Source code for nn.modules.native.linear

"""
Implements linear weight matrix modules
"""

from __future__ import annotations
from rockpool.nn.modules.module import Module
from rockpool.parameters import Parameter
from rockpool.graph import GraphModuleBase, LinearWeights, as_GraphHolder

import numpy as onp
from warnings import warn

from typing import Tuple, Any, Callable

from abc import ABC

__all__ = ["unit_eigs", "kaiming", "xavier", "Linear"]


def unit_eigs(s):
    return onp.random.randn(*s) / onp.sqrt(s[0])


def uniform_sqrt(s):
    lim = onp.sqrt(1 / s[0])
    return onp.random.uniform(-lim, lim, s)


def kaiming(s):
    lim = onp.sqrt(6 / s[0])
    return onp.random.uniform(-lim, lim, s)


def xavier(s):
    lim = onp.sqrt(6 / onp.sum(s))
    return onp.random.uniform(-lim, lim, s)


class LinearMixin(ABC):
    """
    Encapsulate a linear weight matrix
    """

    _dot = None

    def __init__(
        self,
        shape: tuple,
        weight=None,
        bias=None,
        has_bias: bool = False,
        weight_init_func: Callable = kaiming,
        bias_init_func: Callable = uniform_sqrt,
        *args,
        **kwargs,
    ):
        """
        Encapsulate a linear weight matrix, with optional biases

        `.Linear` wraps a single weight matrix, and passes data through by using the matrix as a set of weights. The shape of the matrix must be specified as a tuple ``(Nin, Nout)``. `.Linear` provides optional biases.

        A weight initialisation function may be specified. By default the weights will use Kaiming initialisation (:py:func:`.kaiming`).

        A bias initialisation function may be specified, if used. By default the biases will be initialised as uniform random over the range :math:`(-\\sqrt(1/N), \\sqrt(1/N))`.

        Warnings:
            Standard DNN libraries by default include a bias on linear layers. These are usually not used for SNNs, where the bias is configured on the spiking neuron module. :py:class:`.Linear` layers in Rockpool use a default of ``has_bias = False``. You can force the presence of a bias on the linear layer with ``has_bias = True`` on initialisation.

        Examples:

            Build a linear weight matrix with shape ``(3, 4)``, and no biases:

            >>> Linear((3, 4))
            Linear  with shape (3, 4)

            Build a linear weight matrix with shape ``(2, 5)``, which will be initialised with zeros:

            >>> Linear((2, 5), weight_init_func = lambda s: np.zeros(s))
            Linear  with shape (2, 5)

            Provide a concrete initialisation for the linear weights:

            >>> Linear((2, 2), weight = np.array([[1, 2], [3, 4]]))
            Linear  with shape (2, 2)

            Build a linear layer including biases:

            >>> mod = Linear((2, 2), has_bias = True)
            >>> mod.parameters()
            Out:
            {'weight': array([[ 0.56655314,  0.64411151],
                    [-1.43016068, -1.538719  ]]),
             'bias': array([-0.58513867, -0.32314069])}

        Args:
            shape (tuple): The desired shape of the weight matrix. Must have two entries ``(Nin, Nout)``
            weight_init_func (Callable): The initialisation function to use for the weights. Default: Kaiming initialization; uniform on the range :math:`(-\\sqrt(6/Nin), \\sqrt(6/Nin))`
            weight (Optional[np.array]): A concrete weight matrix to assign to the weights on initialisation. ``weight.shape`` must match the ``shape`` argument
            has_bias (bool): A boolean flag indicating that this linear layer should have a bias parameter. Default: ``False``, no bias parameter
            bias_init_func (Callable): The initialisation function to use for the biases. Default: Uniform / sqrt(N); uniform on the range :math:`(-\\sqrt(1/N), \\sqrt(1/N))`
            bias (Optional[np.array]): A concrete bias vector to assign to the biases on initialisation. ``bias.shape`` must be ``(N,)``
        """
        # - Base class must be `Module`
        if not isinstance(self, Module):
            raise TypeError(
                "`LinearMixin` mix-in class may only be used with `Module` classes."
            )

        # - Call superclass init
        super().__init__(shape=shape, *args, **kwargs)

        if len(self.shape) != 2:
            raise ValueError("`shape` must specify input and output sizes for Linear.")

        # - Specify weight parameter
        self.weight = Parameter(
            weight, shape=self.shape, init_func=weight_init_func, family="weights"
        )
        """ Weight matrix of this module """

        # - Specify bias parameter
        if has_bias or bias is not None:
            self.bias = Parameter(
                bias, shape=self.size_out, init_func=bias_init_func, family="biases"
            )
            """ Bias vector of this module """
            self._has_bias = True
        else:
            self.bias = 0
            self._has_bias = False

    def evolve(self, input_data, record: bool = False) -> Tuple[Any, Any, Any]:
        return self._dot(input_data, self.weight) + self.bias, {}, {}

    def as_graph(self) -> GraphModuleBase:
        return LinearWeights._factory(
            self.size_in,
            self.size_out,
            f"{type(self).__name__}_{self.name}_{id(self)}",
            self,
            self.weight,
            self.bias if self._has_bias else None,
        )

    @classmethod
    def from_graph(cls, graph: LinearWeights) -> LinearMixin:
        """
        from_graph constructs a LinearMixin object from the comptutational graph

        :param graph: the reference computational graph to restore computational module
        :type graph: LinearWeights
        :return: a LinearMixin object
        :rtype: LinearMixin
        """
        if not isinstance(graph, LinearWeights):
            graph = LinearWeights._convert_from(graph)

        return cls(
            shape=(len(graph.input_nodes), len(graph.output_nodes)),
            weight=graph.weights,
            bias=graph.biases,
            has_bias=False if graph.biases is None else True,
        )


[docs]class Linear(LinearMixin, Module): """ Encapsulates a linear weight matrix """ _dot = staticmethod(onp.dot) pass