# Source code for graph.graph_modules

```"""
Basic computational modules for graph definition in Rockpool

Defines :py:class:`.LinearWeights`, :py:class:`.GenericNeurons`, :py:class:`.AliasConnection` and :py:class:`.LIFNeuronRealValue`.
"""

from rockpool.graph.graph_base import GraphModule

from dataclasses import dataclass, field
from typing import Optional

from rockpool.typehints import FloatVector

import numpy as np
from rockpool.utilities.backend_management import backend_available

if backend_available("torch"):
from torch import Tensor
else:

class Tensor:
pass

__all__ = [
"LinearWeights",
"GenericNeurons",
"AliasConnection",
"LIFNeuronWithSynsRealValue",
"RateNeuronWithSynsRealValue",
]

[docs]@dataclass(eq=False, repr=False)
class LinearWeights(GraphModule):
"""
A :py:class:`.GraphModule` that encapsulates a single set of linear weights
"""

weights: FloatVector
""" FloatVector: The linear weights ``(Nin, Nout)`` encapsulated by this module """

biases: Optional[FloatVector] = None
""" FloatVector: The biases ``(Nout,)`` encapsulaed by this module """

def __post_init__(self, *args, **kwargs):
# - Check size
if self.weights.shape != (len(self.input_nodes), len(self.output_nodes)):
raise ValueError(
f"`weights` must match size of input and output nodes. Got {self.weights.shape}, expected {(len(self.input_nodes), len(self.output_nodes))}."
)

if self.biases is not None and self.biases.shape != (len(self.output_nodes),):
raise ValueError(
f"`biases` must match size of input and output nodes. Got {self.biases.shape}, expected {(len(self.output_nodes),)}."
)

super().__post_init__(*args, **kwargs)

# - Convert weights and biases to numpy arrays
if isinstance(self.weights, Tensor):
self.weights = self.weights.detach().cpu().numpy()
else:
self.weights = np.array(self.weights)

if self.biases is not None:
if isinstance(self.biases, Tensor):
self.biases = self.biases.detach().cpu().numpy()
else:
self.biases = np.array(self.biases)

[docs]@dataclass(eq=False, repr=False)
class GenericNeurons(GraphModule):
"""
A :py:class:`.GraphModule` than encapsulates a set of generic neurons

This class is used as a base class for all specific neuron subclasses. It defines only input and output nodes, and does not specify any parameters for the neurons.
"""

pass

[docs]@dataclass(eq=False, repr=False)
class AliasConnection(GraphModule):
"""
A :py:class:`.GraphModule` that encapsulates a set of alias connections
"""

def __post_init__(self, *args, **kwargs):
# - Call super-class
super().__post_init__(*args, **kwargs)

# - Check size
if len(self.input_nodes) != len(self.output_nodes):
raise ValueError(
f"For an alias connection, the number of inputs and outputs must be identical.\nGot {len(self.input_nodes)} and {len(self.output_nodes)}."
)

[docs]@dataclass(eq=False, repr=False)
class LIFNeuronWithSynsRealValue(GenericNeurons):
"""
A :py:class:`.GraphModule` that encapsulates a set of LIF spiking neurons with synaptic and membrane dynamics, and with real-valued parameters
"""

tau_mem: FloatVector = field(default_factory=list)
""" Floatvector: The membrane time constants of these neurons, in seconds ``(Nout,)`` """

tau_syn: FloatVector = field(default_factory=list)
""" Floatvector: The synaptic time constants of these neurons, in seconds ``(Nin,)`` """

threshold: FloatVector = field(default_factory=list)
""" Floatvector: The firing threshold parameters of these neurons ``(Nout,)`` """

bias: FloatVector = field(default_factory=list)
""" Floatvector: The bias parameters of these neurons, if present ``(Nout,)`` """

dt: Optional[float] = None
""" float: The time-step used for these neurons in seconds, if present """

[docs]@dataclass(eq=False, repr=False)
class RateNeuronWithSynsRealValue(GenericNeurons):
"""
A :py:class:`.GraphModule` that encapsulates a set of rate neurons, with synapses, and with real-valued parameters
"""

tau: FloatVector = field(default_factory=list)
""" Floatvector: The time constants of these neurons, in seconds ``(Nout,)`` """

bias: FloatVector = field(default_factory=list)
""" Floatvector: The bias parameters of these neurons ``(Nout,)`` """

dt: Optional[float] = None
""" float: The time-step used for these neurons in seconds, if present """
```