"""
Classes to manage registered Module attributes in Rockpool
"""
from typing import Callable, Iterable, Any, Union, List, Tuple, Optional
from copy import deepcopy
from itertools import compress
from functools import partial
import numpy as np
__all__ = ["Parameter", "State", "SimulationParameter", "Constant"]
from rockpool.utilities.backend_management import backend_available
if backend_available("torch"):
from torch import tensor, Tensor
else:
class Tensor:
pass
def tensor(_):
return _
class RP_Constant:
"""
Represent a concrete initialisation value as a constant parameter, which should not be trained
See Also:
Use :py:func:`Constant` to wrap an intialisation as a constant argument.
"""
pass
[docs]def Constant(obj: Any) -> RP_Constant:
"""
Identify an initialisation argument as a constant (non-trainable) parameter
Examples
>>> mod = LIFJax(1)
>>> mod.parameters('taus')
{'tau_mem': DeviceArray([0.02], dtype=float32),
'tau_syn': DeviceArray([[0.02]], dtype=float32)}
>>> mod.simulation_parameters('taus')
{}
>>> mod = LIFJax(1, tau_mem = Constant(10e-3))
>>> mod.parameters('taus')
{'tau_syn': DeviceArray([[0.02]], dtype=float32)}
>>> mod.simulation_parameters('taus')
{'tau_mem': DeviceArray(0.01, dtype=float32)}
Args:
obj (Any): The initialisation object to wrap
Returns:
A wrapped object, of the same class as ``obj``.
"""
class ConstantPatch(obj.__class__, RP_Constant):
pass
ConstantPatch.__name__ = obj.__class__.__name__
try:
obj.__class__ = ConstantPatch
except TypeError:
if isinstance(obj, np.ndarray):
obj = obj.view(ConstantPatch)
else:
obj = ConstantPatch(obj)
return obj
# -- Parameter classes
[docs]class ParameterBase:
"""
Base class for Rockpool registered attributes
See Also:
See :py:class:`.Parameter` for representing the configuration of a module, :py:class:`.State` for representing the transient internal state of a neuron or module, and :py:class:`.SimulationParameter` for representing simulation- or solver-specific parameters that are not important for network configuration.
"""
[docs] def __init__(
self,
data: Any = None,
family: str = None,
init_func: Callable[[Any], Any] = None,
shape: Optional[Union[List[Tuple], Tuple, int]] = None,
permit_reshape: bool = True,
cast_fn: Callable[[Any], Any] = lambda x: x,
):
"""
Instantiate a Rockpool registered attribute
Args:
data (Optional[Any]): Concrete initialisation data for this attribute. The shape of ``data`` will specify the allowable shape of the attribute data, unless the ``shape`` argument is provided.
family (Optional[str]): An arbitrary string to specify the "family" of this attribute. You should use ``'weights'``, ``'taus'``, ``'biases'`` if you can; otherwise you can use whatever you like. These are used by the :py:meth:`.Module.parameters`, :py:meth:`.Module.state` and :py:meth:`.Module.simulation_parameters` methods to group and select attributes.
init_func (Optional[Callable]): A function that initialises this attributed. Called by :py:meth:`.Module.reset_parameters` and :py:meth:`.Module.reset_state`. The signature is ``f(shape: tuple) -> np.ndarray``.
shape (Optional[Union[List[Tuple], Tuple, int]]): A list of permisable shapes for the parameter, or a tuple specifying the permitted shape, or an integer specifying the number of elements. If not provided, the shape of the concrete initialisation data will be used as the attribute shape. The first item in the list will be used as the concrete shape, if ``data`` is not provided and ``init_func`` should be used.
permit_reshape (bool): If ``True``, the input data will be reshaped to a matching permitted shape. If ``False``, then an error will be raised if the shapes do not match exactly.
cast_fn (Optional[Callable]): A function to call to cast the data for this parameter. Will only be called once on initialisation.
"""
if data is None and shape is None:
raise ValueError(f"One of `data` or `shape` must be provided.")
# - Check type and configuration of `shape` argument
if shape is not None:
if not isinstance(
shape,
(List, Tuple, int),
):
raise TypeError(
f"`shape` must be a list, a tuple or an integer. Instead `shape` was a {type(shape).__name__}."
)
# - Convert a single tuple to a list
if isinstance(shape, (Tuple, int)):
shape = [shape]
# - Check each list element in turn
for i, st in enumerate(shape):
# - Convert non-tuples to tuples
if not isinstance(st, tuple):
shape[i] = (st,)
st = shape[i]
# - Check each element of each tuple
for elem in st:
if not isinstance(elem, int):
raise TypeError(
f"All elements in a shape tuple must be integers. Instead I found an element of type {type(elem).__name__}."
)
# - Assign attributes
self.family: str = family
self.data: Union[np.ndarray, Iterable, float, int] = data
self.init_func: Callable = init_func
self.shape: Optional[List] = shape
self.cast_fn: Callable = cast_fn
class_name = type(self).__name__
# - Check that the initialisation function is callable
if self.init_func is not None and not callable(self.init_func):
raise ValueError(
f"The `init_func` for a {class_name} must be a callable that accepts a shape tuple."
)
# - Force object to be a SimulationParameter, if training should be disabled
if isinstance(self.data, RP_Constant):
self.__class__ = SimulationParameter
# - Unpack a torch tensor
if isinstance(self.data, Tensor):
self.data = tensor(self.data.detach().numpy())
def numel(x):
if isinstance(x, np.ndarray):
return x.size
elif isinstance(x, Tensor):
return x.numel()
else:
return np.size(x)
# - Get the shape from the data, if not provided explicitly
if self.data is not None:
if self.shape is not None:
# - Check that the concrete data matches the shape
if not any([np.shape(self.data) == st for st in self.shape]):
# - Check if the concrete and desired sizes match for any permitted shape
matching_sizes = [
numel(self.data) == int(np.prod(st)) for st in self.shape
]
# - Can we reshape the concrete data to match a shape?
if not any(matching_sizes) or not permit_reshape:
raise ValueError(
f"The shape provided for this {class_name} does not match the provided initialisation data.\n"
+ f" self.shape = {self.shape}; data.shape = {np.shape(self.data)}"
)
elif permit_reshape and any(matching_sizes):
# - Reshape input data to first matching size
target_shape = list(compress(self.shape, matching_sizes))[0]
self.data = np.array(self.data).reshape(target_shape)
self.shape = None
if self.shape is None:
# - Record the shape of the data as the concrete shape
self.shape = np.shape(self.data)
# - Initialise data, if not provided
if self.data is None:
# - Get the concrete shape to use (by default: first shape option in the list)
self.shape = self.shape[0]
if self.init_func is None:
raise ValueError(
f"If concrete initialisation `data` is not provided for a {class_name} then `init_func` must be provided.\nParameter was {self.data, self.family, self.init_func, self.shape, self.cast_fn}"
)
# - Call the `init_func`
self.data = self.init_func(self.shape)
else:
# - If concrete initialisation data is provided, then override the `init_func`
data_copy = deepcopy(self.data)
self.init_func = lambda _: data_copy
# - Cast the data using the cast function
if self.cast_fn is not None:
self.data = self.cast_fn(self.data)
def __repr__(self):
return f"{type(self).__name__}(data={self.data}, family={self.family}, init_func={self.init_func}, shape={self.shape})"
[docs] def _tree_flatten(self) -> Tuple[tuple, tuple]:
"""FLatten this parameter / state for Jax"""
return (
(
self.data,
self.family,
partial(self.init_func),
self.shape,
partial(self.cast_fn),
),
(),
)
[docs] @classmethod
def _tree_unflatten(cls, _, children):
"""Unflatten a tree of parameter from Jax to Rockpool"""
data, family, init_func, shape, cast_fn = children
obj = cls(data=data, family=family, init_func=init_func)
return obj
[docs]class Parameter(ParameterBase):
"""
Represent a module parameter
A :py:class:`.Parameter` in Rockpool is a configuration value that is important for communicating the configuration of a network. For example, network weights; network time constants; neuron biases; etc. These are likely to be your set of trainable parameters for a module or network.
See Also:
See :py:class:`.State` for representing the transient internal state of a neuron or module, and :py:class:`.SimulationParameter` for representing simulation- or solver-specific parameters that are not important for network configuration.
"""
pass
[docs]class State(ParameterBase):
"""
Represent a module state
A :py:class:`.State` in Rockpool is a transient value which is required to maintain the dynamics of a stateful module. For example the membrane potential of a neuron; the synaptic current; the refractory state of a neuron; etc.
See Also:
See :py:class:`.Parameter` for representing the configuration of a module, and :py:class:`.SimulationParameter` for representing simulation- or solver-specific parameters that are not important for network configuration.
"""
pass
[docs]class SimulationParameter(ParameterBase):
"""
Represent a module simulation parameter
A :py:class:`.SimulationParameter` in Rockpool is a simulation-specific configuration value, which is only needed to control the simulation of a network, but is **not** needed to communicate your network configuration to someone else. For example, the simulation time-step your solver uses to simulate the dynamics of a module. :py:class:`.SimulationParameter` s are basically never trainable parameters.
See Also:
See :py:class:`.Parameter` for representing the configuration of a module, and :py:class:`.State` for representing the transient internal state of a neuron or module.
"""
pass