"""
IAF neurons layers with Brian2 backend
"""
# - Imports
from warnings import warn
import brian2 as b2
import brian2.numpy_ as np
from brian2.units.stdunits import *
from brian2.units.allunits import *
from rockpool.utilities.timedarray_shift import TimedArray as TAShift
from rockpool.timeseries import TSContinuous, TSEvent
from rockpool.nn.layers.layer import Layer
from rockpool.typehints import FloatVector
from typing import Optional, Union, Tuple, List, Any
# - Type alias for array-like objects
ArrayLike = FloatVector
from rockpool.nn.modules.timed_module import astimedmodule
# - Configure exports
__all__ = [
"FFIAFBrian",
"RecIAFBrian",
"FFIAFBrianBase",
"FFIAFSpkInBrian",
"RecIAFBrianBase",
"RecIAFSpkInBrian",
"eqNeuronIAFFF",
"eqNeuronCLIAFFF",
"eqNeuronIAFSpkInFF",
"eqNeuronIAFRec",
"eqNeuronIAFSpkInRec",
"eqSynapseExp",
"eqSynapseExpSpkInRec",
]
# - Equations for an integrate-and-fire neuron, ff-layer, analogue external input
eqNeuronIAFFF = b2.Equations(
"""
dv/dt = (v_rest - v + r_m * I_total) / tau_m : volt (unless refractory) # Neuron membrane voltage
I_total = I_inp(t, i) + I_bias : amp # Total input current
I_bias : amp # Per-neuron bias current
v_rest : volt # Rest potential
tau_m : second # Membrane time constant
r_m : ohm # Membrane resistance
v_thresh : volt # Firing threshold potential
v_reset : volt # Reset potential
"""
)
# - Equations for an integrate-and-fire neuron, ff-layer, analogue external input, constant leak
eqNeuronCLIAFFF = b2.Equations(
"""
dv/dt = (r_m * I_total) / tau_m : volt (unless refractory) # Neuron membrane voltage
I_total = I_inp(t, i) + I_bias : amp # Total input current
I_bias : amp # Per-neuron bias current
v_rest : volt # Rest potential
tau_m : second # Membrane time constant
r_m : ohm # Membrane resistance
v_thresh : volt # Firing threshold potential
v_reset : volt # Reset potential
"""
)
# - Equations for an integrate-and-fire neuron, ff-layer, spiking external input
eqNeuronIAFSpkInFF = b2.Equations(
"""
dv/dt = (v_rest - v + r_m * I_total) / tau_m : volt (unless refractory) # Neuron membrane voltage
I_total = I_syn + I_bias + I_inp(t, i) : amp # Total input current
dI_syn/dt = -I_syn / tau_s : amp # Synaptic input current
I_bias : amp # Per-neuron bias current
v_rest : volt # Rest potential
tau_m : second # Membrane time constant
tau_s : second # Membrane time constant
r_m : ohm # Membrane resistance
v_thresh : volt # Firing threshold potential
v_reset : volt # Reset potential
"""
)
# - Equations for an integrate-and-fire neuron, recurrent layer, analogue external input
eqNeuronIAFRec = b2.Equations(
"""
dv/dt = (v_rest - v + r_m * I_total) / tau_m : volt (unless refractory) # Neuron membrane voltage
I_total = I_inp(t, i) + I_syn + I_bias : amp # Total input current
I_bias : amp # Per-neuron bias current
v_rest : volt # Rest potential
tau_m : second # Membrane time constant
r_m : ohm # Membrane resistance
v_thresh : volt # Firing threshold potential
v_reset : volt # Reset potential
"""
)
# - Equations for an integrate-and-fire neuron, recurrent layer, spiking external input
eqNeuronIAFSpkInRec = b2.Equations(
"""
dv/dt = (v_rest - v + r_m * I_total) / tau_m : volt (unless refractory) # Neuron membrane voltage
I_total = I_inp(t, i) + I_syn + I_bias : amp # Total input current
I_syn = I_syn_inp + I_syn_rec : amp # Synaptic currents
I_bias : amp # Per-neuron bias current
v_rest : volt # Rest potential
tau_m : second # Membrane time constant
r_m : ohm # Membrane resistance
v_thresh : volt # Firing threshold potential
v_reset : volt # Reset potential
"""
)
# - Equations for an exponential synapse - used for RecIAFBrian
eqSynapseExp = b2.Equations(
"""
dI_syn/dt = -I_syn / tau_s : amp # Synaptic current
tau_s : second # Synapse time constant
"""
)
# - Equations for two exponential synapses (spiking external input and recurrent) for RecIAFSpkInBrian
eqSynapseExpSpkInRec = b2.Equations(
"""
dI_syn_inp/dt = -I_syn_inp / tau_syn_inp : amp # Synaptic current, input synapses
dI_syn_rec/dt = -I_syn_rec / tau_syn_rec : amp # Synaptic current, recurrent synapses
tau_syn_inp : second # Synapse time constant, input
tau_syn_rec : second # Synapse time constant, recurrent
"""
)
class FFIAFBrianBase(Layer):
"""A spiking feedforward layer with current inputs and spiking outputs"""
## - Constructor
def __init__(
self,
weights: np.ndarray,
bias: FloatVector = 15 * mA,
dt: float = 0.1 * ms,
noise_std: float = 0 * mV,
tau_mem: FloatVector = 20 * ms,
v_thresh: FloatVector = -55 * mV,
v_reset: FloatVector = -65 * mV,
v_rest: FloatVector = -65 * mV,
refractory: float = 0 * ms,
neuron_eq: Union[b2.Equations, str] = eqNeuronIAFFF,
integrator_name: str = "rk4",
name: str = "unnamed",
record: bool = False,
):
"""
Construct a spiking feedforward layer with IAF neurons, with a Brian2 back-end. Inputs are continuous currents; outputs are spiking events
:param np.array weights: Layer weight matrix [N_in, N]
:param nparray bias: Nx1 bias vector. Default: ``10mA``
:param float dt: Time-step. Default: ``0.1 ms``
:param float noise_std: Noise std. dev. per second. Default:`` 0.``
:param FloatVector tau_mem: Nx1 vector of neuron time constants. Default: ``20ms``
:param FloatVector v_thresh: Nx1 vector of neuron thresholds. Default: ``-55mV``
:param FloatVector v_reset: Nx1 vector of neuron thresholds. Default: ``-65mV``
:param FloatVector v_rest: Nx1 vector of neuron thresholds. Default: ``-65mV``
:param float refractory: Refractory period after each spike. Default: ``0ms``
:param Union[Brian2.Equations, str] neuron_eq: Set of neuron equations. Default: IAF equation set
:param str integrator_name: Integrator to use for simulation. Default: ``'rk4'``
:param str name: Name for the layer. Default: ``'unnamed'``
:param bool record: Record membrane potential during evolutions
"""
# - Call super constructor (`asarray` is used to strip units)
super().__init__(
weights=np.asarray(weights),
dt=np.asarray(dt),
noise_std=np.asarray(noise_std),
name=name,
)
# - Set up layer neurons
self._neuron_group = b2.NeuronGroup(
self.size,
neuron_eq,
threshold="v > v_thresh",
reset="v = v_reset",
refractory=np.asarray(refractory) * second,
method=integrator_name,
dt=np.asarray(dt) * second,
name="spiking_ff_neurons",
)
self._neuron_group.v = v_rest
self._neuron_group.r_m = 1 * ohm
# - Add monitors to record layer outputs
self._layer = b2.SpikeMonitor(
self._neuron_group, record=True, name="layer_spikes"
)
# - Call Network constructor
self._net = b2.Network(self._neuron_group, self._layer, name="ff_spiking_layer")
if record:
# - Monitor for recording network potential
self.state_monitor = b2.StateMonitor(
self._neuron_group, ["v"], record=True, name="layer_potential"
)
self._net.add(self.state_monitor)
# - Record neuron parameters
self.v_thresh = v_thresh
self.v_reset = v_reset
self.v_rest = v_rest
self.tau_mem = tau_mem
self.bias = bias
self.weights = weights
# - Store "reset" state
self._net.store("reset")
def reset_state(self):
"""Reset the internal state of the layer"""
self._neuron_group.v = self.v_rest * volt
def randomize_state(self):
"""Randomize the internal state of the layer"""
v_range = abs(self.v_thresh - self.v_reset)
self._neuron_group.v = (
np.random.rand(self.size) * v_range + self.v_reset
) * volt
def reset_time(self):
"""Reset the internal clock of this layer"""
# - Sotre state variables
v_state = np.copy(self._neuron_group.v) * volt
# - Store parameters
v_thresh = np.copy(self.v_thresh)
v_reset = np.copy(self.v_reset)
v_rest = np.copy(self.v_rest)
tau_mem = np.copy(self.tau_mem)
bias = np.copy(self.bias)
weights = np.copy(self.weights)
# - Reset network
self._net.restore("reset")
self._timestep = 0
# - Restork parameters
self.v_thresh = v_thresh
self.v_reset = v_reset
self.v_rest = v_rest
self.tau_mem = tau_mem
self.bias = bias
self.weights = weights
# - Restore state variables
self._neuron_group.v = v_state
### --- State evolution
def evolve(
self,
ts_input: Optional[TSContinuous] = None,
duration: Optional[float] = None,
num_timesteps: Optional[int] = None,
verbose: bool = False,
) -> TSEvent:
"""
Function to evolve the states of this layer given an input
:param Optional[`.TSContinuous`] ts_input: Input time series
:param Optional[float] duration: Simulation/Evolution time
:param Optional[int] num_timesteps: Number of evolution time steps
:param bool verbose: Currently no effect, just for conformity
:return `.TSEvent`: Output spike series
"""
# - Prepare time base
time_base, input_steps, num_timesteps = self._prepare_input(
ts_input, duration, num_timesteps
)
# - Weight inputs
neuron_inp_step = input_steps @ self.weights
# - Generate a noise trace
noise_step = (
np.random.randn(np.size(time_base), self.size)
# - Standard deviation slightly smaller than expected (due to brian??),
# therefore correct with empirically found factor 1.63
* self.noise_std
* np.sqrt(2.0 * self.tau_mem / self.dt)
* 1.63
)
# - Specifiy network input currents, construct TimedArray
inp_current = TAShift(
np.asarray(neuron_inp_step + noise_step) * amp,
self.dt * second,
tOffset=self.t * second,
name="external_input",
)
# - Perform simulation
self._net.run(
num_timesteps * self.dt * second, namespace={"I_inp": inp_current}, level=0
)
# - Start and stop times for output time series
t_start = self._timestep * float(self.dt)
t_stop = (self._timestep + num_timesteps) * float(self.dt)
# - Update layer time step
self._timestep += num_timesteps
# - Build response TimeSeries
use_event = self._layer.t_ >= time_base[0]
# Shift event times to middle of time bins
event_time_out = self._layer.t_[use_event] - 0.5 * self.dt
event_channel_out = self._layer.i[use_event]
return TSEvent(
np.clip(event_time_out, t_start, t_stop),
event_channel_out,
name="Layer spikes",
num_channels=self.size,
t_start=t_start,
t_stop=t_stop,
)
def stream(
self, duration: float, dt: float, verbose: bool = False
) -> Tuple[float, List[float]]:
"""
Stream data through this layer
:param float duration: Total duration for which to handle streaming
:param float dt: Streaming time step
:param bool verbose: Display feedback
:yield: (event_times, event_channels)
:return: Final (event_times, event_channels)
"""
# - Initialise simulation, determine how many dt to evolve for
if verbose:
print("Layer: I'm preparing")
time_trace = np.arange(0, duration + dt, dt)
num_steps = np.size(time_trace) - 1
# - Generate a noise trace
noise_step = (
np.random.randn(np.size(time_trace), self.size)
# - Standard deviation slightly smaller than expected (due to brian??),
# therefore correct with empirically found factor 1.63
* self.noise_std
* np.sqrt(2.0 * self.tau_mem / self.dt)
* 1.63
)
# - Generate a TimedArray to use for step-constant input currents
inp_current = TAShift(
np.zeros((1, self._size_in)) * amp, self.dt * second, name="external_input"
)
if verbose:
print("Layer: Prepared")
# - Loop over dt steps
for step in range(num_steps):
if verbose:
print("Layer: Yielding from internal state.")
if verbose:
print("Layer: step", step)
if verbose:
print("Layer: Waiting for input...")
# - Yield current output spikes, receive input for next time step
use_events = self._layer.t_ >= time_trace[step]
if verbose:
print("Layer: Yielding {} spikes".format(np.sum(use_events)))
inp = (yield self._layer.t_[use_events], self._layer.i_[use_events])
# - Specify network input currents for this streaming step
if inp is None:
inp_current.values = noise_step[step, :]
else:
inp_current.values = (
np.reshape(inp[1][0, :], (1, -1)) + noise_step[step, :]
)
# - Reinitialise TimedArray
inp_current._init_2d()
if verbose:
print("Layer: Input was: ", inp)
# - Evolve layer (increments time implicitly)
self._net.run(dt * second, namespace={"I_inp": inp_current}, level=0)
# - Return final spikes, if any
use_events = self._layer.t_ >= time_trace[-2] # Should be duration - dt
return self._layer.t_[use_events], self._layer.i_[use_events]
def to_dict(self) -> dict:
"""
Convert parameters of `self` to a dict if they are relevant for reconstructing an identical layer.
"""
config = super().to_dict()
config["bias"] = self.bias.tolist()
config["tau_mem"] = self.tau_mem.tolist()
config["v_thresh"] = self.v_thresh.tolist()
config["v_reset"] = self.v_reset.tolist()
config["v_rest"] = self.v_rest.tolist()
config["refractory"] = self.refractory
config["neuron_eq"] = self._neuron_group.equations
config["integrator_name"] = self._neuron_group.method
config["record"] = hasattr(self, "state_monitor")
return config
### --- Properties
@property
def output_type(self):
"""(`.TSEvent`) Output time series class for this layer (`.TSEvent`)"""
return TSEvent
@property
def refractory(self):
"""Returns the refractory period"""
return self._neuron_group._refractory
@property
def state(self):
"""Returns the membrane potentials"""
return self._neuron_group.v_
@state.setter
def state(self, new_state):
self._neuron_group.v = (
np.asarray(self._expand_to_net_size(new_state, "new_state")) * volt
)
@property
def tau_mem(self):
"""Return the membrane time constants"""
return self._neuron_group.tau_m_
@tau_mem.setter
def tau_mem(self, new_tau_mem):
self._neuron_group.tau_m = (
np.asarray(self._expand_to_net_size(new_tau_mem, "new_tau_mem")) * second
)
@property
def bias(self):
"""Retruns the biases"""
return self._neuron_group.I_bias_
@bias.setter
def bias(self, new_bias):
self._neuron_group.I_bias = (
np.asarray(self._expand_to_net_size(new_bias, "new_bias")) * amp
)
@property
def v_thresh(self):
"""Returns the spiking threshold"""
return self._neuron_group.v_thresh_
@v_thresh.setter
def v_thresh(self, new_v_thresh):
self._neuron_group.v_thresh = (
np.asarray(self._expand_to_net_size(new_v_thresh, "new_v_thresh")) * volt
)
@property
def v_rest(self):
"""Returns the resting potential"""
return self._neuron_group.v_rest_
@v_rest.setter
def v_rest(self, new_v_rest):
self._neuron_group.v_rest = (
np.asarray(self._expand_to_net_size(new_v_rest, "new_v_rest")) * volt
)
@property
def v_reset(self):
"""Returns the reset potential"""
return self._neuron_group.v_reset_
@v_reset.setter
def v_reset(self, new_v_reset):
self._neuron_group.v_reset = (
np.asarray(self._expand_to_net_size(new_v_reset, "new_v_reset")) * volt
)
@property
def t(self):
"""Returns the current time of the simulation"""
return self._net.t_
@Layer.dt.setter
def dt(self, _):
warn("The `dt` property cannot be set for this layer")
[docs]@astimedmodule(
parameters=[
"weights",
"bias",
"tau_mem",
"v_thresh",
"v_reset",
"v_rest",
],
simulation_parameters=[
"dt",
"noise_std",
"refractory",
],
)
class FFIAFBrian(FFIAFBrianBase):
pass
[docs]@astimedmodule(
parameters=[
"weights",
"bias",
"tau_mem",
"tau_syn",
"v_thresh",
"v_reset",
"v_rest",
],
simulation_parameters=[
"dt",
"noise_std",
"refractory",
],
)
class FFIAFSpkInBrian(FFIAFBrianBase):
"""Spiking feedforward layer with spiking inputs and outputs"""
## - Constructor
def __init__(
self,
weights: np.ndarray,
bias: np.ndarray = 10 * mA,
dt: float = 0.1 * ms,
noise_std: float = 0 * mV,
tau_mem: np.ndarray = 20 * ms,
tau_syn: np.ndarray = 20 * ms,
v_thresh: np.ndarray = -55 * mV,
v_reset: np.ndarray = -65 * mV,
v_rest: np.ndarray = -65 * mV,
refractory: float = 0 * ms,
neuron_eq: str = eqNeuronIAFSpkInFF,
integrator_name: str = "rk4",
name: str = "unnamed",
record: bool = False,
):
"""
Construct a spiking feedforward layer with IAF neurons, with a Brian2 back-end. In- and outputs are spiking events
:param np.array weights: MxN weight matrix.
:param np.array bias: Nx1 bias vector. Default: 10mA
:param float dt: Time-step. Default: 0.1 ms
:param float noise_std: Noise std. dev. per second. Default: 0
:param np.array tau_mem: Nx1 vector of neuron time constants. Default: 20ms
:param np.array tau_syn: Nx1 vector of synapse time constants. Default: 20ms
:param np.array v_thresh: Nx1 vector of neuron thresholds. Default: -55mV
:param np.array v_reset: Nx1 vector of neuron thresholds. Default: -65mV
:param np.array v_rest: Nx1 vector of neuron thresholds. Default: -65mV
:param float refractory: Refractory period after each spike. Default: 0ms
:param Brian2.Equations neuron_eq: set of neuron equations. Default: IAF equation set
:param str integrator_name: Integrator to use for simulation. Default: 'rk4'
:param str name: Name for the layer. Default: 'unnamed'
:param bool record: Record membrane potential during evolutions
"""
# - Call Layer constructor
Layer.__init__(
self,
weights=weights,
dt=np.asarray(dt),
noise_std=np.asarray(noise_std),
name=name,
)
# - Set up spike source to receive spiking input
self._input_generator = b2.SpikeGeneratorGroup(
self.size_in, [0], [0 * second], dt=np.asarray(dt) * second
)
# - Set up layer neurons
self._neuron_group = b2.NeuronGroup(
self.size,
neuron_eq,
threshold="v > v_thresh",
reset="v = v_reset",
refractory=np.asarray(refractory) * second,
method=integrator_name,
dt=np.asarray(dt) * second,
name="spiking_ff_neurons",
)
self._neuron_group.v = v_rest
self._neuron_group.r_m = 1 * ohm
# - Add source -> receiver synapses
self._inp_synapses = b2.Synapses(
self._input_generator,
self._neuron_group,
model="w : 1",
on_pre="I_syn_post += w*amp",
method=integrator_name,
dt=np.asarray(dt) * second,
name="receiver_synapses",
)
self._inp_synapses.connect()
# - Add monitors to record layer outputs
self._layer = b2.SpikeMonitor(
self._neuron_group, record=True, name="layer_spikes"
)
# - Call Network constructor
self._net = b2.Network(
self._input_generator,
self._inp_synapses,
self._neuron_group,
self._layer,
name="ff_spiking_layer",
)
if record:
# - Monitor for recording network potential
self.state_monitor = b2.StateMonitor(
self._neuron_group, ["v"], record=True, name="layer_potential"
)
self._net.add(self.state_monitor)
# - Record neuron parameters
self.v_thresh = v_thresh
self.v_reset = v_reset
self.v_rest = v_rest
self.tau_mem = tau_mem
self.tau_syn = tau_syn
self.bias = bias
self.weights = weights
# - Store "reset" state
self._net.store("reset")
def evolve(
self,
ts_input: Optional[TSEvent] = None,
duration: Optional[float] = None,
num_timesteps: Optional[int] = None,
verbose: bool = False,
) -> TSEvent:
"""
Evolve the states of this layer given an input
:param Optional[`.TSEvent`] ts_input: Input spike train
:param Optional[float] duration: Simulation/Evolution time
:param Optional[int] num_timesteps: Number of evolution time steps
:param bool verbose: Currently no effect, just for conformity
:return `.TSEvent`: Output spike series
"""
# - Prepare time base
num_timesteps = self._determine_timesteps(ts_input, duration, num_timesteps)
time_base = self.t + np.arange(num_timesteps) * self.dt
# - Set spikes for spike generator
if ts_input is not None:
event_times, event_channels = ts_input(
t_start=time_base[0], t_stop=time_base[-1] + self.dt
)
self._input_generator.set_spikes(
event_channels, event_times * second, sorted=False
)
else:
self._input_generator.set_spikes([], [] * second)
# - Generate a noise trace
noise_step = (
np.random.randn(np.size(time_base), self.size)
# - Standard deviation slightly smaller than expected (due to brian??),
# therefore correct with empirically found factor 1.63
* self.noise_std
* np.sqrt(2.0 * self.tau_mem / self.dt)
* 1.63
)
# - Specifiy noise input currents, construct TimedArray
inp_noise = TAShift(
np.asarray(noise_step) * amp,
self.dt * second,
tOffset=self.t * second,
name="noise_input",
)
# - Perform simulation
self._net.run(
num_timesteps * self.dt * second, namespace={"I_inp": inp_noise}, level=0
)
# - Start and stop times for output time series
t_start = self._timestep * float(self.dt)
t_stop = (self._timestep + num_timesteps) * float(self.dt)
# - Update layer time
self._timestep += num_timesteps
# - Build response TimeSeries
use_event = self._layer.t_ >= time_base[0]
event_time_out = self._layer.t_[use_event]
event_channel_out = self._layer.i[use_event]
return TSEvent(
np.clip(event_time_out, t_start, t_stop),
event_channel_out,
name="Layer spikes",
num_channels=self.size,
t_start=t_start,
t_stop=t_stop,
)
def reset_time(self):
"""Resets the time of the simulation"""
# - Store state variables
v_state = np.copy(self._neuron_group.v) * volt
syn_inp = np.copy(self._neuron_group.I_syn) * amp
# - Store parameters
v_thresh = np.copy(self.v_thresh)
v_reset = np.copy(self.v_reset)
v_rest = np.copy(self.v_rest)
tau_mem = np.copy(self.tau_mem)
tau_syn = np.copy(self.tau_syn)
bias = np.copy(self.bias)
weights = np.copy(self.weights)
self._net.restore("reset")
self._timestep = 0
# - Restork parameters
self.v_thresh = v_thresh
self.v_reset = v_reset
self.v_rest = v_rest
self.tau_mem = tau_mem
self.tau_syn = tau_syn
self.bias = bias
self.weights = weights
# - Restore state variables
self._neuron_group.v = v_state
self._neuron_group.I_syn = syn_inp
def reset_state(self):
""".reset_state() - arguments:: reset the internal state of the layer
Usage: .reset_state()
"""
self._neuron_group.v = self.v_rest * volt
self._neuron_group.I_syn = 0 * amp
def reset_all(self, keep_params=True):
"""Resets the network completely
:param bool keep_params: Keep the current state of the network if ``True``
"""
if keep_params:
# - Store parameters
v_thresh = np.copy(self.v_thresh)
v_reset = np.copy(self.v_reset)
v_rest = np.copy(self.v_rest)
tau_mem = np.copy(self.tau_mem)
tau_syn = np.copy(self.tau_syn)
bias = np.copy(self.bias)
weights = np.copy(self.weights)
self.reset_state()
self._net.restore("reset")
self._timestep = 0
if keep_params:
# - Restork parameters
self.v_thresh = v_thresh
self.v_reset = v_reset
self.v_rest = v_rest
self.tau_mem = tau_mem
self.tau_syn = tau_syn
self.bias = bias
self.weights = weights
def randomize_state(self):
""".randomize_state() - arguments:: randomize the internal state of the layer
Usage: .randomize_state()
"""
v_range = abs(self.v_thresh - self.v_reset)
self._neuron_group.v = (
np.random.rand(self.size) * v_range + self.v_reset
) * volt
self._neuron_group.I_syn = np.random.rand(self.size) * amp
def pot_kernel(self, t):
"""pot_kernel - response of the membrane potential to an
incoming spike at a single synapse with
weight 1*amp (not considering v_rest)
"""
t = t.reshape(-1, 1)
fConst = (
self.tau_syn / (self.tau_syn - self.tau_mem) * self._neuron_group.r_m * amp
)
return fConst * (np.exp(-t / self.tau_syn) - np.exp(-t / self.tau_mem))
def train(
self,
ts_target: Any,
ts_input: TSContinuous,
is_first: bool,
is_last: bool,
method: str = "mst",
**kwargs,
) -> None:
"""
Wrapper to standardize training syntax across layers. Use specified training method to train layer for current batch.
:param Any ts_target: Target time series for current batch. Can be skipped for ``"mst"`` method.
:param TSContinuous ts_input: Input to the layer during the current batch.
:param bool is_first: Set ``True`` to indicate that this batch is the first in training procedure.
:param bool is_last: Set ``True`` to indicate that this batch is the last in training procedure.
:param str method: String indicating which training method to choose. Currently only multi-spike tempotron ("mst") is supported.
:param kwargs: ``kwargs`` will be passed on to corresponding training method. For `"mst"` method, arguments ``duration`` and ``t_start`` must be provided.
"""
# - Choose training method
if method in {"mst", "multi-spike tempotron"}:
if "duration" not in kwargs.keys():
raise TypeError(
f"FFIAFSpkInBrian `{self.name}`: For multi-spike tempotron, argument "
+ "`duration` must be provided."
)
if "t_start" not in kwargs.keys():
raise TypeError(
f"FFIAFSpkInBrian `{self.name}`: For multi-spike tempotron, argument "
+ "`t_start` must be provided."
)
self.train_mst_simple(
ts_input=ts_input, is_first=is_first, is_last=is_last, **kwargs
)
else:
raise ValueError(
f"FFIAFSpkInBrian `{self.name}`: Training method `{method}` is currently "
+ "not supported. Use `mst` for multi-spike tempotron."
)
def train_mst_simple(
self,
duration: float,
t_start: float,
ts_input: TSEvent,
target_counts: np.ndarray = None,
lambda_: float = 1e-5,
eligibility_ratio: float = 0.1,
momentum: float = 0,
is_first: bool = True,
is_last: bool = False,
verbose: bool = False,
):
"""
train_mst_simple - Use the multi-spike tempotron learning rule
from Guetig2017, in its simplified version,
where no gradients are calculated
"""
assert hasattr(self, "state_monitor"), (
"Layer needs to be instantiated with record=True for "
+ "this learning rule."
)
# - End time of current batch
t_stop = t_start + duration
if ts_input is not None:
event_times, event_channels = ts_input(t_start=t_start, t_stop=t_stop)
else:
print("No ts_input defined, assuming input to be 0.")
event_times, event_channels = [], []
# - Prepare target
if target_counts is None:
target_counts = np.zeros(self.size)
else:
assert (
np.size(target_counts) == self.size
), "Target array size must match layer size ({}).".format(self.size)
## -- Determine eligibility for each neuron and synapse
eligibility = np.zeros((self.size_in, self.size))
# - Iterate over source neurons
for source_id in range(self.size_in):
if verbose:
print(
"\rProcessing input {} of {}".format(source_id + 1, self.size_in),
end="",
)
# - Find spike timings
event_time_source = event_times[event_channels == source_id]
# - Sum individual correlations over input spikes, for all synapses
for t_spike_in in event_time_source:
# - Membrane potential between input spike time and now (transform to v_rest at 0)
v_mem = (
self.state_monitor.v.T[self.state_monitor.t_ >= t_spike_in]
- self.v_rest * volt
)
# - Kernel between input spike time and now
kernel = self.pot_kernel(
self.state_monitor.t_[self.state_monitor.t_ >= t_spike_in]
- t_spike_in
)
# - Add correlations to eligibility matrix
eligibility[source_id, :] += np.sum(kernel * v_mem)
## -- For each neuron sort eligibilities and choose synapses with largest eligibility
eligible = int(eligibility_ratio * self.size_in)
# - Mark eligible neurons
is_eligible = np.argsort(eligibility, axis=0)[:eligible:-1]
## -- Compare target number of events with spikes and perform weight updates for chosen synapses
# - Numbers of (output) spike times for each neuron
use_out_events = (self._layer.t_ >= t_start) & (self._layer.t_ <= t_stop)
spikes_out_neurons = self._layer.i[use_out_events]
spike_counts = np.array(
[np.sum(spikes_out_neurons == n_id) for n_id in range(self.size)]
)
# - Updates to eligible synapses of each neuron
updates = np.zeros(self.size)
# - Negative update if spike count too high
updates[spike_counts > target_counts] = -lambda_
# - Positive update if spike count too low
updates[spike_counts < target_counts] = lambda_
# - Reset previous weight changes that are used for momentum heuristic
if is_first:
self._dw_previous = np.zeros_like(self.weights)
# - Accumulate updates to me made to weights
dw_current = np.zeros_like(self.weights)
# - Update only eligible synapses
for target_id in range(self.size):
dw_current[is_eligible[:, target_id], target_id] += updates[target_id]
# - Include previous weight changes for momentum heuristic
dw_current += momentum * self._dw_previous
# - Perform weight update
self.weights += dw_current
# - Store weight changes for next iteration
self._dw_previous = dw_current
def to_dict(self) -> dict:
"""
to_dict - Convert parameters of `self` to a dict if they are relevant for
reconstructing an identical layer.
"""
config = super().to_dict()
config["tau_syn"] = self.tau_syn.tolist()
return config
@property
def input_type(self):
"""Returns input type class"""
return TSEvent
@property
def refractory(self):
"""Returns the refractory period"""
return self._neuron_group._refractory
@property
def weights(self):
"""Returns the weights of the connections"""
return np.array(self._inp_synapses.w).reshape(self.size_in, self.size)
@weights.setter
def weights(self, new_w):
assert (
new_w.shape == (self.size_in, self.size)
or new_w.shape == self._inp_synapses.w.shape
), "weights must be of dimensions ({}, {}) or flat with size {}.".format(
self.size_in, self.size, self.size_in * self.size
)
self._inp_synapses.w = np.array(new_w).flatten()
@property
def tau_syn(self):
"""Returns the synaptic time constants"""
return self._neuron_group.tau_s_
@tau_syn.setter
def tau_syn(self, new_tau_syn):
self._neuron_group.tau_s = (
np.asarray(self._expand_to_net_size(new_tau_syn, "new_tau_syn")) * second
)
## - RecIAFBrian - Class: define a spiking recurrent layer with exponential synaptic outputs
class RecIAFBrianBase(Layer):
"""A spiking recurrent layer with current inputs and spiking outputs, using a Brian2 backend"""
## - Constructor
def __init__(
self,
weights: np.ndarray = None,
bias: FloatVector = 10.5 * mA,
dt: float = 0.1 * ms,
noise_std: float = 0 * mV,
tau_mem: FloatVector = 20 * ms,
tau_syn_r: FloatVector = 50 * ms,
v_thresh: FloatVector = -55 * mV,
v_reset: FloatVector = -65 * mV,
v_rest: FloatVector = -65 * mV,
refractory: float = 0 * ms,
neuron_eq: str = eqNeuronIAFRec,
rec_syn_eq: str = eqSynapseExp,
integrator_name: str = "rk4",
name: str = "unnamed",
record: bool = False,
):
"""
Construct a spiking recurrent layer with IAF neurons, with a Brian2 back-end. Current input, spiking output
:param np.array weights: NxN weight matrix. Default: [100x100] unit-lambda matrix
:param np.array bias: Nx1 bias vector. Default: 10.5mA
:param np.array tau_mem: Nx1 vector of neuron time constants. Default: 20 ms
:param np.array tau_syn_r: NxN vector of recurrent synaptic time constants. Default: 50 ms
:param np.array v_thresh: Nx1 vector of neuron thresholds. Default: -55mV
:param np.array v_reset: Nx1 vector of neuron thresholds. Default: -65mV
:param np.array v_rest: Nx1 vector of neuron thresholds. Default: -65mV
:param float refractory: Refractory period after each spike. Default: 0ms
:param Brian2.Equations neuron_eq: set of neuron equations. Default: IAF equation set
:param Brian2.Equations rec_syn_eq: set of synapse equations for recurrent connects. Default: exponential
:param str integrator_name: Integrator to use for simulation. Default: 'exact'
:param str name: Name for the layer. Default: 'unnamed'
:param bool record: Record membrane potential during evolutions
"""
assert (
np.atleast_2d(weights).shape[0] == np.atleast_2d(weights).shape[1]
), "Layer `{}`: weights must be a square matrix.".format(name)
# - Call super constructor
super().__init__(
weights=weights,
dt=np.asarray(dt),
noise_std=np.asarray(noise_std),
name=name,
)
# - Set up reservoir neurons
self._neuron_group = b2.NeuronGroup(
self.size,
neuron_eq + rec_syn_eq,
threshold="v > v_thresh",
reset="v = v_reset",
refractory=np.asarray(refractory) * second,
method=integrator_name,
dt=np.asarray(dt) * second,
name="reservoir_neurons",
)
self._neuron_group.v = v_rest
self._neuron_group.r_m = 1 * ohm
# - Add recurrent weights (all-to-all)
self._rec_synapses = b2.Synapses(
self._neuron_group,
self._neuron_group,
model="w : 1",
on_pre="I_syn_post += w*amp",
method=integrator_name,
dt=np.asarray(dt) * second,
name="reservoir_recurrent_synapses",
)
self._rec_synapses.connect()
# - Add spike monitor to record layer outputs
self._spike_monitor = b2.SpikeMonitor(
self._neuron_group, record=True, name="layer_spikes"
)
# - Call Network constructor
self._net = b2.Network(
self._neuron_group,
self._rec_synapses,
self._spike_monitor,
name="recurrent_spiking_layer",
)
if record:
# - Monitor for recording network potential
self._v_monitor = b2.StateMonitor(
self._neuron_group,
["v", "I_syn", "I_total"],
record=True,
name="layer_neurons",
)
self._net.add(self._v_monitor)
# - Record neuron / synapse parameters
self.weights = weights
self.v_thresh = v_thresh
self.v_reset = v_reset
self.v_rest = v_rest
self.tau_mem = tau_mem
self.tau_syn_r = tau_syn_r
self.bias = bias
self._neuron_eq = neuron_eq
self._rec_syn_eq = rec_syn_eq
# - Store "reset" state
self._net.store("reset")
def reset_state(self):
"""Reset the internal state of the layer"""
self._neuron_group.v = self.v_rest * volt
self._neuron_group.I_syn = 0 * amp
def randomize_state(self):
"""Randomize the internal state of the layer"""
v_range = abs(self.v_thresh - self.v_reset)
self._neuron_group.v = (
np.random.rand(self.size) * v_range + self.v_reset
) * volt
self._neuron_group.I_syn = np.random.rand(self.size) * amp
def reset_time(self):
"""
Reset the internal clock of this layer
"""
# - Store state variables
v_state = np.copy(self._neuron_group.v) * volt
syn_inp = np.copy(self._neuron_group.I_syn) * amp
# - Store parameters
v_thresh = np.copy(self.v_thresh)
v_reset = np.copy(self.v_reset)
v_rest = np.copy(self.v_rest)
tau_mem = np.copy(self.tau_mem)
tau_syn_r = np.copy(self.tau_syn_r)
bias = np.copy(self.bias)
weights = np.copy(self.weights)
# - Reset network
self._net.restore("reset")
self._timestep = 0
# - Restork parameters
self.v_thresh = v_thresh
self.v_reset = v_reset
self.v_rest = v_rest
self.tau_mem = tau_mem
self.tau_syn_r = tau_syn_r
self.bias = bias
self.weights = weights
# - Restore state variables
self._neuron_group.v = v_state
self._neuron_group.I_syn = syn_inp
def to_dict(self) -> dict:
"""
Convert parameters of ``self`` to a dict if they are relevant for reconstructing an identical layer.
"""
config = super().to_dict()
config["rec_syn_eq"] = self._rec_syn_eq
config["neuron_eq"] = self._neuron_eq
config["tau_mem"] = self.tau_mem.tolist()
config["tau_syn_r"] = self.tau_syn_r.tolist()
config["v_thresh"] = self.v_thresh.tolist()
config["v_reset"] = self.v_reset.tolist()
config["v_rest"] = self.v_rest.tolist()
config["refractory"] = self.refractory
config["integrator_name"] = self._neuron_group.method
config["record"] = self.hasattr("state_monitor")
return config
### --- State evolution
def evolve(
self,
ts_input: Optional[TSContinuous] = None,
duration: Optional[float] = None,
num_timesteps: Optional[int] = None,
verbose: bool = False,
) -> TSEvent:
"""
Evolve the states of this layer given an input
:param Optional[`.TSContinuous`] ts_input: Input spike train
:param Optional[float] duration: Simulation/Evolution time
:param Optional[int] num_timesteps: Number of evolution time steps
:param bool verbose: Currently no effect, just for conformity
:return `.TSEvent`: Output spike series
"""
# - Prepare time base
time_base, input_steps, num_timesteps = self._prepare_input(
ts_input, duration, num_timesteps
)
# - Store stuff for debugging
self.time_base = time_base
self.input_steps = input_steps
self.num_timesteps = num_timesteps
# - Generate a noise trace
noise_step = (
np.random.randn(np.size(time_base), self.size)
# - Standard deviation slightly smaller than expected (due to brian??),
# therefore correct with empirically found factor 1.63
* self.noise_std
* np.sqrt(2.0 * self.tau_mem / self.dt)
* 1.63
)
# - Specifiy network input currents, construct TimedArray
inp_current = TAShift(
np.asarray(input_steps + noise_step) * amp,
self.dt * second,
tOffset=self.t * second,
name="external_input",
)
# - Perform simulation
self._net.run(
num_timesteps * self.dt * second, namespace={"I_inp": inp_current}, level=0
)
# - Start and stop times for output time series
t_start = self._timestep * float(self.dt)
t_stop = (self._timestep + num_timesteps) * float(self.dt)
# - Update layer time step
self._timestep += num_timesteps
# - Build response TimeSeries
use_event = self._spike_monitor.t_ >= time_base[0]
event_time_out = self._spike_monitor.t_[use_event]
event_channel_out = self._spike_monitor.i[use_event]
return TSEvent(
np.clip(event_time_out, t_start, t_stop),
event_channel_out,
name="Layer spikes",
num_channels=self.size,
t_start=t_start,
t_stop=t_stop,
)
### --- Properties
@property
def output_type(self):
"""(`.TSEvent`) Output time series data type for this layer (`.TSEvent`)"""
return TSEvent
@property
def weights(self):
"""(np.ndarray) Recurrent weights for this layer"""
if hasattr(self, "_rec_synapses"):
return np.reshape(self._rec_synapses.w, (self.size, -1))
else:
return self._weights
@weights.setter
def weights(self, new_w):
assert new_w is not None, "Layer `{}`: weights must not be None.".format(
self.name
)
assert np.size(new_w) == self.size**2, (
"Layer `{}`: `new_w` must have ["
+ str(self.size**2)
+ "] elements.".format(self.name)
)
self._weights = new_w
if hasattr(self, "_rec_synapses"):
# - Assign recurrent weights (need to transpose)
new_w = np.asarray(new_w).reshape(self.size, -1)
self._rec_synapses.w = new_w.flatten()
@property
def state(self):
"""(np.ndarray) Membrane potential for the neurons in this layer [N,]"""
return self._neuron_group.v_
@state.setter
def state(self, new_state):
self._neuron_group.v = (
np.asarray(self._expand_to_net_size(new_state, "new_state")) * volt
)
@property
def refractory(self):
"""(np.ndarray) Refractory period for the neurons in this layer [N,]"""
return self._neuron_group._refractory
@property
def tau_mem(self):
"""(np.ndarray) Membrane time constants for the neurons in this layer [N,]"""
return self._neuron_group.tau_m_
@tau_mem.setter
def tau_mem(self, new_tau_mem):
self._neuron_group.tau_m = (
np.asarray(self._expand_to_net_size(new_tau_mem, "new_tau_mem")) * second
)
@property
def tau_syn_r(self):
"""(np.ndarray) Synaptic time constants for recurrent synapses in this layer [N**2,]"""
return self._neuron_group.tau_s_
@tau_syn_r.setter
def tau_syn_r(self, vtNewTauSynR):
self._neuron_group.tau_s = (
np.asarray(self._expand_to_net_size(vtNewTauSynR, "vtNewTauSynR")) * second
)
@property
def bias(self):
"""(np.ndarray) Bias currents for the neurons in this layer [N,]"""
return self._neuron_group.I_bias_
@bias.setter
def bias(self, new_bias):
self._neuron_group.I_bias = (
np.asarray(self._expand_to_net_size(new_bias, "new_bias")) * amp
)
@property
def v_thresh(self):
"""(np.ndarray) Threshold potentials for the neurons in this layer [N,]"""
return self._neuron_group.v_thresh_
@v_thresh.setter
def v_thresh(self, new_v_thresh):
self._neuron_group.v_thresh = (
np.asarray(self._expand_to_net_size(new_v_thresh, "new_v_thresh")) * volt
)
@property
def v_rest(self):
"""(np.ndarray) Resting potential for the neurons in this layer [N,]"""
return self._neuron_group.v_rest_
@v_rest.setter
def v_rest(self, new_v_rest):
self._neuron_group.v_rest = (
np.asarray(self._expand_to_net_size(new_v_rest, "new_v_rest")) * volt
)
@property
def v_reset(self):
"""(np.ndarray) Reset potential for the neurons in this layer [N,]"""
return self._neuron_group.v_reset_
@v_reset.setter
def v_reset(self, new_v_reset):
self._neuron_group.v_reset = (
np.asarray(self._expand_to_net_size(new_v_reset, "new_v_reset")) * volt
)
@property
def t(self):
"""(float) Current layer time in s"""
return self._net.t_
@Layer.dt.setter
def dt(self, _):
warn(
"Layer `{}`: The `dt` property cannot be set for this layer".format(
self.name
)
)
[docs]@astimedmodule(
parameters=[
"weights",
"bias",
"tau_mem",
"tau_syn_r",
"v_thresh",
"v_reset",
"v_rest",
],
simulation_parameters=[
"dt",
"noise_std",
"refractory",
],
)
class RecIAFBrian(RecIAFBrianBase):
pass
# - Spiking recurrent layer with spiking in- and outputs
[docs]@astimedmodule(
parameters=[
"weights",
"bias",
"tau_mem",
"tau_syn_inp",
"tau_syn_rec",
"v_thresh",
"v_reset",
"v_rest",
],
simulation_parameters=[
"dt",
"noise_std",
"refractory",
],
)
class RecIAFSpkInBrian(RecIAFBrianBase):
"""Spiking recurrent layer with spiking in- and outputs, and a Brian2 backend"""
## - Constructor
def __init__(
self,
weights_in: np.ndarray,
weights_rec: np.ndarray,
bias: np.ndarray = 10.5 * mA,
dt: float = 0.1 * ms,
noise_std: float = 0 * mV,
tau_mem: np.ndarray = 20 * ms,
tau_syn_inp: np.ndarray = 50 * ms,
tau_syn_rec: np.ndarray = 50 * ms,
v_thresh: np.ndarray = -55 * mV,
v_reset: np.ndarray = -65 * mV,
v_rest: np.ndarray = -65 * mV,
refractory=0 * ms,
neuron_eq=eqNeuronIAFSpkInRec,
synapse_eq=eqSynapseExpSpkInRec,
integrator_name: str = "rk4",
name: str = "unnamed",
record: bool = False,
):
"""
Construct a spiking recurrent layer with IAF neurons, with a Brian2 back-end. In- and outputs are spiking events
:param np.array weights_in: MxN input weight matrix.
:param np.array weights_rec: NxN recurrent weight matrix.
:param np.array bias: Nx1 bias vector. Default: 10.5mA
:param float dt: Time-step. Default: 0.1 ms
:param float noise_std: Noise std. dev. per second. Default: 0
:param np.array tau_mem: Nx1 vector of neuron time constants. Default: 20ms
:param np.array tau_syn_inp: Nx1 vector of synapse time constants. Default: 20ms
:param np.array tau_syn_rec: Nx1 vector of synapse time constants. Default: 20ms
:param np.array v_thresh: Nx1 vector of neuron thresholds. Default: -55mV
:param np.array v_reset: Nx1 vector of neuron thresholds. Default: -65mV
:param np.array v_rest: Nx1 vector of neuron thresholds. Default: -65mV
:param float refractory: Refractory period after each spike. Default: 0ms
:param Brian2.Equations neuron_eq: set of neuron equations. Default: IAF equation set
:param Brian2.Equations synapse_eq: set of synapse equations for recurrent connects. Default: exponential
:param str integrator_name: Integrator to use for simulation. Default: 'rk4'
:param str name: Name for the layer. Default: 'unnamed'
:param bool record: Record membrane potential during evolutions
"""
# - Call Layer constructor
Layer.__init__(
self,
weights=weights_in,
dt=np.asarray(dt),
noise_std=np.asarray(noise_std),
name=name,
)
# - Set up spike source to receive spiking input
self._input_generator = b2.SpikeGeneratorGroup(
self.size_in, [0], [0 * second], dt=np.asarray(dt) * second
)
# - Set up layer neurons
self._neuron_group = b2.NeuronGroup(
self.size,
neuron_eq + synapse_eq,
threshold="v > v_thresh",
reset="v = v_reset",
refractory=np.asarray(refractory) * second,
method=integrator_name,
dt=np.asarray(dt) * second,
name="spiking_ff_neurons",
)
self._neuron_group.v = v_rest
self._neuron_group.r_m = 1 * ohm
# - Add source -> receiver synapses
self._inp_synapses = b2.Synapses(
self._input_generator,
self._neuron_group,
model="w : 1",
on_pre="I_syn_inp_post += w*amp",
method=integrator_name,
dt=np.asarray(dt) * second,
name="receiver_synapses",
)
self._inp_synapses.connect()
# - Add recurrent synapses
self._rec_synapses = b2.Synapses(
self._neuron_group,
self._neuron_group,
model="w : 1",
on_pre="I_syn_rec_post += w*amp",
method=integrator_name,
dt=np.asarray(dt) * second,
name="recurrent_synapses",
)
self._rec_synapses.connect()
# - Add monitors to record layer outputs
self._spike_monitor = b2.SpikeMonitor(
self._neuron_group, record=True, name="layer_spikes"
)
# - Call Network constructor
self._net = b2.Network(
self._input_generator,
self._inp_synapses,
self._rec_synapses,
self._neuron_group,
self._spike_monitor,
name="rec_spiking_layer",
)
if record:
# - Monitor for recording network potential
self._v_monitor = b2.StateMonitor(
self._neuron_group,
["v", "I_syn_inp", "I_syn_rec"],
record=True,
name="layer_neurons",
)
self._net.add(self._v_monitor)
# - Record neuron parameters
self.v_thresh = v_thresh
self.v_reset = v_reset
self.v_rest = v_rest
self.tau_mem = tau_mem
self.tau_syn_inp = tau_syn_inp
self.tau_syn_rec = tau_syn_rec
self.bias = bias
self.weights_in = weights_in
self.weights_rec = weights_rec
self._neuron_eq = neuron_eq
self._synapse_eq = synapse_eq
# - Store "reset" state
self._net.store("reset")
def evolve(
self,
ts_input: Optional[TSEvent] = None,
duration: Optional[float] = None,
num_timesteps: Optional[int] = None,
verbose: bool = False,
) -> TSEvent:
"""
Evolve the states of this layer given an input
:param Optional[`.TSEvent`] ts_input: Input spike train
:param Optional[float] duration: Simulation/Evolution time
:param Optional[int] num_timesteps: Number of evolution time steps
:param bool verbose: Currently no effect, just for conformity
:return `.TSEvent`: Output spike series
"""
# - Prepare time base
num_timesteps = self._determine_timesteps(ts_input, duration, num_timesteps)
time_base = self.t + np.arange(num_timesteps) * self.dt
# - Set spikes for spike generator
if ts_input is not None:
event_times, event_channels = ts_input(
t_start=time_base[0], t_stop=time_base[-1] + self.dt
)
self._input_generator.set_spikes(
event_channels, event_times * second, sorted=False
)
else:
self._input_generator.set_spikes([], [] * second)
# - Generate a noise trace
noise_step = (
np.random.randn(np.size(time_base), self.size)
# - Standard deviation slightly smaller than expected (due to brian??),
# therefore correct with empirically found factor 1.63
* self.noise_std
* np.sqrt(2.0 * self.tau_mem / self.dt)
* 1.63
)
# - Specifiy noise input currents, construct TimedArray
inp_noise = TAShift(
np.asarray(noise_step) * amp,
self.dt * second,
tOffset=self.t * second,
name="noise_input",
)
# - Perform simulation
self._net.run(
num_timesteps * self.dt * second, namespace={"I_inp": inp_noise}, level=0
)
# - Start and stop times for output time series
t_start = self._timestep * float(self.dt)
t_stop = (self._timestep + num_timesteps) * float(self.dt)
# - Update layer time step
self._timestep += num_timesteps
# - Build response TimeSeries
use_event = self._spike_monitor.t_ >= time_base[0]
event_time_out = self._spike_monitor.t_[use_event]
event_channel_out = self._spike_monitor.i[use_event]
return TSEvent(
np.clip(event_time_out, t_start, t_stop),
event_channel_out,
name="Layer spikes",
num_channels=self.size,
t_start=t_start,
t_stop=t_stop,
)
def reset_time(self):
"""Reset the time for this layer"""
# - Store state variables
v_state = np.copy(self._neuron_group.v) * volt
v_syn_rec = np.copy(self._neuron_group.I_syn_rec) * amp
v_syn_inp = np.copy(self._neuron_group.I_syn_inp) * amp
# - Store parameters
v_thresh = np.copy(self.v_thresh)
v_reset = np.copy(self.v_reset)
v_rest = np.copy(self.v_rest)
tau_mem = np.copy(self.tau_mem)
tau_syn_inp = np.copy(self.tau_syn_inp)
tau_syn_rec = np.copy(self.tau_syn_rec)
bias = np.copy(self.bias)
weights_in = np.copy(self.weights_in)
weights_rec = np.copy(self.weights_rec)
self._net.restore("reset")
self._timestep = 0
# - Restork parameters
self.v_thresh = v_thresh
self.v_reset = v_reset
self.v_rest = v_rest
self.tau_mem = tau_mem
self.tau_syn_inp = tau_syn_inp
self.tau_syn_rec = tau_syn_rec
self.bias = bias
self.weights_in = weights_in
self.weights_rec = weights_rec
# - Restore state variables
self._neuron_group.v = v_state
self._neuron_group.I_syn_inp = v_syn_inp
self._neuron_group.I_syn_rec = v_syn_rec
def reset_state(self):
"""Reset the internal state of the layer"""
self._neuron_group.v = self.v_rest * volt
self._neuron_group.I_syn_inp = 0 * amp
self._neuron_group.I_syn_rec = 0 * amp
def reset_all(self, keep_params=True):
"""Reset all state of this layer (time and internal state)"""
if keep_params:
# - Store parameters
v_thresh = np.copy(self.v_thresh)
v_reset = np.copy(self.v_reset)
v_rest = np.copy(self.v_rest)
tau_mem = np.copy(self.tau_mem)
tau_syn_rec = np.copy(self.tau_syn_rec)
tau_syn_inp = np.copy(self.tau_syn_inp)
bias = np.copy(self.bias)
weights_in = np.copy(self.weights_in)
weights_rec = np.copy(self.weights_rec)
self.reset_state()
self._net.restore("reset")
self._timestep = 0
if keep_params:
# - Restork parameters
self.v_thresh = v_thresh
self.v_reset = v_reset
self.v_rest = v_rest
self.tau_mem = tau_mem
self.tau_syn_inp = tau_syn_inp
self.tau_syn_rec = tau_syn_rec
self.bias = bias
self.weights_in = weights_in
self.weights_rec = weights_rec
def randomize_state(self):
"""Randomize the internal state of the layer"""
v_range = abs(self.v_thresh - self.v_reset)
self._neuron_group.v = (
np.random.rand(self.size) * v_range + self.v_reset
) * volt
self._neuron_group.I_syn_inp = (
np.random.randn(self.size) * np.mean(np.abs(self.weights_in)) * amp
)
self._neuron_group.I_syn_rec = (
np.random.randn(self.size) * np.mean(np.abs(self.weights_rec)) * amp
)
def to_dict(self) -> dict:
"""
Convert parameters of `self` to a dict if they are relevant for reconstructing an identical layer
"""
config = super().to_dict()
config.pop("weights")
config.pop("tau_syn_r")
config.pop("rec_syn_eq")
config["weights_in"] = self.weights_in
config["weights_rec"] = self.weights_rec
config["tau_syn_inp"] = self.tau_syn_inp.tolist()
config["tau_syn_rec"] = self.tau_syn_rec.tolist()
config["synapse_eq"] = self._synapse_eq
return config
@property
def input_type(self):
"""(~.TSEvent`) Input time series class accepted by this layer (`.TSEvent`)"""
return TSEvent
@property
def weights(self):
"""(np.ndarray) Recurrent synaptic weights for this layer [N, N]"""
return self.weights_rec
@weights.setter
def weights(self, new_w):
self.weights_rec = new_w
@property
def weights_in(self):
"""(np.ndarray) Input weights for this layer [M, N]"""
return np.array(self._inp_synapses.w).reshape(self.size_in, self.size)
@weights_in.setter
def weights_in(self, new_w):
assert new_w is not None, "Layer `{}`: weights_in must not be None.".format(
self.name
)
assert (
new_w.shape == (self.size_in, self.size)
or new_w.shape == self._inp_synapses.w.shape
), "Layer `{}`: weights must be of dimensions ({}, {}) or flat with size {}.".format(
self.name, self.size_in, self.size, self.size_in * self.size
)
self._inp_synapses.w = np.array(new_w).flatten()
@property
def weights_rec(self):
"""(np.ndarray) Recurrent synaptic weights for this layer [N, N]"""
return np.array(self._rec_synapses.w).reshape(self.size, self.size)
@weights_rec.setter
def weights_rec(self, new_w):
assert new_w is not None, "Layer `{}`: weights_rec must not be None.".format(
self.name
)
assert (
new_w.shape == (self.size, self.size)
or new_w.shape == self._inp_synapses.w.shape
), "Layer `{}`: weights_rec must be of dimensions ({}, {}) or flat with size {}.".format(
self.name, self.size, self.size, self.size * self.size
)
self._rec_synapses.w = np.array(new_w).flatten()
@property
def tau_syn_inp(self):
"""(np.ndarray) Input synaptic time constants for this layer [M, N]"""
return self._neuron_group.tau_syn_inp
@tau_syn_inp.setter
def tau_syn_inp(self, new_tau_syn):
self._neuron_group.tau_syn_inp = (
np.asarray(self._expand_to_net_size(new_tau_syn, "tau_syn_inp")) * second
)
@property
def tau_syn_rec(self):
"""(np.ndarray) Recurrent synaptic time constants for this layer [N, N]"""
return self._neuron_group.tau_syn_rec
@tau_syn_rec.setter
def tau_syn_rec(self, new_tau_syn):
self._neuron_group.tau_syn_rec = (
np.asarray(self._expand_to_net_size(new_tau_syn, "tau_syn_rec")) * second
)
@property
def tau_syn_r(self):
print(
"Layer {}: This layer has no attribute `tau_syn_r`. ".format(self.name)
+ "You might want to consider `tau_syn_rec` or `tau_syn_inp`."
)
@tau_syn_r.setter
def tau_syn_r(self, *args, **kwargs):
print(
"Layer {}: This layer has no attribute `tau_syn_r`. ".format(self.name)
+ "You might want to consider `tau_syn_rec` or `tau_syn_inp`."
)