Source code for devices.dynapse.quantization.autoencoder_quantization

"""
Dynap-SE autoencoder based quantization top level function calling the unsupervised learning methods per clusters
"""

from __future__ import annotations
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np

from .autoencoder import learn_weights


__all__ = ["autoencoder_quantization"]


[docs]def autoencoder_quantization( ## cluster n_cluster: int, core_map: List[int], ## Input weights_in: np.ndarray, weights_rec: np.ndarray, Iscale: float, n_bits: Optional[int] = 4, ## Optimization fixed_epoch: bool = False, num_epoch: int = int(1e7), num_epoch_checkpoint: int = int(1e3), eps: float = 1e-6, record_loss: bool = True, optimizer: str = "adam", step_size: Union[float, Callable[[int], float]] = 1e-4, opt_params: Optional[Dict[str, Any]] = {}, *args, **kwargs, ) -> Dict[str, Union[np.ndarray, float]]: """ autoencoder_quantization executes the unsupervised weight configuration learning approach `rockpool.devices.dynapse.quantization.autoencoder.learn.learn_weights` for each cluster seperately. The function subsets input and recurrent weights for each cluster and quantizes the weights according to regarding cluster's constraints. :param n_cluster: total number of clusters, neural cores allocated :type n_cluster: int :param core_map: core mapping for real hardware neurons (neuron_id : core_id) :type core_map: List[int] :param weights_in: input layer weights used in Dynap-SE2 simulation :type weights_in: Optional[np.ndarray] :param weights_rec: recurrent layer (in-device neurons) weights used in Dynap-SE2 simulation :type weights_rec: Optional[np.ndarray] :param Iscale: base weight scaling current in Amperes used in simulation :type Iscale: float :param n_bits: number of target weight bits, defaults to 4 :type n_bits: Optional[int], optional :param fixed_epoch: used fixed number of epochs or control the convergence by loss decrease, defaults to False :type fixed_epoch: bool, optional :param num_epoch: the fixed number of epochs as global limit, defaults to 10,000,000 :type num_epoch: int, optional :param num_epoch_checkpoint: at this point (number of epochs), pipeline checks the loss decrease and decide to continue or not, defaults to 1,000. :type num_epoch_checkpoint: int, optional :param eps: the epsilon tolerance value. If the loss does not decrease more than this for five consecutive checkpoints, then training stops. defaults to 1e-6 :type eps: float, optional :param record_loss: record the loss evolution or not, defaults to True :type record_loss: bool, optional :param optimizer: one of the optimizer defined in `jax.example_libraries.optimizers` : , defaults to "adam" :type optimizer: str, optional :param step_size: positive scalar, or a callable representing a step size schedule that maps the iteration index to a positive scalar. , defaults to 1e-4 :type step_size: Union[float, Callable[[int], float]], optional :param opt_params: optimizer parameters dictionary, defaults to {} :type opt_params: Optional[Dict[str, Any]] :return: A dictionary of quantized weights and parameters, the quantization loss :rtype: Dict[str, Union[np.ndarray, float]] """ spec = { "weights_in": [], "sign_in": [], "weights_rec": [], "sign_rec": [], "Iw_0": [], "Iw_1": [], "Iw_2": [], "Iw_3": [], "quantization_loss": [], } for n in range(n_cluster): if weights_in is not None: w_in = np.zeros_like(weights_in) w_in[:, core_map == n] = weights_in[:, core_map == n] else: w_in = None if weights_rec is not None: w_rec = np.zeros_like(weights_rec) w_rec[:, core_map == n] = weights_rec[:, core_map == n] else: w_rec = None __temp = learn_weights( w_in, w_rec, Iscale, n_bits, fixed_epoch, num_epoch, num_epoch_checkpoint, eps, record_loss, optimizer, step_size, opt_params, ) for key in spec: spec[key].append(__temp[key]) return spec