Source code for nn.modules.native.rate

"""
Contains an implementation of a non-spiking rate module
"""

# - Rockpool imports
from rockpool.nn.modules.module import Module
from rockpool.parameters import Parameter, State, SimulationParameter
from .linear import kaiming, unit_eigs

from rockpool.graph import (
    RateNeuronWithSynsRealValue,
    LinearWeights,
    GraphModuleBase,
    as_GraphHolder,
)

# -- Imports
import numpy as np
import numpy.random as rand

from typing import Optional, Union, Any, Callable, Tuple
from rockpool.typehints import FloatVector, P_float, P_int, P_tensor

__all__ = ["Rate"]


# -- Define useful neuron transfer functions
def H_ReLU(x: FloatVector, threshold: FloatVector) -> FloatVector:
    return (x - threshold) * ((x - threshold) > 0.0)


def H_tanh(x: FloatVector, threshold: FloatVector) -> FloatVector:
    return np.tanh(x - threshold)


def H_sigmoid(x: FloatVector, threshold: FloatVector) -> FloatVector:
    return (np.tanh(x - threshold) + 1) / 2


[docs]class Rate(Module): """ Encapsulates a population of rate neurons, supporting feed-forward and recurrent modules Examples: Instantiate a feed-forward module with 8 neurons: >>> mod = Rate(8,) RateEulerJax 'None' with shape (8,) Instantiate a recurrent module with 12 neurons: >>> mod_rec = Rate(12, has_rec = True) RateEulerJax 'None' with shape (12,) Instantiate a feed-forward module with defined time constants: >>> mod = Rate(7, tau = np.arange(7,) * 10e-3) RateEulerJax 'None' with shape (7,) This module implements the update equations: .. math:: \dot{X} = -X + i(t) + W_{rec} H(X) + bias + \sigma \zeta_t X = X + \dot{x} * dt / \tau H(x, t) = relu(x, t) = (x - t) * ((x - t) > 0) """
[docs] def __init__( self, shape: Union[int, Tuple[int, int], Tuple[int]], tau: Optional[FloatVector] = None, bias: Optional[FloatVector] = None, threshold: Optional[FloatVector] = None, w_rec: Optional[np.ndarray] = None, weight_init_func: Callable = unit_eigs, has_rec: bool = False, activation_func: Union[str, Callable] = H_ReLU, dt: float = 1e-3, noise_std: float = 0.0, *args: list, **kwargs: dict, ): """ Instantiate a non-spiking rate module, either feed-forward or recurrent. Args: shape (Tuple[int]): A tuple containing the numer of this module. tau (float): A scalar or vector defining the initialisation time constants for the module. If a vector is provided, it must match the output size of the module. Default: ``20ms`` for each unit bias (float): A scalar or vector defining the initialisation bias values for the module. If a vector is provided, it must match the output size of the module. Default: ``0.`` w_rec (np.ndarray): An optional matrix defining the initialisation recurrent weights for the module. weight_init_func (Callable): A function used to initialise the recurrent weights, if used. Default: :py:func:`.unit_eigs`; initialise such that recurrent feedback has eigenvalues distributed within the unit circle. has_rec (bool): A flag parameter indicating whether the module has recurrent connections or not. Default: ``False``, no recurrent connections. activation_func (Callable): The activation function of the neurons. This can be provided as a string ``['ReLU', 'sigmoid', 'tanh']``, or as a function that accepts a vector of neural states and returns the vector of output activations. Default: ``'ReLU'``. dt (float): The Euler solver time-step. Default: ``1e-3`` noise_std (float): The std. dev. of normally-distributed noise added to the neural state at each time step. Default: ``0.`` rng_key (Any): A Jax PRNG key to initialise the module with. Default: not provided, the module PRNG will be initialised with a random number. *args: Additional positional arguments **kwargs: Additional keyword arguments """ # - Call the superclass initialiser super().__init__( shape=shape, spiking_input=False, spiking_output=False, *args, **kwargs ) if self.size_out != self.size_in: raise ValueError("Rate module must have `size_out` == `size_in`.") # - Initialise state self.x: Union[np.ndarray, State] = State( shape=self.size_out, init_func=np.zeros ) """A vector ``(N,)`` of the internal state of each neuron""" # - Should we be recurrent or FFwd? if not has_rec: # - Feed-forward mode if w_rec is not None: raise ValueError( "If `shape` is unidimensional, then `w_rec` may not be provided as an argument." ) else: # - Recurrent mode self.w_rec: P_tensor = Parameter( w_rec, family="weights", init_func=weight_init_func, shape=(self.size_out, self.size_in), ) """The recurrent weight matrix ``(N, N)`` for this module""" # - Set parameters self.tau: P_tensor = Parameter( tau, family="taus", init_func=lambda s: np.ones(s) * 20e-3, shape=[(self.size_out,), ()], ) """ The vector ``(N,)`` of time constants :math:`\\tau` for each neuron""" self.bias: P_tensor = Parameter( bias, "bias", init_func=lambda s: np.zeros(s), shape=[(self.size_out,), ()], ) """The vector ``(N,)`` of bias currents for each neuron""" self.threshold: P_tensor = Parameter( threshold, family="thresholds", shape=[(self.size_out,), ()], init_func=np.zeros, ) """ (Tensor) Unit thresholds `(Nout,)` or `()` """ self.dt: P_float = SimulationParameter(dt) """The Euler solver time step for this module""" self.noise_std: P_float = SimulationParameter(noise_std) """The std. dev. :math:`\\sigma` of noise added to internal neuron states at each time step""" # - Check and assign the activation function if isinstance(activation_func, str): # - Handle a string argument if activation_func.lower() in ["relu", "r"]: act_fn = H_ReLU elif activation_func.lower() in ["sigmoid", "sig", "s"]: act_fn = H_sigmoid elif activation_func.lower() in ["tanh", "t"]: act_fn = H_tanh else: raise ValueError( 'If `activation_func` is provided as a string argument, it must be one of ["ReLU", "sigmoid", "tanh"].' ) elif callable(activation_func): # - Handle a callable function act_fn = activation_func """The activation function of the neurons in the module""" else: raise ValueError( "Argument `activation_func` must be a string or a function." ) # - Assign activation function self.act_fn: Union[Callable, SimulationParameter] = SimulationParameter(act_fn)
[docs] def evolve( self, input_data: np.ndarray, record: bool = False, ): input_data, (x,) = self._auto_batch(input_data, (self.x,)) batches, num_timesteps, n_inputs = input_data.shape # - Get evolution constants alpha = self.dt / np.clip(self.tau, 10 * self.dt, np.inf) noise_zeta = self.noise_std * np.sqrt(self.dt) w_rec = self.w_rec if hasattr(self, "w_rec") else None # - Reservoir state step function (forward Euler solver) def forward(x, inp): """ reservoir_step() - Single step of recurrent reservoir :param x: np.ndarray Current state and activation of reservoir units :param inp: np.ndarray Inputs to each reservoir unit for the current step :return: (new_state, new_activation), (rec_input, activation) """ state, activation = x rec_input = np.dot(activation, w_rec) if w_rec is not None else 0.0 dstate = -state + inp + self.bias + rec_input state = state + dstate * alpha activation = self.act_fn(state, self.threshold) return (state, activation), (rec_input, state, activation) # - Compute random numbers for reservoir noise noise = noise_zeta * rand.normal(size=input_data.shape) inputs = input_data + noise # - Loop over time rec_inputs = np.zeros((batches, num_timesteps, self.size_out)) res_state = np.zeros((batches, num_timesteps, self.size_out)) outputs = np.zeros((batches, num_timesteps, self.size_out)) for b in range(batches): for t in range(num_timesteps): # - Solve layer dynamics for this time-step ( (x[b], _), ( this_rec_i, this_r_s, this_out, ), ) = forward((x[b], self.act_fn(x[b], self.threshold)), inputs[b, t, :]) # - Keep a record of the layer dynamics rec_inputs[b, t, :] = this_rec_i res_state[b, t, :] = this_r_s outputs[b, t, :] = this_out self.x = x[0] record_dict = {"rec_input": rec_inputs, "x": res_state} if record else {} return outputs, self.state(), record_dict
[docs] def as_graph(self) -> GraphModuleBase: # - Generate a GraphModule for the neurons neurons = RateNeuronWithSynsRealValue._factory( self.size_in, self.size_out, f"{type(self).__name__}_{self.name}_{id(self)}", self, self.tau, self.bias, self.dt, ) # - Include recurrent weights if present if len(self.attributes_named("w_rec")) > 0: # - Weights are connected over the existing input and output nodes w_rec_graph = LinearWeights( neurons.output_nodes, neurons.input_nodes, f"{type(self).__name__}_recurrent_{self.name}_{id(self)}", self, self.w_rec, ) # - Return a graph containing neurons and optional weights return as_GraphHolder(neurons)