"""
Contains the base classes for TimedModules in Rockpool. Also contains classes to adapt Module classes -> TimedModule classes, and to wrap Module objects as TimedModule objects.
"""
# - Rockpool imports
from rockpool.timeseries import TimeSeries, TSContinuous, TSEvent
from rockpool.nn.modules.module import Module, ModuleBase, PostInitMetaMixin
from rockpool.parameters import SimulationParameter, Parameter, State
# - Other imports
from typing import Optional, Union, Tuple, Dict, Any, List, Iterable, Callable
from warnings import warn
from abc import abstractmethod
import numpy as np
import functools
from decimal import Decimal
from collections import abc
# - Tolerance constants to use when comparing floats
tol_rel = 1e-5
tol_abs = 1e-6
decimal_base = 1e-7
# - Define some type aliases
RealValue = Union[float, Decimal, SimulationParameter, str]
Tree = Union[abc.Iterable, abc.MutableMapping]
__all__ = ["TimedModule", "TimedModuleWrapper", "astimedmodule"]
def is_multiple(
a: RealValue,
b: RealValue,
tol_rel: RealValue = tol_rel,
tol_abs: RealValue = tol_abs,
) -> bool:
"""
Check whether a % b is 0 within some tolerance.
:param float a: The number that may be multiple of `b`
:param float b: The number `a` may be a multiple of
:param float tol_rel: Relative tolerance
:param float tol_abs: Absolute tolerance
:return bool: True if `a` is a multiple of `b` within some tolerance
"""
# - Convert to decimals
a = Decimal(str(a))
b = Decimal(str(b))
tol_rel = Decimal(str(tol_rel))
tol_abs = Decimal(str(tol_abs))
min_remainder = min(a % b, b - a % b)
return min_remainder < tol_rel * b + tol_abs
def gcd(a: RealValue, b: RealValue) -> Decimal:
"""
Return the greatest common divisor of two values
:param float a: Value `a`
:param float b: Value `b`
:return int: Greatest common divisor of `a` and `b`
"""
a = Decimal(str(a))
b = Decimal(str(b))
if b == 0:
return a
else:
return gcd(b, a % b)
def lcm(a: RealValue, b: RealValue) -> Decimal:
"""
Return the least common multiple of two values
:param float a: Value a
:param float b: Value b
:return int: Least common integer multiple of `a` and `b`
"""
# - Make sure that values used are sufficiently large
# Transform to integer-values
a_rnd = round(float(a) / decimal_base)
b_rnd = round(float(b) / decimal_base)
# - Make sure that a and b are not too small
if (
np.abs(a_rnd - float(a) / decimal_base) > tol_rel
or np.abs(b_rnd - float(b) / decimal_base) > tol_rel
):
raise ValueError(
"network: Too small values to find lcm. Try changing 'decimal_base'"
)
a = Decimal(str(a_rnd))
b = Decimal(str(b_rnd))
return a / gcd(a, b) * b * Decimal(str(decimal_base))
def tree_map(func: Callable[[Any], Any], tree: Tree) -> Tree:
"""
Map a function over a tree
A ``Tree`` is a nested ``list``, ``dict`` or ``tuple`` object, containing other ``list``s, ``dict``s and ``tuple``s. Any other type is considered a leaf node. The supplied function will be applied independently to all leaf nodes in the tree, and the transformed ``Tree`` will be returned.
Args:
func (Callable): A function to apply to each node in the tree
tree (Tree): A tree over which to iterate. `func` will be applied to every leaf node in `tree`
Returns:
Tree: A transformed tree, with `func` applied to every leaf node
"""
if isinstance(tree, dict): # if dict, apply to each key
return {k: tree_map(func, v) for k, v in tree.items()}
elif isinstance(tree, list): # if list, apply to each element
return [tree_map(func, elem) for elem in tree]
elif isinstance(tree, tuple): # if tuple, apply to each element
return tuple([tree_map(func, elem) for elem in tree])
else:
# - Apply function
return func(tree)
def leaves(d: dict):
"""
A generator that yields the leaf nodes in a nested dict
`leaves` will perform a depth-first traversal of the nested ``dict`` `d`, and will yield each ``(key, value)`` tuple in turn
Args:
d (dict): The dict over which to traverse
Yields:
tuple: (key, value)
key (str): The key of a given leaf
value (Any): The value of a given leaf
"""
for k, v in d.items():
if isinstance(v, dict):
yield from leaves(v)
else:
yield k, v
[docs]class TimedModule(ModuleBase, metaclass=PostInitMetaMixin):
"""
The Rockpool base class for all :py:class:`.TimedModule` modules
:py:class:`.TimedModule` provides functionality for :py:class:`.Module` s to understand time series data, and to conveniently evolve, handle and return time series data from modules.
The :py:meth:`.evolve` method provided by :py:class:`.TimedModule` can accept :py:class:`.TimeSeries` objects natively as input, or can accept clocked / rasterised input data.
See Also:
:py:class:`.TimedModule` provides the useful methods :py:meth:`~.TimedModule._prepare_input` and :py:meth:`~.TimedModule._gen_timeseries` to help you in rasterising data for your own :py:class:`.TimedModule` subclasses.
For more information on how to used the :py:class:`.TimedModule` API for Rockpool, see :ref:`/in-depth/api-high-level.ipynb`.
"""
__in_TimedModule_init: bool = False
""" A flag indicating that this ``TimedModule`` is currently being initialised """
[docs] def __init__(
self,
dt: Union[float, SimulationParameter],
spiking_input: bool = False,
spiking_output: bool = False,
add_events: bool = True,
*args,
**kwargs,
):
"""
Initialise this :py:class:`.TimedModule` object
When initialised, the :py:class:`.TimedModule` will have a :py:attr:`~.TimedModule.dt` attribute assigned, as well as initialising the internal module :py:attr:`~.TimedModule._timestep`, :py:attr:`~.TimedModule._parent_dt_factor` and :py:attr:`~.TimedModule._is_child`. The subclass :py:meth:`~.TimedModule.evolve` method will be wrapped to update the internal timestamp clock.
Args:
dt (float): The duration of a single time step for this module, in seconds
spiking_input (bool): If ``True``, this module accepts :py:class:`.TSEvent` event time series objects as input. If ``False`` (default), this module accepts :py:class:`TSContinuous` continuous time series objects as input.
spiking_output (bool): If ``True``, this module sends :py:class:`.TSEvent` event time series objects as output. If ``False`` (default), this module sends :py:class:`TSContinuous` continuous time series objects as output.
*args: Additional positional arguments
**kwargs: Additional keyword arguments
"""
# - Initialise superclass
super().__init__(
spiking_input=spiking_input, spiking_output=spiking_output, *args, **kwargs
)
# - Assign dt
self.dt: Union[float, SimulationParameter] = SimulationParameter(dt, "dt")
""" float: The simulation and input rasterisation timestep for this `.TimedModule` """
# - Initialise internal timestep
self._timestep: int = 0
""" The current time-step count in units of :py:attr:`.dt` """
# - Initialise dt factor (1.0 by default)
self._parent_dt_factor: float = 1.0
""" The factor between the parent's :py:attr:`.dt` and this module's :py:attr:`.dt`. Given by ``self.dt / parent.dt`` """
# - Initialise a flag indicating that this is a child module
self._is_child: bool = TimedModule.__in_TimedModule_init
""" Flag indicating that this is a child module """
# - Record that we are currently initialising the module tree
TimedModule.__in_TimedModule_init = True
# - Wrap `evolve()` method to perform timestep updates
self.__evolve = self.evolve
self.evolve = self._evolve_wrapper
# - Remember "add events" argument
self._add_events = add_events
def __post_init__(self) -> None:
"""
Perform post-initialisation work for :py:class:`.TimedModule`
Handles setting the :py:attr:`.dt` attribute for a :py:class:`.TimedModule` tree with all sub-modules. Manages the :py:attr:`._is_child` attribute.
"""
# - Find least-common-multiple `dt` for base module
if not self._is_child:
self._set_dt()
# - Restore in_init flag
TimedModule.__in_TimedModule_init = self._is_child
[docs] def _set_dt(self, max_factor: float = 100) -> None:
"""
Set a time step size for the network which is the lowest common multiple of all sub-module's :py:attr:`.dt` s.
:param float max_factor: Factor by which the module :py:attr:`.dt` may exceed the largest sub-module :py:attr:`.dt` before an error is raised. Default: 100.
:raises ValueError: If a sensible :py:attr:`.dt` cannot be found
"""
if self.modules():
## -- Try to determine self.dt from layer time steps
dt_list = [
Decimal(str(dt)) for _, dt in leaves(self.attributes_named("dt"))
]
# - Determine least common multiple
t_lcm = dt_list[0]
for dt in dt_list[1:]:
try:
t_lcm = lcm(t_lcm, dt)
except ValueError:
raise ValueError(
"Network: dt is too small for one or more layers. Try larger"
+ " value or decrease `decimal_base`."
)
if (
# If result is way larger than largest dt, assume it hasn't worked
t_lcm > max_factor * np.amax(dt_list)
# Also make sure that t_lcm is indeed a multiple of all dt's
or any(not is_multiple(t_lcm, dt) for dt in dt_list)
):
raise ValueError(
"Network: Couldn't find a reasonable common time step "
+ f"(layer dt's: {dt_list}, found: {t_lcm}"
)
# - Store base-level time step, for now as float for compatibility
self.dt = float(t_lcm)
# - Store number of layer time steps per global time step for each layer
for _, mod in self.modules().items():
if hasattr(mod, "dt"):
mod._parent_dt_factor = float(round(self.dt / mod.dt))
[docs] def _evolve_wrapper(
self,
ts_input=None,
duration=None,
num_timesteps=None,
kwargs_timeseries=None,
record: bool = False,
*args,
**kwargs,
) -> Tuple[TimeSeries, Dict, Dict]:
"""
Wrap a call to :py:meth:`.evolve` to update the internal time-steps count
See :py:meth:`.evolve` for calling syntax.
"""
# - Determine number of timesteps
num_timesteps = self._determine_timesteps(ts_input, duration, num_timesteps)
# - Call wrapped evolve
ts_output, state_dict, record_dict = self.__evolve(
ts_input,
duration,
num_timesteps,
kwargs_timeseries,
record,
*args,
**kwargs,
)
# - We could re-wrap outputs as TimeSeries here, if desired
# - Update internal time
self._timestep += num_timesteps
return ts_output, state_dict, record_dict
[docs] def _determine_timesteps(
self,
ts_input: Optional[TimeSeries] = None,
duration: Optional[float] = None,
num_timesteps: Optional[int] = None,
) -> int:
"""
Determine how many time steps to evolve with the given input specification
:param Optional[TimeSeries] ts_input: TxM or Tx1 time series of input signals for this layer
:param Optional[float] duration: Duration of the desired evolution, in seconds. If not provided, ``num_timesteps`` or the duration of ``ts_input`` will be used to determine evolution time
:param Optional[int] num_timesteps: Number of evolution time steps, in units of :py:attr:`.dt`. If not provided, ``duration`` or the duration of ``ts_input`` will be used to determine evolution time
:return int:
num_timesteps: Number of evolution time steps
"""
if num_timesteps is None:
# - Determine `num_timesteps`
if duration is None:
# - Determine duration
if ts_input is None:
raise TypeError(
self.full_name
+ "One of 'num_timesteps', 'ts_input' or 'duration' must be supplied."
)
if ts_input.periodic:
# - Use duration of periodic TimeSeries, if possible
duration = ts_input.duration
else:
# - Evolve until the end of the input TimeSeries
duration = ts_input.t_stop - self.t
if duration <= 0:
raise ValueError(
self.full_name
+ "Cannot determine an appropriate evolution duration."
+ " 'ts_input' finishes before the current evolution time."
)
num_timesteps = int(np.floor((duration + tol_abs) / self.dt))
else:
if not isinstance(num_timesteps, int):
raise TypeError(
self.full_name + "'num_timesteps' must be a non-negative integer."
)
elif num_timesteps < 0:
raise ValueError(
self.full_name + "'num_timesteps' must be a non-negative integer."
)
# - Convert parent num_timestamps to self-compatible num-timestamps
num_timesteps = int(np.ceil(num_timesteps * self._parent_dt_factor))
return num_timesteps
[docs] def _gen_time_trace(self, t_start: float, num_timesteps: int) -> np.ndarray:
"""
Generate a time trace starting at ``t_start``, of length ``num_timesteps`` with time step :py:attr:`.dt`
:param float t_start: Start time, in seconds
:param int num_timesteps: Number of time steps to generate, in units of :py:attr:`.dt`
:return ndarray: Generated time trace
"""
# - Generate a trace
time_trace = np.arange(num_timesteps) * self.dt + t_start
return time_trace
[docs] def _gen_timeseries(self, output: np.ndarray, **kwargs) -> TimeSeries:
"""
Wrap a clocked / rasterised output array into a :py:class:`.TimeSeries` object
Output :py:class:`.TimeSeries` will be of the appropriate subclass, and will be named nicely.
Args:
output (np.ndarray): The clocked or rasterised output data ``(T, N)``
**kwargs: Additional keyword arguments to :py:class:`.TimeSeries`
Returns:
TimeSeries: The data in ``output`` wrapped into a :py:class:`.TimeSeries` object
"""
if len(output.shape) > 2:
output = output[0]
if self.spiking_output:
return self._gen_tsevent(output, **kwargs)
else:
return self._gen_tscontinuous(output, **kwargs)
[docs] def _gen_tsevent(
self,
output: np.ndarray,
dt: Optional[float] = None,
t_start: Optional[float] = None,
name: Optional[str] = None,
periodic: bool = False,
num_channels: Optional[int] = None,
spikes_at_bin_start: bool = False,
) -> TSEvent:
"""
Wrap a rasterised output array as a :py:class:`.TSEvent` object to present as output for this module
Output :py:class:`.TSEvent` s will be named nicely, with correct start timesm durations, etc. Several attributes of the :py:class:`.TSEvent` object can be set as arguments here.
Args:
output (np.ndarray): A rasterised event array ``(T, N)``
dt (Optional[float]): The time-step of the rasterised array ``output``. If not provided, the module :py:attr:`.dt` will be used
t_start (Optional[float]): The start time of the output series, in seconds. If not provided, the module time before evolution will be used
name (Optional[str]): The desired name of the :py:class:`.TSEvent` object. If not provided, the object will be named nicely according to the module name
periodic (bool): Flag to indicate whether the returned :py:class:`.TSEvent` should be periodic. Default: ``False``, the :py:class:`.TSEvent` will not be periodic
num_channels (Optional[int]): The desired number of total channels for the output :py:class:`.TSEvent` object. If not provided, the output size :py:attr:`.size_out` of the current module will be used
spikes_at_bin_start (bool): If ``False`` (default), spike events will be considered to fall in the middle of the time bin they fall in. If ``True``, all spike events will be considered to occur at the start of the time bin they fall in.
Returns:
TSEvent: The wrapped output raster as a :py:class:`.TSEvent` object
"""
# - Build a name for the time series
if name is None:
if self.name:
name = f"Output events '{self.name}'"
else:
name = f"Output events"
# - Create and return a new event time series
return TSEvent.from_raster(
raster=output,
dt=self.dt if dt is None else dt,
t_start=self.t if t_start is None else t_start,
name=name,
periodic=periodic,
num_channels=self.size_out if num_channels is None else num_channels,
spikes_at_bin_start=spikes_at_bin_start,
)
[docs] def _gen_tscontinuous(
self,
output: np.ndarray,
dt: Optional[float] = None,
t_start: Optional[float] = None,
name: Optional[str] = None,
periodic: bool = False,
interp_kind: str = "previous",
) -> TSContinuous:
"""
Wrap a rasterised output array as a :py:class:`TSContinuous` object to present as output for this module
Output :py:class:`.TSContinuous` s will be named nicely, with correct start times, durations, etc. Several attributes of the :py:class:`.TSContinuous` object can be set as arguments here.
Args:
output (np.ndarray): A clocked time series data array ``(T, N)``
dt (Optional[float]): The time-step of the clocked array ``output``. If not provided, the module :py:attr:`.dt` will be used
t_start (Optional[float]): The start time of the output :py:class:`.TSContinuous` object, in seconds. If not provided, the module time before evolution will be used
name (Optional[str]): The desired name of the :py:class:`.TSContinuous` object. If not provided, the object will be named nicely according to the module name
periodic (bool): Flag to indicate whether the returned :py:class:`.TSContinuous` should be periodic. Default: ``False``, the :py:class:`.TSContinuous` will not be periodic
interp_kind (str): The style of interpolation to apply to the returned :py:class:`.TSContinuous` object. Default: ``"previous"``
Returns:
TSContinuous: The wrapped output data as a `TSContinuous` object
"""
# - Build a name for the time series
if name is None:
if self.name:
name = f"Output samples '{self.name}'"
else:
name = f"Output samples"
# - Create and return a new continuous time series
return TSContinuous.from_clocked(
samples=output,
dt=self.dt if dt is None else dt,
t_start=self.t if t_start is None else t_start,
periodic=periodic,
name=name,
interp_kind=interp_kind,
)
[docs] @abstractmethod
def evolve(
self,
ts_input: Union[TimeSeries, np.ndarray] = None,
duration: float = None,
num_timesteps: int = None,
kwargs_timeseries: dict = None,
record: bool = False,
*args,
**kwargs,
) -> Tuple[TimeSeries, Dict, Dict]:
"""
Evolve the state of this module over time
Warnings:
If you are seeing this message in documentation for a :py:class:`.TimedModule` subclass, then THIS CLASS HAS NOT PROVIDED DOCUMENTATION FOR ITS EVOLVE METHOD. PLEASE UPDATE THE DOCUMENTATION TO INCLUDE SPECIFIC DETAILS FOR THIS CLASS.
You need to implement an :py:meth:`.evolve` method for each class which inherits from :py:class:`.TimedModule`.
Here is an example :py:meth:`.evolve` method that rasterises a time series and uses the rasterised version for further processing. The output data is re-wrapped as a time series and returned.
.. code-block:: python
def evolve(...):
# - Rasterise input and prepare input time steps
time_base, input_raster, num_timesteps = self._prepare_input(
ts_input, duration, num_timesteps
)
# - Call sub-modules, do your evolution, etc.
# - Return and wrap outputs if necessary
return (
self._gen_timeseries(output, **kwargs_timeseries),
new_state,
record_dict,
)
Here is an example :py:meth:`.evolve` method that uses :py:class:`.TimeSeries` objects natively. Any rasterisation would be taken care of by submodules, if and when required.
.. code-block:: python
def evolve(...):
new_state = {}
record = {}
x1, new_state1, record1 = self.submodule(input_ts)
new_state.update({'submodule': new_state1})
record.update({'submodule': record1})
x2, new_state2, record2 = self.submodule2(x1)
new_state.update({'submodule2': new_state2})
record.update({'submodule2': record2})
return x2, new_state, record
You can of course use a mixture of these approaches.
Args:
ts_input (Union[TimeSeries, np.ndarray]): The input time series over which to evolve
duration (float): The duration over which to evolve, in seconds
num_timesteps (int): The number of time steps (in terms of the :py:attr:`.dt` attribute of this module) to evolve over
kwargs_timeseries (Optional[dict]): Any additional arguments to pass when generating output time series
record (bool): If ``True``, this module and sub-modules must record their state during evolution and return it in the ``record_state`` dict. If ``False`` (default), no recording is requested
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
tuple: (output_ts, new_state, record_state)
output_ts :py:class:`.TimeSeries`: A time series containing the output time series produces by this module.
new_state dict: A dictionary containing the updated state of this module and sub-modules, after evolution
record_state dict: If the argument ``record`` is ``True``, ``record_state`` must contain a dictionary of the recorded states o this and all sub-modules during evolution. Otherwise it may be an empty dict.
"""
raise NotImplementedError
# - Rasterise input and prepare input time steps
time_base, input_raster, num_timesteps = self._prepare_input(
ts_input, duration, num_timesteps
)
# - Return and wrap outputs if necessary
return (
self._gen_timeseries(output, **kwargs_timeseries),
new_state,
record_dict,
)
def __call__(self, *args, **kwargs) -> Tuple[TimeSeries, Dict, Dict]:
"""
Evolve the state of this :py:class:`.TimedModule` over time
Args:
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
tuple: (output_ts, new_state, record_state)
output_ts (TimeSeries): The output time series produced by this module.
new_state (dict): A dictionary containing the updated state of this and all sub-modules after evolution
record_state (dict): If the argument ``record`` is ``True``, ``record_state`` must contain a dictionary of the recorded states o this and all sub-modules during evolution. Otherwise it may be an empty dict.
"""
return self.evolve(*args, **kwargs)
@property
def input_type(self) -> type:
"""type: The :py:class:`.TimeSeries` class accepted by this module"""
if self.spiking_input:
return TSEvent
else:
return TSContinuous
@property
def output_type(self) -> type:
"""type: The :py:class:`.TimeSeries` class returned by this module"""
if self.spiking_output:
return TSEvent
else:
return TSContinuous
@property
def t(self) -> float:
"""float: The current evolution time of this layer, in seconds"""
return self._timestep * self.dt
@t.setter
def t(self, new_t) -> None:
self._timestep = int(np.floor(new_t / self.dt))
[docs] def reset_time(self) -> None:
"""
Reset the internal time of this module and all sub-modules to zero
"""
# - Reset own time
self._timestep = 0
# - Get attribute registry
__registered_attributes, __modules = self._get_attribute_registry()
# - Reset submodule time
for k, m in __modules.items():
if hasattr(m, "reset_time"):
m[0].reset_time()
[docs] def reset_all(self) -> None:
"""
Reset the internal state and time of this module and all sub-modules
"""
self.reset_state()
self.reset_time()
[docs]class TimedModuleWrapper(TimedModule):
"""
Wrap a low-level Rockpool :py:class:`.Module` automatically into a :py:class:`.TimedModule` object
Use this class to automatically convert a :py:class:`.Module` subclass, implementing the low-level API of Rockpool, into a :py:class:`.TimedModule` object that supports the high-level time series API directly.
Notes:
Only a single output argument may be returned from the wrapped :py:class:`.TimedModule`. However, multiple return arguments from the internal module can be handled through the ``output_num`` argument to :py:meth:`.__init__`.
Recorded state from the wrapped module is not currently converted automatically into :py:class:`.TimeSeries` objects. Just keep that in mind.
Examples:
Constract a low-level module, wrap it into a :py:class:`.TimedModule`:
>>> from rockpool.nn.modules import RateEuler, TimedModuleWrapper
>>> mod = RateEuler(...)
>>> tmod = TimedModuleWrapper(mod)
See Also:
If you want to convert a :py:class:`.Module` object, use this class.
If you need to convert a Rockpool v1 :py:class:`~rockpool.nn.layers.Layer` subclass, use either the :py:class:`.LayerToTimedModule` or the `.astimedmodule` decorator.
For more information, see :ref:`/in-depth/api-high-level.ipynb`.
"""
[docs] def __init__(
self,
module: Module,
output_num: int = 0,
dt: float = None,
add_events: bool = True,
*args,
**kwargs,
):
"""
Wrap a low-level module as a high-level :py:class:`.TimedModule`
Args:
module (:py:class:`.Module`): The module to wrap. Must inherit from :py:class:`.Module`.
output_num (int): If the output of the evolution function for ``module`` returns multiple outputs, then here you should specify which of the outputs to wrap into a time series to return. :py:class:`.TimedModuleWrapper` only supports returning one output argument from :py:meth:`.evolve`.
dt (float): The timestep to set for ``module``, if ``module.dt`` does not exist. Note that ``module.dt`` will not be overridden by this argument!
add_events (bool): If ``True``, then multiple events per time bin will be summed when converting to a raster. If ``False``, only a single event will be retained per time bin. Default: ``True``, sum events in each time bin.
"""
# - Check that we are wrapping a Module object
if not isinstance(module, Module):
raise TypeError(self.full_name + ": `module` must be a 'Module' object.")
# - Warn that an extra `dt` is ignored
if dt is not None and hasattr(module, "dt"):
warn(
"`dt` argument to `TimedModuleWrapper` is ignored if the module already has a `dt` attribute."
)
# - Assign a `dt`, if the submodule doesn't already have one
if not hasattr(module, "dt"):
if dt is None:
raise KeyError(
self.full_name
+ ": If 'module' has no `dt`, it must be passed as an argument."
)
module.dt = SimulationParameter(dt)
# - Initialise superclass
super().__init__(
shape=(module.size_in, module.size_out),
spiking_input=module.spiking_input,
spiking_output=module.spiking_output,
dt=module.dt,
add_events=add_events,
*args,
**kwargs,
)
# - Keep a handle to the submodule
self._module = module
# - Remember which output to select
self._output_num = output_num
@property
def module(self):
"""`Module`: The wrapped module"""
return self._module
def __repr__(self) -> str:
"""str: A representation of this module as a string"""
return f"{super().__repr__()} with {self._module.full_name} as module"
[docs] def evolve(
self,
ts_input: Optional[TimeSeries] = None,
duration: Optional[float] = None,
num_timesteps: Optional[int] = None,
kwargs_timeseries: Optional[Dict] = None,
record: bool = False,
*args,
**kwargs,
) -> Tuple[TimeSeries, Any, Any]:
"""
Evolve the wrapped :py:class:`.Module`, handling :py:class:`.TimeSeries` input and output
Args:
ts_input (Optional[TimeSeries]): The input data for this evolution. If not provided, zero input will be used
duration (Optional[float]): The duration over which to evolve this module, in seconds. If not provided, it will be inferred
num_timesteps (Optional[int]): The number of time steps over which to evolve this module, in units of `dt`. If not provided, it will be inferred
kwargs_timeseries (Optional[dict]): Additional keyword arguments to pass when generating the output time series
record (bool): If ``True``, a dictionary containing a record of state during evolution for this and all submodules will be returned. If ``False`` (default), no record is requested.
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
tuple: (output_ts, new_state, record_dict)
output_ts (`.TimeSeries`): The output of this module, wrapped as a `.TimeSeries`.
new_state (dict): A dictionary containing teh updated state of this and all sub-modules after evolution
record_dict (dict): If ``True``, a dictionary containing a record of state during evolution for this and all submodules will be returned. If ``False`` (default), no record is requested.
"""
# - Rasterise input time series
time_base, input_data, num_timesteps = self._prepare_input(
ts_input, duration, num_timesteps
)
# - Call evolution method of wrapped module
output, state_dict, record_dict = self._module.evolve(input_data, record=record)
# - Get the first output, if more than one is returned
if isinstance(output, tuple):
output = output[self._output_num]
# - Convert output to TimeSeries
if kwargs_timeseries is None:
kwargs_timeseries = {}
ts_out = self._gen_timeseries(output, **kwargs_timeseries)
# - Use the optional `_wrap_recorded_state` method to convert the recorded state to TimeSeries objects
if record:
record_dict = self._module._wrap_recorded_state(record_dict, time_base[0])
return ts_out, state_dict, record_dict
[docs] def reset_state(self) -> None:
self._module = self._module.reset_state()
[docs]class LayerToTimedModule(TimedModule):
"""
An adapter class to wrap a Rockpool v1 :py:class:`.Layer` object, converting the object to support the :py:class:`.TimedModule` high-level Rockpool v2 API
Use this class to automagically convert a Rockpool v1 :py:class:`.Layer` object into a Rockpool v2 :py:class:`.TimedModule` object. This class is used internally by the :py:func:`~.timed_module.astimedmodule` decorator to convert a v1 class into a v2 class.
Examples:
Construct a v1 :py:class:`.RecRateEulerV1` layer, and convert it to a v2 :py:class:`.TimedModule`:
>>> from rockpool.nn.layers import RecRateEulerV1
>>> from rockpool.nn.modules.timed_module import LayerToTimedModule
>>> lyr = RecRateEuler(...)
>>> tmod = LayerToTimedModule(lyr)
>>> output_ts, new_state, record = tmod(input_ts)
See Also:
If you want to convert a :py:class:`.Module` object implementing the low-level v2 API to the high-level :py:class:`.TimedModule` v2 API, use the :py:class:`.TimedModuleWrapper` class.
For more information, see :ref:`/in-depth/api-high-level.ipynb`.
"""
[docs] def __init__(
self,
layer: "Layer",
parameters: Iterable[str] = None,
states: Iterable[str] = None,
simulation_parameters: Iterable[str] = None,
):
"""
Wrap a v1 :py:class:`.Layer` object as a v2 :py:class:`.TimedModule` object
Args:
layer (:py:class:`.Layer`): The v1 layer object to wrap
parameters (Iterable[str]): A list (or tuple) containing the names of all attributes of `layer` that should be registered as Rockpool :py:class:`.Parameter` s
states (Iterable[str]): A list (or tuple) containing the names of all attributes of `layer` that should be registered as Rockpool :py:class:`.State` s
simulation_parameters (Iterable[str]): A list (or tuple) containing the names of all attributes of `layer` that should be registered as Rockpool :py:class:`.SimulationParameter` s
"""
from rockpool.nn.layers.layer import Layer
if not isinstance(layer, Layer):
raise TypeError("LayerToTimedModule can only wrap a Rockpool v1 Layer.")
spiking_input = layer.input_type is TSEvent
spiking_output = layer.output_type is TSEvent
super().__init__(
shape=(layer.size_in, layer.size),
dt=layer.dt,
spiking_input=spiking_input,
spiking_output=spiking_output,
)
# - Record layer as submodule
self._module: Layer = layer
""" Layer: The wrapped layer object """
self._name: str = layer.name
""" str: The name of the wrapped layer """
# - Record parameters
if parameters is not None:
for param in parameters:
self._register_attribute(param, Parameter(getattr(self._module, param)))
# - Record states
if states is not None:
for state in states:
self._register_attribute(state, State(getattr(self._module, state)))
# - Record simulation parameters
if simulation_parameters is not None:
for sim_param in simulation_parameters:
self._register_attribute(
sim_param, SimulationParameter(getattr(self._module, sim_param))
)
[docs] def reset_time(self) -> None:
"""Reset the internal clock for this `TimedModule`"""
super().reset_time()
self._module.reset_time()
[docs] def reset_state(self) -> None:
"""Reset the internal state for this `TimedModule`"""
super().reset_state()
self._module.reset_state()
[docs] def evolve(
self,
ts_input: Optional[TimeSeries] = None,
duration: Optional[float] = None,
num_timesteps: Optional[int] = None,
kwargs_timeseries: Optional[dict] = None,
record: bool = False,
*args,
**kwargs,
) -> Tuple[TimeSeries, dict, dict]:
"""
Evolve the wrapped layer, handle inputs and outputs
Args:
ts_input (Optional[`TimeSeries`]): The input time series to evolve over. If not provided, zero input will be used
duration (Optional[float]): The duration of the evolution, in seconds. If not provided it will be inferred
num_timesteps (Optional[int]): The duration of evolution in integer units of :py:attr:`dt`. If not provided it will be inferred
kwargs_timeseries (Optional[dict]): Additional keyword arguments to pass when creating the return `TimeSeries` object
record (bool): If ``True``, a dictionary of recorded state will be returned for the module. If ``False`` (default), no recorded state is requested.
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
tuple: ts_output, new_state, record_dict
ts_output (`TimeSeries`): A time series containing the output of the module evolution
new_state (dict): A dictionary containing the state of the module post evolution
record_dict (dict): If ``record == True``, a dictionary of recorded state will be returned for the module. If ``record == False`` (default), no recorded state is requested
"""
# - Call submodule layer to evolve
ts_output = self._module.evolve(
ts_input, duration, num_timesteps, *args, **kwargs
)
# - Return output, state and record dict
return ts_output, self.state(), {}
def __setattr__(self, key: str, value: Any) -> None:
"""
Set an attribute of the wrapped layer, if it has been registered
Args:
key (str): The name of the attribute to set
value (Any): The value to set to the attribute
"""
# - Set attribute in module, if registered
if self._has_registered_attribute(key):
if hasattr(self, "_module"):
setattr(self._module, key, value)
# - Ensure we get validated value from submodule
value = getattr(self._module, key)
# - Set value using superclass
super().__setattr__(key, value)
def __getattr__(self, key: str) -> Any:
"""
Get an attribute of the wrapped layer, if it has been registered
Args:
key (str): The name of the attribute to get
Returns:
Any: The value of the attribute
"""
if key == "_ModuleBase__registered_attributes" or key == "_ModuleBase__modules":
raise AttributeError
# - Get attribute from module if registered
if self._has_registered_attribute(key):
return getattr(self._module, key)
else:
raise AttributeError(
f"Attribute {key} not found in TimedModule class {self.class_name} named {self.name}"
)
# def _get_attribute_family(self, type_name: str, family: str = None):
# # - Get matching attributes
# return super()._get_attribute_family(type_name, family)
[docs]def astimedmodule(
v1_cls: type = None,
parameters: Optional[Iterable[str]] = None,
states: Optional[Iterable[str]] = None,
simulation_parameters: Optional[Iterable[str]] = None,
) -> Union[type, Callable]:
"""
Convert a Rockpool v1 class to a v2 class
This decorator transparently converts a Rockpool v1 :py:class:`.Layer` subclass to a Rockpool v2 high-level API :py:class:`.TimedModule` subclass.
You can specify the parameter, state and simulation parameter attributes of the v1 layer to expose via the v2 API.
Evolution should just work™, and ideally you won't need to modify anything in the v1 code to use the class within the v2 API. Depending on the complexity of the v1 layer, this may or may not be the case.
Examples:
Specify a simple v1 layer, and convert it to a v2 :py:class:`.Module`:
.. code-block:: python
from rockpool.nn.layers import Layer
from rockpool.nn.modules.timed_module import astimedmodule
@astimedmodule(
parameters = ['tau_mem', 'tau_syn', 'bias'],
states = ['v_mem', 'i_syn'],
simulation_parameters = ['noise_std']
)
class my_v1_layer(Layer):
def __init__(...):
...
def evolve(...):
...
See Also
For more information, see :ref:`/in-depth/api-high-level.ipynb`.
Args:
v1_cls (type): A v1 :py:class:`.Layer` subclass to wrap
parameters (Optional[Iterable[str]]): An iterable set of strings, specifying the names of attributes provided by ``v1_cls`` that should be automatically registered as Rockpool :py:class:`.Parameter` s.
states (Optional[Iterable[str]]): An iterable set of strings, specifying the names of attributes provided by ``v1_cls`` that should be automatically registered as Rockpool :py:class:`.State` s.
simulation_parameters (Optional[Iterable[str]]): An iterable set of strings, specifying the names of attributes provided by ``v1_cls`` that should be automatically registered as Rockpool :py:class:`.SimulationParameter` s.
Returns:
:py:class:`.LayerToTimedModule`: A wrapped class instantiator the will create a v2 high-level API object
"""
# - Be lenient if any parameters are not lists/tuples
if not isinstance(parameters, (tuple, list)) and parameters is not None:
parameters = [parameters]
if not isinstance(states, (tuple, list)) and states is not None:
states = [states]
if (
not isinstance(simulation_parameters, (tuple, list))
and simulation_parameters is not None
):
simulation_parameters = [simulation_parameters]
from rockpool.nn.layers import Layer
def wrapper_function(v1_cls):
if not issubclass(v1_cls, Layer):
raise ValueError(
"`@astimedmodule` may only be applied to Rockpool v1 `Layer` subclasses."
)
# - Define a wrapping class
@functools.wraps(v1_cls, updated=())
class wrapper(LayerToTimedModule):
def __init__(self, *args, **kwargs):
# - Instantiate layer
layer = v1_cls(*args, **kwargs)
# - Wrap layer
super().__init__(layer, parameters, states, simulation_parameters)
# - Update docstrings for wrapped class
if v1_cls.__doc__ is None:
# - Inherit docs from parent class
wrapper.__doc__ = v1_cls.__base__.__doc__
else:
wrapper.__doc__ = v1_cls.__doc__
wrapper.__init__.__doc__ = v1_cls.__init__.__doc__
wrapper.evolve.__doc__ = v1_cls.evolve.__doc__
# - Return the decorated class
return wrapper
if v1_cls is None:
return wrapper_function
else:
return wrapper_function(v1_cls)