"""
Spike-to-current layer with exponential synapses, with a Brian2 backend
"""
# - Imports
from warnings import warn
import brian2 as b2
import brian2.numpy_ as np
from brian2.units.stdunits import *
from brian2.units.allunits import *
from rockpool.timeseries import TSContinuous, TSEvent
from rockpool.nn.layers.layer import Layer
from rockpool.utilities.timedarray_shift import TimedArray as TAShift
from rockpool.nn.modules.timed_module import astimedmodule
from typing import Optional, Union, Tuple, List
# - Type alias for array-like objects
ArrayLike = Union[np.ndarray, List, Tuple]
# - Configure exports
__all__ = ["FFExpSynBrian", "eqSynapseExp"]
# - Equations for an exponential synapse
eqSynapseExp = b2.Equations(
"""
dI_syn/dt = (-I_syn + I_inp(t, i)) / tau_s : amp # Synaptic current
tau_s : second # Synapse time constant
"""
)
## - FFExpSynBrian - Class: define an exponential synapse layer (spiking input)
[docs]@astimedmodule(
parameters=["weights", "tau_syn"],
states=["state"],
simulation_parameters=["dt", "noise_std"],
)
class FFExpSynBrian(Layer):
"""Define an exponential synapse layer (spiking input), with a Brian2 backend"""
## - Constructor
def __init__(
self,
weights: Union[np.ndarray, int] = None,
dt: float = 0.1 * ms,
noise_std: float = 0 * mV,
tau_syn: float = 5 * ms,
synapse_eq=eqSynapseExp,
integrator_name: str = "rk4",
name: str = "unnamed",
):
"""
Construct an exponential synapse layer (spiking input), with a Brian2 backend
:param weights: np.array MxN weight matrix
int Size of layer -> creates one-to-one conversion layer
:param dt: float Time step for state evolution. Default: 0.1 ms
:param noise_std: float Std. dev. of noise added to this layer. Default: 0
:param tau_syn: float Output synaptic time constants. Default: 5ms
:param synapse_eq: Brian2.Equations set of synapse equations for receiver. Default: exponential
:param integrator_name: str Integrator to use for simulation. Default: 'exact'
:param name: str Name for the layer. Default: 'unnamed'
"""
warn(
"FFExpSynBrian - This layer is deprecated. You can use FFExpSyn or FFExpSynTorch instead."
)
# - Provide default dt
if dt is None:
dt = 0.1 * ms
# - Provide default weight matrix for one-to-one conversion
if isinstance(weights, int):
weights = np.identity(weights, "float")
# - Call super constructor
super().__init__(weights=weights, dt=dt, noise_std=noise_std, name=name)
# - Set up spike source to receive spiking input
self._input_generator = b2.SpikeGeneratorGroup(
self.size_in, [0], [0 * second], dt=np.asarray(dt) * second
)
# - Set up layer receiver nodes
self._neuron_group = b2.NeuronGroup(
self.size,
synapse_eq,
refractory=False,
method=integrator_name,
dt=np.asarray(dt) * second,
name="receiver_neurons",
)
# - Add source -> receiver synapses
self._inp_synapses = b2.Synapses(
self._input_generator,
self._neuron_group,
model="w : 1",
on_pre="I_syn_post += w*amp",
method=integrator_name,
dt=np.asarray(dt) * second,
name="receiver_synapses",
)
self._inp_synapses.connect()
# - Add current monitors to record reservoir outputs
self._state_monitor = b2.StateMonitor(
self._neuron_group, "I_syn", True, name="receiver_synaptic_currents"
)
# - Call Network constructor
self._net = b2.Network(
self._input_generator,
self._neuron_group,
self._inp_synapses,
self._state_monitor,
name="ff_spiking_to_exp_layer",
)
# - Record layer parameters, set weights
self.weights = weights
self.tau_syn = tau_syn
# - Store "reset" state
self._net.store("reset")
def reset_state(self):
"""Reset the internal state of the layer"""
self._neuron_group.I_syn = 0 * amp
def randomize_state(self):
"""Randomize the internal state of the layer"""
self.reset_state()
def reset_time(self):
"""
Reset the internal clock of this layer
"""
# - Sotre state variables
syn_inp = np.copy(self._neuron_group.I_syn) * amp
# - Store parameters
tau_syn = np.copy(self.tau_syn)
weights = np.copy(self.weights)
# - Reset network
self._net.restore("reset")
self._timestep = 0
# - Restork parameters
self.tau_syn = tau_syn
self.weights = weights
# - Restore state variables
self._neuron_group.I_syn = syn_inp
### --- State evolution
def evolve(
self,
ts_input: Optional[TSEvent] = None,
duration: Optional[float] = None,
num_timesteps: Optional[int] = None,
verbose: bool = False,
) -> TSContinuous:
"""
Function to evolve the states of this layer given an input
:param Optional[TSEvent] ts_input: TSEvent Input spike trian
:param Optional[float] duration: Simulation/Evolution time
:param Optional[int] num_timesteps: Number of evolution time steps
:param bool verbose: Currently no effect, just for conformity
:return TSContinuous: output spike series
"""
# - Prepare time base
time_base, __, num_timesteps = self._prepare_input(
ts_input, duration, num_timesteps
)
# - Set spikes for spike generator
if ts_input is not None:
event_times, event_channels = ts_input(
t_start=time_base[0], t_stop=time_base[-1] + self.dt
)
self._input_generator.set_spikes(
event_channels, event_times * second, sorted=False
)
else:
self._input_generator.set_spikes([], [] * second)
# - Generate a noise trace
noise_step = (
np.random.randn(np.size(time_base), self.size)
* self.noise_std
* np.sqrt(2 * self.tau_syn / self.dt)
)
# noise_step = np.zeros((np.size(time_base), self.size))
# noise_step[0,:] = self.noise_std
# - Specifiy noise input currents, construct TimedArray
inp_noise = TAShift(
np.asarray(noise_step) * amp,
self.dt * second,
tOffset=self.t * second,
name="noise_input",
)
# - Perform simulation
self._net.run(
num_timesteps * self.dt * second, namespace={"I_inp": inp_noise}, level=0
)
self._timestep += num_timesteps
# - Build response TimeSeries
time_base_out = self._state_monitor.t_
use_time = self._state_monitor.t_ >= time_base[0]
time_base_out = time_base_out[use_time]
a = self._state_monitor.I_syn_.T
a = a[use_time, :]
# - Return the current state as final time point
if time_base_out[-1] != self.t:
time_base_out = np.concatenate((time_base_out, [self.t]))
a = np.concatenate((a, np.reshape(self.state, (1, self.size))))
return TSContinuous(time_base_out, a, name="Receiver current")
### --- Properties
@property
def input_type(self):
return TSEvent
@property
def weights(self):
if hasattr(self, "_inp_synapses"):
return np.reshape(self._inp_synapses.w, (self.size, -1))
else:
return self._weights
@weights.setter
def weights(self, new_w):
assert np.size(new_w) == self.size * self.size_in, (
"`new_w` must have [" + str(self.size * self.size_in) + "] elements."
)
self._weights = new_w
if hasattr(self, "_inp_synapses"):
# - Assign recurrent weights
new_w = np.asarray(new_w).reshape(self.size, -1)
self._inp_synapses.w = new_w.flatten()
@property
def state(self):
return self._neuron_group.I_syn_
@state.setter
def state(self, new_state):
self._neuron_group.I_syn = (
np.asarray(self._expand_to_net_size(new_state, "new_state")) * amp
)
@property
def tau_syn(self):
return self._neuron_group.tau_s_[0]
@tau_syn.setter
def tau_syn(self, new_tau_syn):
self._neuron_group.tau_s = np.asarray(new_tau_syn) * second
@property
def t(self):
return self._net.t_
@Layer.dt.setter
def dt(self, _):
warn("The `dt` property cannot be set for this layer")
def to_dict(self):
d = super().to_dict()
d["tau_syn"] = self.tau_syn
return d