"""
Basic computational modules for graph definition in Rockpool
Defines :py:class:`.LinearWeights`, :py:class:`.GenericNeurons`, :py:class:`.AliasConnection` and :py:class:`.LIFNeuronRealValue`.
"""
from rockpool.graph.graph_base import GraphModule
from dataclasses import dataclass, field
from typing import Optional
from rockpool.typehints import FloatVector
import numpy as np
from rockpool.utilities.backend_management import backend_available
if backend_available("torch"):
from torch import Tensor
else:
class Tensor:
pass
__all__ = [
"LinearWeights",
"GenericNeurons",
"AliasConnection",
"LIFNeuronWithSynsRealValue",
"RateNeuronWithSynsRealValue",
]
[docs]@dataclass(eq=False, repr=False)
class LinearWeights(GraphModule):
"""
A :py:class:`.GraphModule` that encapsulates a single set of linear weights
"""
weights: FloatVector
""" FloatVector: The linear weights ``(Nin, Nout)`` encapsulated by this module """
biases: Optional[FloatVector] = None
""" FloatVector: The biases ``(Nout,)`` encapsulaed by this module """
def __post_init__(self, *args, **kwargs):
# - Check size
if self.weights.shape != (len(self.input_nodes), len(self.output_nodes)):
raise ValueError(
f"`weights` must match size of input and output nodes. Got {self.weights.shape}, expected {(len(self.input_nodes), len(self.output_nodes))}."
)
if self.biases is not None and self.biases.shape != (len(self.output_nodes),):
raise ValueError(
f"`biases` must match size of input and output nodes. Got {self.biases.shape}, expected {(len(self.output_nodes),)}."
)
super().__post_init__(*args, **kwargs)
# - Convert weights and biases to numpy arrays
if isinstance(self.weights, Tensor):
self.weights = self.weights.detach().cpu().numpy()
else:
self.weights = np.array(self.weights)
if self.biases is not None:
if isinstance(self.biases, Tensor):
self.biases = self.biases.detach().cpu().numpy()
else:
self.biases = np.array(self.biases)
[docs]@dataclass(eq=False, repr=False)
class GenericNeurons(GraphModule):
"""
A :py:class:`.GraphModule` than encapsulates a set of generic neurons
This class is used as a base class for all specific neuron subclasses. It defines only input and output nodes, and does not specify any parameters for the neurons.
"""
pass
[docs]@dataclass(eq=False, repr=False)
class AliasConnection(GraphModule):
"""
A :py:class:`.GraphModule` that encapsulates a set of alias connections
"""
def __post_init__(self, *args, **kwargs):
# - Call super-class
super().__post_init__(*args, **kwargs)
# - Check size
if len(self.input_nodes) != len(self.output_nodes):
raise ValueError(
f"For an alias connection, the number of inputs and outputs must be identical.\nGot {len(self.input_nodes)} and {len(self.output_nodes)}."
)
[docs]@dataclass(eq=False, repr=False)
class LIFNeuronWithSynsRealValue(GenericNeurons):
"""
A :py:class:`.GraphModule` that encapsulates a set of LIF spiking neurons with synaptic and membrane dynamics, and with real-valued parameters
"""
tau_mem: FloatVector = field(default_factory=list)
""" Floatvector: The membrane time constants of these neurons, in seconds ``(Nout,)`` """
tau_syn: FloatVector = field(default_factory=list)
""" Floatvector: The synaptic time constants of these neurons, in seconds ``(Nin,)`` """
threshold: FloatVector = field(default_factory=list)
""" Floatvector: The firing threshold parameters of these neurons ``(Nout,)`` """
bias: FloatVector = field(default_factory=list)
""" Floatvector: The bias parameters of these neurons, if present ``(Nout,)`` """
dt: Optional[float] = None
""" float: The time-step used for these neurons in seconds, if present """
[docs]@dataclass(eq=False, repr=False)
class RateNeuronWithSynsRealValue(GenericNeurons):
"""
A :py:class:`.GraphModule` that encapsulates a set of rate neurons, with synapses, and with real-valued parameters
"""
tau: FloatVector = field(default_factory=list)
""" Floatvector: The time constants of these neurons, in seconds ``(Nout,)`` """
bias: FloatVector = field(default_factory=list)
""" Floatvector: The bias parameters of these neurons ``(Nout,)`` """
dt: Optional[float] = None
""" float: The time-step used for these neurons in seconds, if present """