Source code for nn.modules.jax.linear_jax

"""
Linear module with a Jax backend
"""

from rockpool.nn.modules.native.linear import LinearMixin
from rockpool.nn.modules.jax.jax_module import JaxModule
import jax.numpy as jnp


[docs]class LinearJax(LinearMixin, JaxModule): """ Encapsulates a linear weight matrix, with a Jax backend """ _dot = staticmethod(jnp.dot) pass