"""
Basic computational modules for graph definition in Rockpool
Defines :py:class:`.LinearWeights`, :py:class:`.GenericNeurons`, :py:class:`.AliasConnection` and :py:class:`.LIFNeuronRealValue`.
"""
from rockpool.graph.graph_base import GraphModule
from dataclasses import dataclass, field
from typing import Optional
from rockpool.typehints import FloatVector
import numpy as np
from rockpool.utilities.backend_management import backend_available
if backend_available("torch"):
    from torch import Tensor
else:
    class Tensor:
        pass
__all__ = [
    "LinearWeights",
    "GenericNeurons",
    "AliasConnection",
    "LIFNeuronWithSynsRealValue",
    "RateNeuronWithSynsRealValue",
]
[docs]@dataclass(eq=False, repr=False)
class LinearWeights(GraphModule):
    """
    A :py:class:`.GraphModule` that encapsulates a single set of linear weights
    """
    weights: FloatVector
    """ FloatVector: The linear weights ``(Nin, Nout)`` encapsulated by this module """
    biases: Optional[FloatVector] = None
    """ FloatVector: The biases ``(Nout,)`` encapsulaed by this module """
    def __post_init__(self, *args, **kwargs):
        # - Check size
        if self.weights.shape != (len(self.input_nodes), len(self.output_nodes)):
            raise ValueError(
                f"`weights` must match size of input and output nodes. Got {self.weights.shape}, expected {(len(self.input_nodes), len(self.output_nodes))}."
            )
        if self.biases is not None and self.biases.shape != (len(self.output_nodes),):
            raise ValueError(
                f"`biases` must match size of input and output nodes. Got {self.biases.shape}, expected {(len(self.output_nodes),)}."
            )
        super().__post_init__(*args, **kwargs)
        # - Convert weights and biases to numpy arrays
        if isinstance(self.weights, Tensor):
            self.weights = self.weights.detach().cpu().numpy()
        else:
            self.weights = np.array(self.weights)
        if self.biases is not None:
            if isinstance(self.biases, Tensor):
                self.biases = self.biases.detach().cpu().numpy()
            else:
                self.biases = np.array(self.biases) 
[docs]@dataclass(eq=False, repr=False)
class GenericNeurons(GraphModule):
    """
    A :py:class:`.GraphModule` than encapsulates a set of generic neurons
    This class is used as a base class for all specific neuron subclasses. It defines only input and output nodes, and does not specify any parameters for the neurons.
    """
    pass 
[docs]@dataclass(eq=False, repr=False)
class AliasConnection(GraphModule):
    """
    A :py:class:`.GraphModule` that encapsulates a set of alias connections
    """
    def __post_init__(self, *args, **kwargs):
        # - Call super-class
        super().__post_init__(*args, **kwargs)
        # - Check size
        if len(self.input_nodes) != len(self.output_nodes):
            raise ValueError(
                f"For an alias connection, the number of inputs and outputs must be identical.\nGot {len(self.input_nodes)} and {len(self.output_nodes)}."
            ) 
[docs]@dataclass(eq=False, repr=False)
class LIFNeuronWithSynsRealValue(GenericNeurons):
    """
    A :py:class:`.GraphModule` that encapsulates a set of LIF spiking neurons with synaptic and membrane dynamics, and with real-valued parameters
    """
    tau_mem: FloatVector = field(default_factory=list)
    """ Floatvector: The membrane time constants of these neurons, in seconds ``(Nout,)`` """
    tau_syn: FloatVector = field(default_factory=list)
    """ Floatvector: The synaptic time constants of these neurons, in seconds ``(Nin,)`` """
    threshold: FloatVector = field(default_factory=list)
    """ Floatvector: The firing threshold parameters of these neurons ``(Nout,)`` """
    bias: FloatVector = field(default_factory=list)
    """ Floatvector: The bias parameters of these neurons, if present ``(Nout,)`` """
    dt: Optional[float] = None
    """ float: The time-step used for these neurons in seconds, if present """ 
[docs]@dataclass(eq=False, repr=False)
class RateNeuronWithSynsRealValue(GenericNeurons):
    """
    A :py:class:`.GraphModule` that encapsulates a set of rate neurons, with synapses, and with real-valued parameters
    """
    tau: FloatVector = field(default_factory=list)
    """ Floatvector: The time constants of these neurons, in seconds ``(Nout,)`` """
    bias: FloatVector = field(default_factory=list)
    """ Floatvector: The bias parameters of these neurons ``(Nout,)`` """
    dt: Optional[float] = None
    """ float: The time-step used for these neurons in seconds, if present """