Source code for nn.modules.native.instant

"""
Encapsulate a simple instantaneous function as a jax module
"""

from rockpool.nn.modules import Module
from rockpool.parameters import SimulationParameter
from rockpool.typehints import P_Callable

from rockpool.utilities.backend_management import (
    backend_available,
    missing_backend_shim,
)

from warnings import warn

from typing import Callable, Union, Tuple, Any

__all__ = ["Instant", "InstantJax", "InstantTorch"]


class InstantMixin:
    """
    Wrap a callable function as an instantaneous Rockpool module
    """

    def __init__(
        self,
        shape: Union[int, tuple] = None,
        function: Callable = lambda x: x,
        *args,
        **kwargs,
    ):
        """
        Wrap a callable function as an instantaneous Rockpool module

        Args:
            shape (Optional[tuple]):
            function (Callable): A scalar function of its arguments, with a single output. Default: identity function
        """
        # - Check that a shape was provided
        if shape is None:
            raise ValueError("The `shape` argument to `Instant` may not be `None`.")

        # - Call superclass init
        super().__init__(shape=shape, *args, **kwargs)

        # - Store the function
        self.function: P_Callable = SimulationParameter(function)

        if not hasattr(self, "_auto_batch"):
            raise NotImplementedError(
                "_auto_batch must be implemented by superclasses!"
            )

    def evolve(
        self,
        input,
        record: bool = False,
    ) -> Tuple[Any, dict, dict]:
        input, _ = self._auto_batch(input)
        return self.function(input), {}, {}


[docs]class Instant(InstantMixin, Module): """ Wrap a callable function as an instantaneous Rockpool module """ pass
if backend_available("jax"): from rockpool.nn.modules.jax.jax_module import JaxModule from jax.tree_util import Partial
[docs] class InstantJax(InstantMixin, JaxModule): """ Wrap a callable function as an instantaneous Rockpool module, with a Jax backend """
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.function = Partial(self.function)
else: InstantJax = missing_backend_shim("InstantJax", "jax") if backend_available("torch"): from rockpool.nn.modules.torch.torch_module import TorchModule
[docs] class InstantTorch(InstantMixin, TorchModule): """ Wrap a callable function as an instantaneous Rockpool module, with a Torch backend """
else: InstantTorch = missing_backend_shim("InstantTorch", "torch")