Source code for devices.xylo.syns65302.xylo_graph_modules

"""
XyloAudio 3 graph modules for use with tracing and mapping
"""

import warnings

from rockpool.graph import (
    GenericNeurons,
    GraphModule,
    LIFNeuronWithSynsRealValue,
    replace_module,
)

import numpy as np

from typing import List, Optional, Union
from rockpool.typehints import IntVector, FloatVector

from dataclasses import dataclass, field

__all__ = ["XyloA3Neurons", "XyloA3HiddenNeurons", "XyloA3OutputNeurons"]


@dataclass(eq=False, repr=False)
class XyloA3Neurons(GenericNeurons):
    """
    Base class for all Xylo graph module classes
    """

    hw_ids: Union[IntVector, FloatVector] = field(default_factory=list)
    """ IntVector: The HW neuron IDs allocated to this graph module ``(N,)``. Empty means than no HW IDs have been allocated."""

    threshold: Union[IntVector, FloatVector] = field(default_factory=list)
    """ IntVector: The threshold parameters for each neuron ``(N,)`` """

    bias: Union[IntVector, FloatVector] = field(default_factory=list)
    """ IntVector: The bias parameters for each neuron ``(N,)`` """

    dash_mem: Union[IntVector, FloatVector] = field(default_factory=list)
    """ IntVector: The membrane decay parameters for each neuron ``(N,)`` """

    dash_syn: Union[IntVector, FloatVector] = field(default_factory=list)
    """ IntVector: The synapse decay parameters for each neuron. Either ``(N,)`` if only one synapse is used per neuron, or ``(2N,)`` if two synapses are used for each neuron (i.e. syn2). In this case, elements ``dash_syn[0:1]`` refer to the synapses of neuron ``0``, and so on. """

    dt: Optional[float] = None
    """ float: The ``dt`` time step used for this neuron module """

    @classmethod
    def _convert_from(cls, mod: GraphModule) -> GraphModule:
        if isinstance(mod, cls):
            # - No need to do anything
            return mod

        elif isinstance(mod, LIFNeuronWithSynsRealValue):
            # - Convert from a real-valued LIF neuron
            # - Get a value for `dt` to use in the conversion
            if mod.dt is None:
                raise ValueError(
                    f"Graph module of type {type(mod).__name__} has no `dt` set, so cannot convert time constants when converting to {cls.__name__}."
                )

            # - Convert TCs to dash parameters
            dash_mem = np.log2(np.array(mod.tau_mem) / mod.dt).tolist()
            dash_syn = np.log2(np.array(mod.tau_syn) / mod.dt).flatten().tolist()

            # - Get thresholds
            thresholds = np.array(mod.threshold).tolist()

            # - Get biases
            bias = np.array(mod.bias).tolist()

            # - Build a new neurons module to insert into the graph
            neurons = cls._factory(
                len(mod.input_nodes),
                len(mod.output_nodes),
                mod.name,
                mod.computational_module,
                [],  # Empty list for HW IDs
                thresholds,
                bias,
                dash_mem,
                dash_syn,
                mod.dt,
            )

            # - Replace the target module and return
            replace_module(mod, neurons)
            return neurons

        elif isinstance(mod, GenericNeurons):
            # - Try to convert as a `GenericNeurons` base class
            if type(mod) != GenericNeurons:
                # - Warn if `mod` is actually some other derived class
                #   We might be missing an explicit conversion rule in this case
                warnings.warn(
                    f"Converting module {mod} as a GenericNeurons module to {cls.__name__} . No explicit conversion rule was found for class {type(mod).__name__}."
                )

            # - Make a new module
            neurons = cls._factory(
                len(mod.input_nodes),
                len(mod.output_nodes),
                mod.name,
            )

            # - Replace the target module
            replace_module(mod, neurons)

            # - Try to set attributes of the new module
            for attr in neurons.__dataclass_fields__.keys():
                if hasattr(mod, attr):
                    setattr(neurons, attr, getattr(mod, attr))

            return neurons

        else:
            raise ValueError(
                f"Graph module of type {type(mod).__name__} cannot be converted to a {cls.__name__}"
            )


[docs]@dataclass(eq=False, repr=False) class XyloA3HiddenNeurons(XyloA3Neurons): """ A :py:class:`.graph.GraphModule` encapsulating XyloAudio 3 hidden neurons """ def __post_init__(self, *args, **kwargs): if len(self.input_nodes) != len(self.output_nodes): if len(self.input_nodes) != 2 * len(self.output_nodes): raise ValueError( "Number of input nodes must be 1* or 2* number of output nodes" ) super().__post_init__(self, *args, **kwargs)
[docs]@dataclass(eq=False, repr=False) class XyloA3OutputNeurons(XyloA3Neurons): """ A :py:class:`.graph.GraphModule` encapsulating XyloAudio 3 output neurons """ def __post_init__(self, *args, **kwargs): if len(self.input_nodes) != len(self.output_nodes): raise ValueError( "Number of input nodes must be equal to number of output nodes" ) super().__post_init__(self, *args, **kwargs)