"""
Implement the :py:class:`.Sequential` combinator, with helper classes for Jax and Torch backends
"""
from rockpool.nn.modules.module import Module, ModuleBase
from rockpool import TSContinuous, TSEvent
from copy import copy
from typing import Tuple, Any, Optional, Union
from abc import ABC
from collections import OrderedDict
import rockpool.graph as rg
__all__ = ["Sequential"]
class SequentialMixin(ABC):
"""
Base class for :py:class:`.Sequential` modules
"""
def __init__(self, *args, **kwargs):
"""
Initialise a :py:class:`.Sequential` module
"""
# - Check that `shape` wasn't provided as a keyword argument
if "shape" in kwargs:
raise ValueError(
"You may not provide a `shape` argument when building a Sequential module."
)
if "spiking_input" in kwargs:
raise ValueError(
"You may not provide a `spiking_input` argument when building a Sequential module."
)
if "spiking_output" in kwargs:
raise ValueError(
"You may not provide a `spiking_output` argument when building a Sequential module."
)
# - Extract OrderedDict modules from arguments list
if len(args) > 0 and isinstance(args[0], OrderedDict):
submods = args[0]
args = args[1:]
other_args = []
mod_index = 1
else:
submods = OrderedDict()
other_args = []
mod_index = 0
# - Extract additional modules from arguments list
for item in args:
if isinstance(item, ModuleBase):
# - Collect the module and define a name
name = f"{mod_index}_{item.class_name}"
submods[name] = item
mod_index += 1
else:
other_args.append(item)
# - Call super-class initialisation
super().__init__(shape=(0, 0), *other_args, **kwargs)
# - Call `append` for each module
[self.append(mod, name) for name, mod in submods.items()]
def append(self, mod: ModuleBase, name: Optional[str] = None) -> ModuleBase:
"""
Append a module to the :py:class:`.Sequential` network stack
Args:
mod (Module): A rockpool :py:class:`.Module` to append to this network stack. The input size of `mod` must match the output size of the existing network.
name (str): An optional name to assign to the new module. If ``None``, a name will automatically be generated.
"""
# - Get a name and index for this module
mod_index = len(self._submodulenames)
if name is None:
name = f"{mod_index}_{mod.class_name}"
if name in self._submodulenames:
raise ValueError(
f'Submodule "{name}" already exists in Sequential network. Cannot append a module with the same name.'
)
# - Check if the shapes are compatible
if len(self._submodulenames) == 0:
self._shape = mod.shape
self._spiking_input = mod._spiking_input
self._spiking_output = mod._spiking_output
elif (
self.size_out != mod.size_in
and self.size_out is not None
and mod.size_in is not None
):
raise ValueError(
f"The output of submodule {mod_index-1} "
+ f"({type(self[-1]).__name__}) "
+ f"does not match the input shape of submodule "
+ f"{mod_index} ({type(mod).__name__}): "
+ f"{self[-1].size_out} ≠ {mod.size_in}"
)
# - Assign module
setattr(self, name, mod)
# - Fix shape and output type
self._shape = (self.size_in, mod.size_out)
self._spiking_output = mod.spiking_output
def evolve(self, input_data, record: bool = False) -> Tuple[Any, Any, Any]:
# - Initialise state and record dictionaries
new_state_dict = {}
record_dict = {}
x = input_data
# - Loop through submodules
for submod_name in self._submodulenames:
# - Get this submodule
mod = getattr(self, submod_name)
# - Push data through submodule
x, substate, subrec = mod(x, record=record)
new_state_dict.update({submod_name: substate})
record_dict.update(
{
submod_name: subrec,
f"{submod_name}_output": copy(x),
}
)
# - Return output, state and record
return x, new_state_dict, record_dict
def __getitem__(self, item: Union[int, str]) -> Module:
"""
Permit indexing into the sequence of modules
Args:
item (Union[int, str]): The index of the module to return, or the name of the module to access
Returns:
Module: The ``item``th module in the sequence
"""
if isinstance(item, str):
return Module.modules(self)[item]
else:
return Module.modules(self)[self._submodulenames[item]]
def as_graph(self):
mod_graphs = []
for mod in self:
mod_graphs.append(mod.as_graph())
for source, dest in zip(mod_graphs[:-1], mod_graphs[1:]):
rg.connect_modules(source, dest)
return rg.GraphHolder(
mod_graphs[0].input_nodes,
mod_graphs[-1].output_nodes,
f"{type(self).__name__}_{self.name}_{id(self)}",
self,
)
def _wrap_recorded_state(self, state_dict: dict, t_start: float = 0.0) -> dict:
# - Wrap each sub-dictionary in turn
for mod_name in self._submodulenames:
mod = self.modules()[mod_name]
state_dict[mod_name].update(
mod._wrap_recorded_state(state_dict[mod_name], t_start)
)
# - Wrap recorded output for this module
output_key = f"{mod_name}_output"
dt = mod.dt if hasattr(mod, "dt") else self.dt
if mod.spiking_output:
ts_output = TSEvent.from_raster(
state_dict[output_key][0],
dt=dt,
name=output_key,
t_start=t_start,
)
else:
ts_output = TSContinuous.from_clocked(
state_dict[output_key][0],
dt=dt,
name=output_key,
t_start=t_start,
)
state_dict.update({output_key: ts_output})
# - Return wrapped dictionary
return state_dict
class ModSequential(SequentialMixin, Module):
"""
The :py:class:`.Sequential` combinator for native modules
"""
pass
try:
from rockpool.nn.modules.jax.jax_module import JaxModule
from jax import numpy as jnp
class JaxSequential(SequentialMixin, JaxModule):
"""
The :py:class:`.Sequential` combinator for Jax modules
"""
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Unflatten a tree of modules from Jax to Rockpool"""
params, sim_params, state, modules, init_params = children
_name, _shape, _submodulenames = aux_data
modules = tuple(modules.values())
obj = cls(*modules)
obj._name = _name
# - Restore configuration
obj = obj.set_attributes(params)
obj = obj.set_attributes(state)
obj = obj.set_attributes(sim_params)
return obj
except:
class JaxModule:
pass
class JaxSequential:
"""
The :py:class:`.Sequential` combinator for Jax modules
"""
def __init__(self):
raise ImportError(
"'Jax' and 'Jaxlib' backend not found. Modules relying on Jax will not be available."
)
try:
from rockpool.nn.modules.torch.torch_module import TorchModule
import torch
from torch.nn import Module as torch_nn_module
class TorchSequential(SequentialMixin, TorchModule):
"""
The :py:class:`.Sequential` combinator for torch modules
"""
def __init__(
self,
*args,
**kwargs,
):
# - Convert torch modules to Rockpool TorchModules
for item in args:
if isinstance(item, torch_nn_module) and not isinstance(
item, TorchModule
):
TorchModule.from_torch(item, retain_torch_api=False)
# - Call super-class constructor
super().__init__(*args, **kwargs)
def forward(self, *args, **kwargs):
# - By default, record state
record = kwargs.get("record", True)
kwargs["record"] = record
# - Return output
return self.evolve(*args, **kwargs)[0]
except:
class TorchModule:
pass
class torch_nn_module:
pass
class TorchSequential:
"""
The :py:class:`.Sequential` combinator for torch modules
"""
def __init__(self):
raise ImportError(
"'Torch' backend not found. Modules relying on PyTorch will not be available."
)
[docs]def Sequential(*args, **kwargs) -> ModuleBase:
"""
Build a sequential stack of modules by connecting them end-to-end
:py:class:`.Sequential` accepts any number of modules. The shapes of the modules must be compatible -- the output size :py:attr:`~.Module.size_out` of each module must match the input size :py:attr:`~.Module.size_in` of the following module.
When provided with a list of modules, :py:class:`.Sequential` will assign module names automatically to each module. If you would like more control over module names, you can provide an `OrderedDict` to construct the network. In that case, dictionary keys will be used as module names.
You can also append additional modules to a network with the :py:meth:`.Sequential.append` method. Module names can optionally be provided in this case as well.
Examples:
Build a :py:class:`.Sequential` stack will be returned a :py:class:`.Module`, containing ``mod0``, ``mod1`` and ``mod2``. When evolving this stack, signals will be passed through ``mod0``, then ``mod1``, then ``mod2``:
>>> Sequential(mod0, mod1, mod2)
Index into a :py:class:`.Sequential` stack using Python indexing:
>>> mod = Sequential(mod0, mod1, mod2)
>>> mod[0]
A module with shape (xx, xx)
Build a :py:class:`.Sequential` stack from an `OrderedDict`:
>>> od = OrderedDict([('mod0', mod0), ('mod1', mod1)])
>>> seq = Sequential(od)
Build an empty :py:class:`.Sequential`, and use :py:meth:`.Sequential.append`:
>>> seq = Sequential()
>>> seq.append(mod0)
>>> seq.append(mod1, 'mod1)
Args:
*mods: Any number of modules to connect. The :py:attr:`~.Module.size_out` attribute of one module must match the :py:attr:`~.Module.size_in` attribute of the following module.
Returns:
A :py:class:`.Module` subclass object that encapsulates the provided modules
"""
# - Check for Jax and Torch submodules
for item in args:
if isinstance(item, JaxModule):
return JaxSequential(*args, **kwargs)
if isinstance(item, (TorchModule, torch_nn_module)):
return TorchSequential(*args, **kwargs)
# - Use ModSequential if no JaxModule or TorchModule is in the submodules
return ModSequential(*args, **kwargs)