Source code for nn.combinators.residual

"""
Implement the :py:class:`.Residual` combinator, with helper classes for Jax and Torch backends
"""

from rockpool.nn.modules.module import Module, ModuleBase
from rockpool.nn.combinators.sequential import SequentialMixin, JaxSequential
from rockpool.graph import AliasConnection, as_GraphHolder

from rockpool.utilities.backend_management import (
    backend_available,
    missing_backend_shim,
)

from typing import Tuple, Any


class ResidualMixin(SequentialMixin):
    """
    The base class for the :py:class:`.Residual` combinator
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        if self.size_in != self.size_out:
            raise ValueError(
                "`size_in` and `size_out` must be identical for a residual block."
            )

    def evolve(self, input_data, record: bool = False) -> Tuple[Any, Any, Any]:
        out, new_state_dict, record_dict = super().evolve(input_data, record)
        return out + input_data, new_state_dict, record_dict

    def as_graph(self):
        # - Use the Sequential `as_graph()` method
        graph = super().as_graph()

        # - Wrap it with an AliasConnection
        return as_GraphHolder(
            AliasConnection(
                graph.input_nodes,
                graph.output_nodes,
                f"Residual_{self.name}_aliases",
                None,
            )
        )


class ModResidual(ResidualMixin, Module):
    """
    The :py:class:`.Residual` combinator for native modules
    """

    pass


if backend_available("jax"):
    from rockpool.nn.modules.jax.jax_module import JaxModule
    from jax import numpy as jnp

    class JaxResidual(JaxSequential, ResidualMixin):
        """
        The :py:class:`.Residual` combinator for jax modules
        """

        pass

else:
    JaxResidual = missing_backend_shim("JaxResidual", "jax")

    class JaxModule:
        pass


if backend_available("torch"):
    from rockpool.nn.modules.torch.torch_module import TorchModule

    class TorchResidual(ResidualMixin, TorchModule):
        """
        The :py:class:`.Residual` combinator for torch modules
        """

        pass

else:
    TorchResidual = missing_backend_shim("TorchResidual", "torch")

    class TorchModule:
        pass


[docs]def Residual(*args, **kwargs) -> ModuleBase: """ Build a residual block over a sequential stack of modules :py:class:`.Residual` accepts any number of modules. The shapes of the modules must be compatible -- the output size :py:attr:`~.Module.size_out` of each module must match the input size :py:attr:`~.Module.size_in` of the following module. Examples: Build a :py:class:`.Residual` stack will be returned a :py:class:`.Module`, containing ``mod0``, ``mod1`` and ``mod2``. When evolving this stack, signals will be passed through ``mod0``, then ``mod1``, then ``mod2``: >>> Residual(mod0, mod1, mod2) Index into a :py:class:`.Residual` stack using Python indexing: >>> mod = Residual(mod0, mod1, mod2) >>> mod[0] A module with shape (xx, xx) Args: *mods: Any number of modules to connect. The :py:attr:`~.Module.size_out` attribute of one module must match the :py:attr:`~.Module.size_in` attribute of the following module. Returns: A :py:class:`.Module` subclass object that encapsulates the provided modules """ # - Check for Jax and Torch submodules for item in args: if isinstance(item, JaxModule): return JaxResidual(*args, **kwargs) if isinstance(item, TorchModule): return TorchResidual(*args, **kwargs) # - Use ModResidual if no JaxModule or TorchModule is in the submodules return ModResidual(*args, **kwargs)