nn.modules.timed_module.TimedModuleWrapper

class nn.modules.timed_module.TimedModuleWrapper(*args, **kwargs)[source]

Bases: TimedModule

Wrap a low-level Rockpool Module automatically into a TimedModule object

Use this class to automatically convert a Module subclass, implementing the low-level API of Rockpool, into a TimedModule object that supports the high-level time series API directly.

Notes

Only a single output argument may be returned from the wrapped TimedModule. However, multiple return arguments from the internal module can be handled through the output_num argument to __init__().

Recorded state from the wrapped module is not currently converted automatically into TimeSeries objects. Just keep that in mind.

Examples

Constract a low-level module, wrap it into a TimedModule:

>>> from rockpool.nn.modules import RateEuler, TimedModuleWrapper
>>> mod = RateEuler(...)
>>> tmod = TimedModuleWrapper(mod)

See also

If you want to convert a Module object, use this class.

If you need to convert a Rockpool v1 Layer subclass, use either the LayerToTimedModule or the astimedmodule decorator.

For more information, see ⏱ High-level TimedModule API.

Attributes overview

class_name

Class name of self

full_name

The full name of this module (class plus module name)

input_type

The TimeSeries class accepted by this module

module

The wrapped module

name

The name of this module, or an empty string if None

output_type

The TimeSeries class returned by this module

shape

The shape of this module

size

(DEPRECATED) The output size of this module

size_in

The input size of this module

size_out

The output size of this module

spiking_input

If True, this module receives spiking input.

spiking_output

If True, this module sends spiking output.

t

The current evolution time of this layer, in seconds

dt

The simulation and input rasterisation timestep for this TimedModule

Methods overview

__init__(module[,Β output_num,Β dt,Β add_events])

Wrap a low-level module as a high-level TimedModule

as_graph()

Convert this module to a computational graph

attributes_named(name)

Search for attributes of this or submodules by time

evolve([ts_input,Β duration,Β num_timesteps,Β ...])

Evolve the wrapped Module, handling TimeSeries input and output

modules()

Return a dictionary of all sub-modules of this module

parameters([family])

Return a nested dictionary of module and submodule Parameters

reset_all()

Reset the internal state and time of this module and all sub-modules

reset_parameters()

Reset all parameters in this module

reset_state()

Reset the state of this module

reset_time()

Reset the internal time of this module and all sub-modules to zero

set_attributes(new_attributes)

Set the attributes and sub-module attributes from a dictionary

simulation_parameters([family])

Return a nested dictionary of module and submodule SimulationParameters

state([family])

Return a nested dictionary of module and submodule States

__in_TimedModule_init: bool

A flag indicating that this TimedModule is currently being initialised

__init__(module: Module, output_num: int = 0, dt: float | None = None, add_events: bool = True, *args, **kwargs)[source]

Wrap a low-level module as a high-level TimedModule

Parameters:
  • module (Module) – The module to wrap. Must inherit from Module.

  • output_num (int) – If the output of the evolution function for module returns multiple outputs, then here you should specify which of the outputs to wrap into a time series to return. TimedModuleWrapper only supports returning one output argument from evolve().

  • dt (float) – The timestep to set for module, if module.dt does not exist. Note that module.dt will not be overridden by this argument!

  • add_events (bool) – If True, then multiple events per time bin will be summed when converting to a raster. If False, only a single event will be retained per time bin. Default: True, sum events in each time bin.

_abc_impl = <_abc._abc_data object>
_determine_timesteps(ts_input: TimeSeries | None = None, duration: float | None = None, num_timesteps: int | None = None) int

Determine how many time steps to evolve with the given input specification

Parameters:
  • ts_input (Optional[TimeSeries]) – TxM or Tx1 time series of input signals for this layer

  • duration (Optional[float]) – Duration of the desired evolution, in seconds. If not provided, num_timesteps or the duration of ts_input will be used to determine evolution time

  • num_timesteps (Optional[int]) – Number of evolution time steps, in units of dt. If not provided, duration or the duration of ts_input will be used to determine evolution time

Return int:

num_timesteps: Number of evolution time steps

_evolve_wrapper(ts_input=None, duration=None, num_timesteps=None, kwargs_timeseries=None, record: bool = False, *args, **kwargs) Tuple[TimeSeries, Dict, Dict]

Wrap a call to evolve() to update the internal time-steps count

See evolve() for calling syntax.

_force_set_attributes

(bool) If True, do not sanity-check attributes when setting.

_gen_time_trace(t_start: float, num_timesteps: int) ndarray

Generate a time trace starting at t_start, of length num_timesteps with time step dt

Parameters:
  • t_start (float) – Start time, in seconds

  • num_timesteps (int) – Number of time steps to generate, in units of dt

Return ndarray:

Generated time trace

_gen_timeseries(output: ndarray, **kwargs) TimeSeries

Wrap a clocked / rasterised output array into a TimeSeries object

Output TimeSeries will be of the appropriate subclass, and will be named nicely.

Parameters:
  • output (np.ndarray) – The clocked or rasterised output data (T, N)

  • **kwargs – Additional keyword arguments to TimeSeries

Returns:

The data in output wrapped into a TimeSeries object

Return type:

TimeSeries

_gen_tscontinuous(output: ndarray, dt: float | None = None, t_start: float | None = None, name: str | None = None, periodic: bool = False, interp_kind: str = 'previous') TSContinuous

Wrap a rasterised output array as a TSContinuous object to present as output for this module

Output TSContinuous s will be named nicely, with correct start times, durations, etc. Several attributes of the TSContinuous object can be set as arguments here.

Parameters:
  • output (np.ndarray) – A clocked time series data array (T, N)

  • dt (Optional[float]) – The time-step of the clocked array output. If not provided, the module dt will be used

  • t_start (Optional[float]) – The start time of the output TSContinuous object, in seconds. If not provided, the module time before evolution will be used

  • name (Optional[str]) – The desired name of the TSContinuous object. If not provided, the object will be named nicely according to the module name

  • periodic (bool) – Flag to indicate whether the returned TSContinuous should be periodic. Default: False, the TSContinuous will not be periodic

  • interp_kind (str) – The style of interpolation to apply to the returned TSContinuous object. Default: "previous"

Returns:

The wrapped output data as a TSContinuous object

Return type:

TSContinuous

_gen_tsevent(output: ndarray, dt: float | None = None, t_start: float | None = None, name: str | None = None, periodic: bool = False, num_channels: int | None = None, spikes_at_bin_start: bool = False) TSEvent

Wrap a rasterised output array as a TSEvent object to present as output for this module

Output TSEvent s will be named nicely, with correct start timesm durations, etc. Several attributes of the TSEvent object can be set as arguments here.

Parameters:
  • output (np.ndarray) – A rasterised event array (T, N)

  • dt (Optional[float]) – The time-step of the rasterised array output. If not provided, the module dt will be used

  • t_start (Optional[float]) – The start time of the output series, in seconds. If not provided, the module time before evolution will be used

  • name (Optional[str]) – The desired name of the TSEvent object. If not provided, the object will be named nicely according to the module name

  • periodic (bool) – Flag to indicate whether the returned TSEvent should be periodic. Default: False, the TSEvent will not be periodic

  • num_channels (Optional[int]) – The desired number of total channels for the output TSEvent object. If not provided, the output size size_out of the current module will be used

  • spikes_at_bin_start (bool) – If False (default), spike events will be considered to fall in the middle of the time bin they fall in. If True, all spike events will be considered to occur at the start of the time bin they fall in.

Returns:

The wrapped output raster as a TSEvent object

Return type:

TSEvent

_get_attribute_family(type_name: str, family: Tuple | List | str | None = None) dict

Search for attributes of this module and submodules that match a given family

This method can be used to conveniently get all weights for a network; or all time constants; or any other family of parameters. Parameter families are defined simply by a string: "weights" for weights; "taus" for time constants, etc. These strings are arbitrary, but if you follow the conventions then future developers will thank you (that includes you in six month’s time).

Parameters:
  • type_name (str) – The class of parameters to search for. Must be one of ["Parameter", "SimulationParameter", "State"] or another future subclass of ParameterBase

  • family (Union[str, Tuple[str]]) – A string or list or tuple of strings, that define one or more attribute families to search for

Returns:

A nested dictionary of attributes that match the provided type_name and family

Return type:

dict

_get_attribute_registry() Tuple[Dict, Dict]

Return or initialise the attribute registry for this module

Returns:

registered_attributes, registered_modules

Return type:

(tuple)

_has_registered_attribute(name: str) bool

Check if the module has a registered attribute

Parameters:

name (str) – The name of the attribute to check

Returns:

True if the attribute name is in the attribute registry, False otherwise.

Return type:

bool

_in_Module_init

(bool) If exists and True, indicates that the module is in the __init__ chain.

_is_child: bool

Flag indicating that this is a child module

_name: str | None

Name of this module, if assigned

_parent_dt_factor: float

The factor between the parent’s dt and this module’s dt. Given by self.dt / parent.dt

_prepare_input(ts_input: TimeSeries | None = None, duration: float | None = None, num_timesteps: int | None = None) Tuple[ndarray, ndarray, int]

Sample input, set up time base

This function checks an input signal, and prepares a discretised time base according to the time step of the current module

Parameters:
  • ts_input (Optional[TimeSeries]) – TimeSeries of TxM or Tx1 Input signals for this layer

  • duration (Optional[float]) – Duration of the desired evolution, in seconds. If not provided, then either num_timesteps or the duration of ts_input will define the evolution time

  • num_timesteps (Optional[int]) – Integer number of evolution time steps, in units of dt. If not provided, then duration or the duration of ts_input will define the evolution time

Return (ndarray, ndarray, int):

(time_base, input_steps, num_timesteps) time_base: T1 Discretised time base for evolution input_raster (T1xN) Discretised input signal for layer num_timesteps: Actual number of evolution time steps, in units of dt

_prepare_input_continuous(ts_input: TSContinuous | None = None, duration: float | None = None, num_timesteps: int | None = None) Tuple[ndarray, ndarray, int]

Sample input, set up time base

This function checks an input signal, and prepares a discretised time base according to the time step of the current module

Parameters:
  • ts_input (Optional[TSContinuous]) – TSContinuous of TxM or Tx1 Input signals for this layer

  • duration (Optional[float]) – Duration of the desired evolution, in seconds. If not provided, then either num_timesteps or the duration of ts_input will define the evolution time

  • num_timesteps (Optional[int]) – Integer number of evolution time steps, in units of dt. If not provided, then duration or the duration of ts_input will define the evolution time

Return (ndarray, ndarray, int):

(time_base, input_raster, num_timesteps) time_base: T1 Discretised time base for evolution input_raster: (T1xN) Discretised input signal for layer num_timesteps: Actual number of evolution time steps, in units of dt

_prepare_input_events(ts_input: TSEvent | None = None, duration: float | None = None, num_timesteps: int | None = None, add_events: bool = False) Tuple[ndarray, ndarray, int]

Sample input from a TSEvent time series, set up evolution time base

This function checks an input signal, and prepares a discretised time base according to the time step of the current module

Parameters:
  • ts_input (Optional[TSEvent]) – TimeSeries of TxM or Tx1 Input signals for this layer

  • duration (Optional[float]) – Duration of the desired evolution, in seconds. If not provided, then either num_timesteps or the duration of ts_input will determine evolution itme

  • num_timesteps (Optional[int]) – Number of evolution time steps, in units of dt. If not provided, then either duration or the duration of ts_input will determine evolution time

Return (ndarray, ndarray, int):

time_base: T1X1 vector of time points – time base for the rasterisation spike_raster: Boolean or integer raster containing spike information. T1xM array num_timesteps: Actual number of evolution time steps, in units of dt

_register_attribute(name: str, val: ParameterBase)

Record an attribute in the attribute registry

Parameters:
  • name (str) – The name of the attribute to register

  • val (ParameterBase) – The ParameterBase subclass object to register. e.g. Parameter, SimulationParameter or State.

_register_module(name: str, mod: ModuleBase)

Register a sub-module in the module registry

Parameters:
  • name (str) – The name of the module to register

  • mod (ModuleBase) – The ModuleBase object to register

_reset_attribute(name: str) ModuleBase

Reset an attribute to its initialisation value

Parameters:

name (str) – The name of the attribute to reset

Returns:

For compatibility with the functional API

Return type:

self (Module)

_set_dt(max_factor: float = 100) None

Set a time step size for the network which is the lowest common multiple of all sub-module’s dt s.

Parameters:

max_factor (float) – Factor by which the module dt may exceed the largest sub-module dt before an error is raised. Default: 100.

Raises:

ValueError – If a sensible dt cannot be found

_shape

The shape of this module

_spiking_input: bool

Whether this module receives spiking input

_spiking_output: bool

Whether this module produces spiking output

_submodulenames: List[str]

Registry of sub-module names

_timestep: int

The current time-step count in units of dt

_wrap_recorded_state(recorded_dict: dict, t_start: float) Dict[str, TimeSeries]

Convert a recorded dictionary to a TimeSeries representation

This method is optional, and is provided to make the timed() conversion to a TimedModule work better. You should override this method in your custom Module, to wrap each element of your recorded state dictionary as a TimeSeries

Parameters:
  • state_dict (dict) – A recorded state dictionary as returned by evolve()

  • t_start (float) – The initial time of the recorded state, to use as the starting point of the time series

Returns:

The mapped recorded state dictionary, wrapped as TimeSeries objects

Return type:

Dict[str, TimeSeries]

as_graph() GraphModuleBase

Convert this module to a computational graph

Returns:

The computational graph corresponding to this module

Return type:

GraphModuleBase

Raises:

NotImplementedError – If as_graph() is not implemented for this subclass

attributes_named(name: Tuple[str] | List[str] | str) dict

Search for attributes of this or submodules by time

Parameters:

name (Union[str, Tuple[str]) – The name of the attribute to search for

Returns:

A nested dictionary of attributes that match name

Return type:

dict

property class_name: str

Class name of self

Type:

str

dt: float | SimulationParameter

The simulation and input rasterisation timestep for this TimedModule

Type:

float

evolve(ts_input: TimeSeries | None = None, duration: float | None = None, num_timesteps: int | None = None, kwargs_timeseries: Dict | None = None, record: bool = False, *args, **kwargs) Tuple[TimeSeries, Any, Any][source]

Evolve the wrapped Module, handling TimeSeries input and output

Parameters:
  • ts_input (Optional[TimeSeries]) – The input data for this evolution. If not provided, zero input will be used

  • duration (Optional[float]) – The duration over which to evolve this module, in seconds. If not provided, it will be inferred

  • num_timesteps (Optional[int]) – The number of time steps over which to evolve this module, in units of dt. If not provided, it will be inferred

  • kwargs_timeseries (Optional[dict]) – Additional keyword arguments to pass when generating the output time series

  • record (bool) – If True, a dictionary containing a record of state during evolution for this and all submodules will be returned. If False (default), no record is requested.

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns:

(output_ts, new_state, record_dict)

output_ts (TimeSeries): The output of this module, wrapped as a TimeSeries. new_state (dict): A dictionary containing teh updated state of this and all sub-modules after evolution record_dict (dict): If True, a dictionary containing a record of state during evolution for this and all submodules will be returned. If False (default), no record is requested.

Return type:

tuple

property full_name: str

The full name of this module (class plus module name)

Type:

str

property input_type: type

The TimeSeries class accepted by this module

Type:

type

property module

The wrapped module

Type:

Module

modules() Dict

Return a dictionary of all sub-modules of this module

Returns:

A dictionary containing all sub-modules. Each item will be named with the sub-module name.

Return type:

dict

property name: str

The name of this module, or an empty string if None

Type:

str

property output_type: type

The TimeSeries class returned by this module

Type:

type

parameters(family: Tuple | List | str | None = None) Dict

Return a nested dictionary of module and submodule Parameters

Use this method to inspect the Parameters from this and all submodules. The optional argument family allows you to search for Parameters in a particular family β€” for example "weights" for all weights of this module and nested submodules.

Although the family argument is an arbitrary string, reasonable choises are "weights", "taus" for time constants, "biases" for biases…

Examples

Obtain a dictionary of all Parameters for this module (including submodules):

>>> mod.parameters()
dict{ ... }

Obtain a dictionary of Parameters from a particular family:

>>> mod.parameters("weights")
dict{ ... }
Parameters:

family (str) – The family of Parameters to search for. Default: None; return all parameters.

Returns:

A nested dictionary of Parameters of this module and all submodules

Return type:

dict

reset_all() None

Reset the internal state and time of this module and all sub-modules

reset_parameters()

Reset all parameters in this module

Returns:

The updated module is returned for compatibility with the functional API

Return type:

Module

reset_state() None[source]

Reset the state of this module

Returns:

The updated module is returned for compatibility with the functional API

Return type:

Module

reset_time() None

Reset the internal time of this module and all sub-modules to zero

set_attributes(new_attributes: dict) ModuleBase

Set the attributes and sub-module attributes from a dictionary

This method can be used with the dictionary returned from module evolution to set the new state of the module. It can also be used to set multiple parameters of a module and submodules.

Examples

Use the functional API to evolve, obtain new states, and set those states:

>>> _, new_state, _ = mod(input)
>>> mod = mod.set_attributes(new_state)

Obtain a parameter dictionary, modify it, then set the parameters back:

>>> params = mod.parameters()
>>> params['w_input'] *= 0.
>>> mod.set_attributes(params)
Parameters:

new_attributes (dict) – A nested dictionary containing parameters of this module and sub-modules.

property shape: tuple

The shape of this module

Type:

tuple

simulation_parameters(family: Tuple | List | str | None = None) Dict

Return a nested dictionary of module and submodule SimulationParameters

Use this method to inspect the SimulationParameters from this and all submodules. The optional argument family allows you to search for SimulationParameters in a particular family.

Examples

Obtain a dictionary of all SimulationParameters for this module (including submodules):

>>> mod.simulation_parameters()
dict{ ... }
Parameters:

family (str) – The family of SimulationParameters to search for. Default: None; return all SimulationParameter attributes.

Returns:

A nested dictionary of SimulationParameters of this module and all submodules

Return type:

dict

property size: int

(DEPRECATED) The output size of this module

Type:

int

property size_in: int

The input size of this module

Type:

int

property size_out: int

The output size of this module

Type:

int

property spiking_input: bool

If True, this module receives spiking input. If False, this module expects continuous input.

Type:

bool

property spiking_output

If True, this module sends spiking output. If False, this module sends continuous output.

Type:

bool

state(family: Tuple | List | str | None = None) Dict

Return a nested dictionary of module and submodule States

Use this method to inspect the States from this and all submodules. The optional argument family allows you to search for States in a particular family.

Examples

Obtain a dictionary of all States for this module (including submodules):

>>> mod.state()
dict{ ... }
Parameters:

family (str) – The family of States to search for. Default: None; return all State attributes.

Returns:

A nested dictionary of States of this module and all submodules

Return type:

dict

property t: float

The current evolution time of this layer, in seconds

Type:

float