Source code for nn.modules.module

"""
Contains the module base classes for Rockpool
"""

# - Rockpool imports
import collections

from rockpool.parameters import ParameterBase
from rockpool.timeseries import TimeSeries

try:
    from rockpool.graph.graph_base import GraphModuleBase
except:
    GraphModuleBase = "GraphModuleBase"

# - Other imports
from abc import ABC, abstractmethod
from warnings import warn
from collections import ChainMap
from typing import Tuple, Any, Iterable, Dict, Optional, List, Union
import numpy as np

__all__ = ["Module"]


class ModuleBase(ABC):
    """
    Base class for all `Module` subclasses in Rockpool
    """

    def __init__(
        self,
        shape: Optional[Union[Tuple, int]] = None,
        spiking_input: bool = False,
        spiking_output: bool = False,
        *args,
        **kwargs,
    ):
        """
        Initialise this module

        Args:
            shape (Optional[Union[Tuple, int]]): The shape of the defined module
            spiking_input (bool): Whether this module receives spiking input. Default: False
            spiking_output (bool): Whether this module produces spiking output. Default: False
            *args: Additional positional arguments
            **kwargs: Additional keyword arguments
        """
        # - Set flag to specify that we are in the `__init__()` method
        self._in_Module_init = True
        """ (bool) If exists and ``True``, indicates that the module is in the ``__init__`` chain."""

        self._force_set_attributes = False
        """ (bool) If ``True``, do not sanity-check attributes when setting. """

        # - Initialise co-classes etc.
        super().__init__(*args, **kwargs)

        # - Initialise Module attributes
        self._submodulenames: List[str] = []
        """Registry of sub-module names"""

        self._name: Optional[str] = None
        """Name of this module, if assigned"""

        self._spiking_input: bool = spiking_input
        """Whether this module receives spiking input"""

        self._spiking_output: bool = spiking_output
        """Whether this module produces spiking output"""

        # - Be generous if a scalar was provided instead of a tuple
        if isinstance(shape, Iterable):
            self._shape = tuple(shape)
            """The shape of this module"""
        else:
            self._shape = (shape,)
            """The shape of this module"""

    def __repr__(self, indent: str = "") -> str:
        """
        Produce a string representation of this module

        Args:
            indent (str): The indent to prepend to each line of output

        Returns:
            str: A string representation of this module
        """
        # - String representation
        repr = f"{indent}{self.full_name} with shape {self._shape}"

        # - Add submodules
        if self._submodulenames:
            repr += " {"
            for mod_name in self._submodulenames:
                repr += "\n" + ModuleBase.__repr__(
                    getattr(self, mod_name), indent=indent + "    "
                )

            repr += f"\n{indent}" + "}"

        return repr

    def _get_attribute_registry(self) -> Tuple[Dict, Dict]:
        """
        Return or initialise the attribute registry for this module

        Returns:
            (tuple): registered_attributes, registered_modules
        """
        if not hasattr(self, "_ModuleBase__registered_attributes") or not hasattr(
            self, "_ModuleBase__modules"
        ):
            super().__setattr__("_ModuleBase__registered_attributes", {})
            super().__setattr__("_ModuleBase__modules", {})

        # - Get the attribute and modules dictionaries in a safe way
        __registered_attributes = self.__dict__.get(
            "_ModuleBase__registered_attributes"
        )
        __modules = self.__dict__.get("_ModuleBase__modules")

        return __registered_attributes, __modules

    def __setattr__(self, name: str, val: Any):
        """
        Set an attribute for this module

        Args:
            name (str): The name of the attribute to set
            val (Any): The value to assign to the attribute
        """
        # - Get attribute registry
        __registered_attributes, __modules = self._get_attribute_registry()

        # - Check if this is a new rockpool Parameter
        if isinstance(val, ParameterBase):
            try:
                super().__getattribute__("_in_Module_init")
            except Exception as e:
                raise NotImplementedError(
                    "You must call `super.__init__()` in your `Module` subclass."
                )

            if (
                hasattr(self, name)
                and hasattr(self, "_in_Module_init")
                and not self._in_Module_init
            ):
                raise ValueError(
                    f'Cannot assign a new Parameter or State to an existing attribute "{name}".'
                )

            # - Register the attribute
            self._register_attribute(name, val)
            val = val.data

        # - Are we assigning a sub-module?
        if isinstance(val, ModuleBase):
            self._register_module(name, val)

        # - Check if this is an already registered attribute
        if name in __registered_attributes:
            if hasattr(self, name):
                (_, _, _, _, shape) = __registered_attributes[name]

                if val is not None:
                    # - Should we force-set the attribute?
                    if self._force_set_attributes:
                        __registered_attributes[name][4] = np.shape(val)

                    elif np.shape(val) != shape:
                        # - Check that shapes are identical
                        raise ValueError(
                            f"The new value assigned to {name} must be of shape {shape} (got {np.shape(val)})."
                        )

            # - Assign the value to the __registered_attributes dictionary
            __registered_attributes[name][0] = val

        # - Assign attribute to self
        super().__setattr__(name, val)

    def __delattr__(self, name: str):
        """
        Delete an attribute from this module, and remove from the attribute registry if present

        Args:
            name (str): The name of the attribute to delete
        """
        # - Get attribute registry
        __registered_attributes, __modules = self._get_attribute_registry()

        # - Remove attribute from registered attributes
        if name in __registered_attributes:
            del __registered_attributes[name]

        # - Remove name from modules
        if name in __modules:
            del __modules[name]
            self._submodulenames.remove(name)

        # - Remove attribute
        super().__delattr__(name)

    def _register_attribute(self, name: str, val: ParameterBase):
        """
        Record an attribute in the attribute registry

        Args:
            name (str): The name of the attribute to register
            val (ParameterBase): The `ParameterBase` subclass object to register. e.g. `Parameter`, `SimulationParameter` or `State`.
        """
        # - Get attribute registry
        __registered_attributes, __modules = self._get_attribute_registry()

        # - Record attribute in attribute registry
        __registered_attributes[name]: dict = [
            val.data,
            type(val).__name__,
            val.family,
            val.init_func,
            val.shape,
        ]
        """The attribute registry for this module"""

    def _register_module(self, name: str, mod: "ModuleBase"):
        """
        Register a sub-module in the module registry

        Args:
            name (str): The name of the module to register
            mod (ModuleBase): The `ModuleBase` object to register
        """
        # - Get attribute registry
        __registered_attributes, __modules = self._get_attribute_registry()

        if not isinstance(mod, ModuleBase):
            raise ValueError(
                f"You may only assign a `Module` subclass as a sub-module."
            )

        # - Assign module name to module
        mod._name = name

        # - Assign to appropriate attribute dictionary
        __modules[name] = [mod, type(mod).__name__]
        if name not in self._submodulenames:
            self._submodulenames.append(name)

    def set_attributes(self, new_attributes: dict) -> "ModuleBase":
        """
        Set the attributes and sub-module attributes from a dictionary

        This method can be used with the dictionary returned from module evolution to set the new state of the module. It can also be used to set multiple parameters of a module and submodules.

        Examples:
            Use the functional API to evolve, obtain new states, and set those states:

            >>> _, new_state, _ = mod(input)
            >>> mod = mod.set_attributes(new_state)

            Obtain a parameter dictionary, modify it, then set the parameters back:

            >>> params = mod.parameters()
            >>> params['w_input'] *= 0.
            >>> mod.set_attributes(params)

        Args:
            new_attributes (dict): A nested dictionary containing parameters of this module and sub-modules.
        """
        # - Get attribute registry
        __registered_attributes, __modules = self._get_attribute_registry()

        # - Set self attributes
        for k, v in __registered_attributes.items():
            if k in new_attributes:
                self.__setattr__(k, new_attributes[k])

        # - Set submodule attributes
        for k, m in __modules.items():
            if k in new_attributes:
                m[0].set_attributes(new_attributes[k])

        # - Return the module, for compatibility with the functional interface
        return self

    def _get_attribute_family(
        self, type_name: str, family: Union[str, Tuple, List] = None
    ) -> dict:
        """
        Search for attributes of this module and submodules that match a given family

        This method can be used to conveniently get all weights for a network; or all time constants; or any other family of parameters. Parameter families are defined simply by a string: ``"weights"`` for weights; ``"taus"`` for time constants, etc. These strings are arbitrary, but if you follow the conventions then future developers will thank you (that includes you in six month's time).

        Args:
            type_name (str): The class of parameters to search for. Must be one of ``["Parameter", "SimulationParameter", "State"]`` or another future subclass of :py:class:`.ParameterBase`
            family (Union[str, Tuple[str]]): A string or list or tuple of strings, that define one or more attribute families to search for

        Returns:
            dict: A nested dictionary of attributes that match the provided `type_name` and `family`
        """
        # - Get attribute registry
        __registered_attributes, __modules = self._get_attribute_registry()

        # - Filter own attribute dictionary by type key
        matching_attributes = {
            k: v for (k, v) in __registered_attributes.items() if v[1] == type_name
        }

        # - Filter by family
        if family is not None:
            if not isinstance(family, (tuple, list)):
                family = (family,)

            list_attributes = [
                {k: v for (k, v) in matching_attributes.items() if v[2] is f}
                for f in family
            ]
            matching_attributes = dict(ChainMap(*list_attributes))

        # - Just take values using getattr
        matching_attributes = {k: getattr(self, k) for k in matching_attributes.keys()}

        # - Append sub-module attributes as nested dictionaries
        submodule_attributes = {}
        for k, m in __modules.items():
            mod_attributes = m[0]._get_attribute_family(type_name, family)

            if (family and mod_attributes) or (not family):
                submodule_attributes[k] = mod_attributes

        # - Push submodule attributes into dictionary
        if family and submodule_attributes or not family:
            matching_attributes.update(submodule_attributes)

        # - Return nested attributes
        return matching_attributes

    def attributes_named(self, name: Union[Tuple[str], List[str], str]) -> dict:
        """
        Search for attributes of this or submodules by time

        Args:
            name (Union[str, Tuple[str]): The name of the attribute to search for

        Returns:
            dict: A nested dictionary of attributes that match `name`
        """
        # - Get attribute registry
        __registered_attributes, __modules = self._get_attribute_registry()

        # - Check if we were given a tuple or not
        if not isinstance(name, (tuple, list)):
            name = (name,)

        # - Filter own attribute dictionary by name keys
        list_attributes = [
            {k: v for (k, v) in __registered_attributes.items() if k == n} for n in name
        ]
        matching_attributes = dict(ChainMap(*list_attributes))

        # - Just take values
        matching_attributes = {k: v[0] for (k, v) in matching_attributes.items()}

        # - Append sub-module attributes as nested dictionaries
        submodule_attributes = {}
        for k, m in __modules.items():
            mod_attributes = m[0].attributes_named(name)

            if mod_attributes:
                submodule_attributes[k] = mod_attributes

        # - Push submodule attributes into dictionary
        if submodule_attributes:
            matching_attributes.update(submodule_attributes)

        # - Return nested attributes
        return matching_attributes

    def parameters(self, family: Union[str, Tuple, List] = None) -> Dict:
        """
        Return a nested dictionary of module and submodule Parameters

        Use this method to inspect the Parameters from this and all submodules. The optional argument `family` allows you to search for Parameters in a particular family — for example ``"weights"`` for all weights of this module and nested submodules.

        Although the `family` argument is an arbitrary string, reasonable choises are ``"weights"``, ``"taus"`` for time constants, ``"biases"`` for biases...

        Examples:
            Obtain a dictionary of all Parameters for this module (including submodules):

            >>> mod.parameters()
            dict{ ... }

            Obtain a dictionary of Parameters from a particular family:

            >>> mod.parameters("weights")
            dict{ ... }

        Args:
            family (str): The family of Parameters to search for. Default: ``None``; return all parameters.

        Returns:
            dict: A nested dictionary of Parameters of this module and all submodules
        """
        return self._get_attribute_family("Parameter", family)

    def simulation_parameters(self, family: Union[str, Tuple, List] = None) -> Dict:
        """
        Return a nested dictionary of module and submodule SimulationParameters

        Use this method to inspect the SimulationParameters from this and all submodules. The optional argument `family` allows you to search for SimulationParameters in a particular family.

        Examples:
            Obtain a dictionary of all SimulationParameters for this module (including submodules):

            >>> mod.simulation_parameters()
            dict{ ... }

        Args:
            family (str): The family of SimulationParameters to search for. Default: ``None``; return all SimulationParameter attributes.

        Returns:
            dict: A nested dictionary of SimulationParameters of this module and all submodules

        """
        return self._get_attribute_family("SimulationParameter", family)

    def state(self, family: Union[str, Tuple, List] = None) -> Dict:
        """
        Return a nested dictionary of module and submodule States

        Use this method to inspect the States from this and all submodules. The optional argument `family` allows you to search for States in a particular family.

        Examples:
            Obtain a dictionary of all States for this module (including submodules):

            >>> mod.state()
            dict{ ... }

        Args:
            family (str): The family of States to search for. Default: ``None``; return all State attributes.

        Returns:
            dict: A nested dictionary of States of this module and all submodules
        """
        return self._get_attribute_family("State", family)

    def modules(self) -> Dict:
        """
        Return a dictionary of all sub-modules of this module

        Returns:
            dict: A dictionary containing all sub-modules. Each item will be named with the sub-module name.
        """
        # - Get attribute registry
        __registered_attributes, __modules = self._get_attribute_registry()

        return collections.OrderedDict([(k, m[0]) for (k, m) in __modules.items()])

    def _reset_attribute(self, name: str) -> "ModuleBase":
        """
        Reset an attribute to its initialisation value

        Args:
            name (str): The name of the attribute to reset

        Returns:
            self (`Module`): For compatibility with the functional API
        """
        # - Get attribute registry
        __registered_attributes, __modules = self._get_attribute_registry()

        # - Check that the attribute is registered
        if name not in __registered_attributes:
            raise KeyError(f"{name} is not a registered attribute.")

        # - Get the initialisation function from the registry
        (_, _, family, init_func, shape) = __registered_attributes[name]

        # - Use the registered initialisation function, if present
        if init_func is not None:
            setattr(self, name, init_func(shape))

        return self

    def _has_registered_attribute(self, name: str) -> bool:
        """
        Check if the module has a registered attribute

        Args:
            name (str): The name of the attribute to check

        Returns:
            bool: ``True`` if the attribute `name` is in the attribute registry, ``False`` otherwise.
        """
        __registered_attributes, _ = self._get_attribute_registry()
        return name in __registered_attributes

    def reset_state(self) -> "ModuleBase":
        """
        Reset the state of this module

        Returns:
            Module: The updated module is returned for compatibility with the functional API

        """
        # - Get attribute registry
        __registered_attributes, __modules = self._get_attribute_registry()

        # - Get a list of states
        states = self.state()

        # - Set self attributes
        for k, v in __registered_attributes.items():
            if k in states:
                self._reset_attribute(k)

        # - Reset submodule states
        for k, m in __modules.items():
            m[0] = m[0].reset_state()

        return self

    def reset_parameters(self):
        """
        Reset all parameters in this module

        Returns:
            Module: The updated module is returned for compatibility with the functional API
        """
        # - Get attribute registry
        __registered_attributes, __modules = self._get_attribute_registry()

        # - Get a list of parameters
        parameters = self.parameters()

        # - Set self attributes
        for k, v in __registered_attributes.items():
            if k in parameters:
                self._reset_attribute(k)

        # - Reset submodule states
        for k, m in __modules.items():
            m[0] = m[0].reset_parameters()

        return self

    @property
    def class_name(self) -> str:
        """str: Class name of ``self``"""
        # - Determine class name by removing "<class '" and "'>" and the package information
        return type(self).__name__

    @property
    def name(self) -> str:
        """str: The name of this module, or an empty string if ``None``"""
        try:
            name = super().__getattribute__("_name")
            return f"'{name}'" if name else ""
        except:
            return ""

    @property
    def full_name(self) -> str:
        """str: The full name of this module (class plus module name)"""
        return f"{self.class_name} {self.name}"

    @property
    def spiking_input(self) -> bool:
        """bool: If ``True``, this module receives spiking input. If ``False``, this module expects continuous input."""
        return self._spiking_input

    @property
    def spiking_output(self):
        """bool: If ``True``, this module sends spiking output. If ``False``, this module sends continuous output."""
        return self._spiking_output

    @property
    def shape(self) -> tuple:
        """tuple: The shape of this module"""
        return self._shape

    @property
    def size(self) -> int:
        """int: (DEPRECATED) The output size of this module"""
        warn(
            "The `size` property is deprecated. Please use `size_out` instead.",
            DeprecationWarning,
        )
        return self._shape[-1]

    @property
    def size_out(self) -> int:
        """int: The output size of this module"""
        return self._shape[-1]

    @property
    def size_in(self) -> int:
        """int: The input size of this module"""
        return self._shape[0]

    @abstractmethod
    def evolve(self, input_data, record: bool = False) -> Tuple[Any, Any, Any]:
        """
        Evolve the state of this module over input data

        NOTE: THIS MODULE CLASS DOES NOT PROVIDE DOCUMENTATION FOR ITS EVOLVE METHOD. PLEASE UPDATE THE DOCUMENTATION FOR THIS MODULE.

        Args:
            input_data: The input data with shape ``(T, size_in)`` to evolve with
            record (bool): If ``True``, the module should record internal state during evolution and return the record. If ``False``, no recording is required. Default: ``False``.

        Returns:
            tuple: (output, new_state, record)
                output (np.ndarray): The output response of this module with shape ``(T, size_out)``
                new_state (dict): A dictionary containing the updated state of this and all submodules after evolution
                record (dict): A dictionary containing recorded state of this and all submodules, if requested using the `record` argument
        """
        return None, None, None

    def __call__(self, input_data, *args, **kwargs):
        """
        Evolve the state of this module over input data

        Args:
            input_data: The input data with shape ``(T, size_in)`` to evolve this module with
            *args: Additional positional arguments
            **kwargs: Additional keyword arguments

        Returns:

        """
        # - Catch the case where we have been called with the raw output of a previous call
        if isinstance(input_data, tuple) and len(input_data) == 3:
            input_data, new_state, recorded_state = input_data
            outputs, this_new_state, this_recorded_state = self.evolve(
                input_data, *args, **kwargs
            )
            new_state.update({self.name: this_new_state})
            recorded_state.update({self.name: this_recorded_state})
        else:
            outputs, new_state, recorded_state = self.evolve(
                input_data, *args, **kwargs
            )

        return outputs, new_state, recorded_state

    def _wrap_recorded_state(
        self, recorded_dict: dict, t_start: float
    ) -> Dict[str, TimeSeries]:
        """
        Convert a recorded dictionary to a :py:class:`TimeSeries` representation

        This method is optional, and is provided to make the :py:meth:`timed` conversion to a :py:class:`TimedModule` work better. You should override this method in your custom :py:class:`Module`, to wrap each element of your recorded state dictionary as a :py:class:`TimeSeries`

        Args:
            state_dict (dict): A recorded state dictionary as returned by :py:meth:`.evolve`
            t_start (float): The initial time of the recorded state, to use as the starting point of the time series

        Returns:
            Dict[str, TimeSeries]: The mapped recorded state dictionary, wrapped as :py:class:`TimeSeries` objects
        """
        return recorded_dict

    def as_graph(self) -> GraphModuleBase:
        """
        Convert this module to a computational graph

        Returns:
            GraphModuleBase: The computational graph corresponding to this module

        Raises:
            NotImplementedError: If :py:meth:`.as_graph` is not implemented for this subclass
        """
        raise NotImplementedError(
            f"The Module class '{type(self).__name__}' used by module [{self}] does not implement the graph serialisation API"
        )


class PostInitMetaMixin(type(ModuleBase)):
    """
    A mixin base class that adds a ``__post_init__()`` method to a class. ``__post_init__()`` is called after the ``__init__()`` method, on instantiation of an object.
    """

    def __call__(cls, *args, **kwargs):
        # - Instantiate the object
        obj = super().__call__(*args, **kwargs)

        # - Check for a `__post_init__` method
        if hasattr(cls, "__post_init__"):
            cls.__post_init__(obj)

        # - Clear the init flag
        obj._in_Module_init = False
        return obj


[docs]class Module(ModuleBase, ABC, metaclass=PostInitMetaMixin): """ The native Python / numpy :py:class:`.Module` base class for Rockpool This class acts as the base class for all "native" modules in Rockpool. To get started with using and writing your own Rockpool modules, see :ref:`/basics/getting_started.ipynb` and :ref:`/in-depth/api-low-level.ipynb`. If you plan to write modules using Jax or Torch backends, you should use either :py:class:`.JaxModule` or :py:class:`.TorchModule` as base classes, respectively. To get started with the Jax backend, see :ref:`/in-depth/api-functional.ipynb` and :ref:`/in-depth/jax-training.ipynb`. To get started with the Torch backend, see :ref:`/in-depth/torch-api.ipynb` and :ref:`/in-depth/torch-training.ipynb`. """
[docs] def timed(self, output_num: int = 0, dt: float = None, add_events: bool = False): """ Convert this module to a :py:class:`.TimedModule` Args: output_num (int): Specify which output of the module to take, if the module returns multiple output series. Default: ``0``, take the first (or only) output. dt (float): Used to provide a time-step for this module, if the module does not already have one. If ``self`` already defines a time-step, then ``self.dt`` will be used. Default: ``None`` add_events (bool): Iff ``True``, the :py:class:`.TimedModule` will add events occurring on a single timestep on input and output. Default: ``False``, don't add time steps. Returns: :py:class:`.TimedModule`: A timed module that wraps this module """ from rockpool.nn.modules.timed_module import TimedModuleWrapper return TimedModuleWrapper( self, output_num=output_num, dt=dt, add_events=add_events )
[docs] def _auto_batch( self, data: np.ndarray, states: Tuple = (), target_shapes: Tuple = None, ) -> Tuple[np.ndarray, Tuple[np.ndarray]]: """ Automatically replicate states over batches and verify input dimensions Examples: >>> data, (state0, state1, state2) = self._auto_batch(data, (self.state0, self.state1, self.state2)) This will verify that ``data`` has the correct final dimension (i.e. ``self.size_in``). If ``data`` has only two dimensions ``(T, Nin)``, then it will be augmented to ``(1, T, Nin)``. The individual states will be replicated out from shape ``(a, b, c, ...)`` to ``(n_batches, a, b, c, ...)`` and returned. If ``data`` has only a single dimension ``(T,)``, it will be expanded to ``(1, T, self.size_in)``. ``state0``, ``state1``, ``state2`` will be replicated out along the batch dimension. >>> data, (state0,) = self._auto_batch(data, (self.state0,), ((10, -1, self.size_in),)) Attempt to replicate ``state0`` to a specified size ``(10, -1, self.size_in)``. Args: data (np.ndarray): Input data tensor. Either ``(batches, T, Nin)`` or ``(T, Nin)`` states (Tuple): Tuple of state variables. Each will be replicated out over batches by prepending a batch dimension target_shapes (Tuple): A tuple of target size tuples, each corresponding to each state argument. The individual states will be replicated out to match the corresponding target sizes. If not provided (the default), then states will be only replicated along batches. Returns: (np.ndarray, Tuple[np.ndarray]) data, states """ # - Ensure data is a float tensor data = np.array(data, "float") # - Verify input data shape if len(data.shape) == 1: data = np.expand_dims(data, 0) data = np.expand_dims(data, 2) elif len(data.shape) == 2: data = np.expand_dims(data, 0) if data.shape[-1] == 1: data = np.broadcast_to(data, (data.shape[0], data.shape[1], self.size_in)) # - Get shape of input (n_batches, time_steps, n_connections) = data.shape # - Check input dimensions if n_connections != self.size_in: raise ValueError( "Input has wrong neuron dimension. It is {}, must be {}".format( n_connections, self.size_in ) ) # - Get target shapes if target_shapes is None: target_shapes = tuple(s.shape for s in states) else: target_shapes = tuple( s.shape if shape is None else shape for s, shape in zip(states, target_shapes) ) # - Replicate shapes and return states = tuple( np.ones((n_batches, *shape)) * s for s, shape in zip(states, target_shapes) ) return data, states