Source code for transform.quantize_methods

"""
Quantisation methods for Xylo

Defines the post-training quasntization methods :py:func:`.global_quantize` and :py:func:`.channel_quantize`.

"""

import numpy as np
import copy

__all__ = ["global_quantize", "channel_quantize"]


def validate_weights_to_2d(data):
    assert len(data.shape) == 2, "Output weights must be 2D."
    return data


def validate_weights_to_3d(data):
    assert (
        len(data.shape) >= 2 and len(data.shape) <= 3
    ), "Weight arrays must be 2D or 3D."
    return np.expand_dims(data, -1) if len(data.shape) < 3 else data


[docs]def global_quantize( weights_in: np.ndarray, weights_rec: np.ndarray, weights_out: np.ndarray, threshold: np.ndarray, threshold_out: np.ndarray, dash_mem: np.ndarray, dash_mem_out: np.ndarray, dash_syn: np.ndarray, dash_syn_2: np.ndarray, dash_syn_out: np.ndarray, bias: np.ndarray = None, bias_out: np.ndarray = None, fuzzy_scaling: bool = False, bits_per_weight: int = 8, bits_per_threshold: int = 16, *_, **__, ): """ Quantize a Xylo model for deployment, using global parameter scaling The figure below illustrates the groups of weights which are considered for quantization. Under this global method, all weights in the network are considered together when scaling and quantizing weights and thresholds. Input and recurrent weights are considered together as a group; output weights are considered separately when quantizing. Dashes are rounded and cast to integer. :: target ------------------- s ------------------- - o -**- -**- -**- -**- - u -**- -**- -**- -**- - r -**- -**- -**- -**- - c -**- -**- -**- -**- - e -**- -**- -**- -**- - ------------------- Examples: >>> specs = xylo.devices.mapper(net.as_graph(), weight_dtype="float", threshold_dtype="float") >>> specs.update(global_quantize(**specs, fuzzy_scaling = True)) >>> xylo.devices.XyloSim.from_specifications(specs) Args: weights_in (np.ndarray): Input weight matrix weights_rec (np.ndarray): Recurrent weight matrix weights_out (np.ndarray): Output weight matrix threshold (np.ndarray): Firing threshold for hidden neurons threshold_out (np.ndarray): Firing threshold for output neurons dash_mem (np.ndarray): Dash for membrane potential of hidden neurons dash_mem_out (np.ndarray): Dash for membrane potential of output neurons dash_syn (np.ndarray): Dash for synaptic current of hidden neurons dash_syn_2 (np.ndarray): Dash for second synaptic current of hidden neurons dash_syn_out (np.ndarray): Dash for synaptic current of output neurons bias (np.ndarray): Bias for hidden neurons bias_out (np.ndarray): Bias for output neurons fuzzy_scaling (bool): If ``True``, scale and clip weights to 2*std dev. If ``False`` (default), scale and clip to maximum absolute weight. bits_per_weight (int): Number of bits per integer signed weight. Default: ``8`` bits_per_threshold (int): Number of bits per integer signed threshold. Default: ``16`` Returns: dict: `model_quan` which can be used to update a Xylo specification dictionary """ w_in = validate_weights_to_3d(copy.copy(weights_in)) w_rec = validate_weights_to_3d(copy.copy(weights_rec)) w_out = validate_weights_to_2d(copy.copy(weights_out)) threshold = copy.copy(threshold) threshold_out = copy.copy(threshold_out) max_w_quan = 2 ** (bits_per_weight - 1) - 1 max_th_quan = 2 ** (bits_per_threshold - 1) - 1 if fuzzy_scaling: # detect outliers weights_rec_ = np.concatenate([np.ravel(w_in), np.ravel(w_rec)]) weights_rec_ = weights_rec_[weights_rec_ != 0] weights_out_ = np.ravel(w_out) weights_out_ = weights_out_[weights_out_ != 0] max_w = np.abs(weights_rec_.mean()) + 2 * weights_rec_.std() max_w_out = np.abs(weights_out_.mean()) + 2 * weights_out_.std() else: max_w = 0 max_w = np.max([max_w, np.max(np.abs(w_in))]) max_w = np.max([max_w, np.max(np.abs(w_rec))]) if not bias is None: max_w = np.max([max_w, np.max(np.abs(bias))]) max_w_out = np.max([0, np.max(np.abs(w_out))]) if not bias_out is None: max_w_out = np.max([max_w_out, np.max(np.abs(bias_out))]) # determine scaling value if max_w != 0: scaling = max_w_quan / max_w else: scaling = 1 if max_w_out != 0: scaling_out = max_w_quan / max_w_out else: scaling_out = 1 # scale weights weights_in = np.round(w_in * scaling).astype(int) weights_rec = np.round(w_rec * scaling).astype(int) weights_out = np.round(w_out * scaling_out).astype(int) # scale thresholds threshold = np.round(threshold * scaling).astype(int) threshold_out = np.round(threshold_out * scaling_out).astype(int) # bias if not bias is None: bias = np.round(bias * scaling).astype(int) # bias out if not bias_out is None: bias_out = np.round(bias_out * scaling).astype(int) # if the threshold exceed boundary if np.abs(np.max(threshold)) > max_th_quan: limited_scaling = max_th_quan / np.max(threshold) threshold = np.round(threshold * limited_scaling).astype(int) weights_in = np.round(weights_in * limited_scaling).astype(int) weights_rec = np.round(weights_rec * limited_scaling).astype(int) if np.abs(np.max(threshold_out)) > max_th_quan: limited_scaling = max_th_quan / np.max(threshold_out) threshold_out = np.round(threshold_out * limited_scaling).astype(int) weights_out = np.round(weights_out * limited_scaling).astype(int) weights_rec = np.round(weights_rec * limited_scaling).astype(int) # round and cast all dashes to integer dash_mem = np.round(dash_mem).astype(int) dash_mem_out = np.round(dash_mem_out).astype(int) dash_syn = np.round(dash_syn).astype(int) dash_syn_2 = np.round(dash_syn_2).astype(int) dash_syn_out = np.round(dash_syn_out).astype(int) model_quan = { "weights_in": weights_in, "weights_rec": weights_rec, "weights_out": weights_out, "threshold": threshold, "threshold_out": threshold_out, "dash_mem": dash_mem, "dash_mem_out": dash_mem_out, "dash_syn": dash_syn, "dash_syn_2": dash_syn_2, "dash_syn_out": dash_syn_out, } if not bias is None: model_quan["bias"] = bias if not bias_out is None: model_quan["bias_out"] = bias_out return model_quan
[docs]def channel_quantize( weights_in: np.ndarray, weights_rec: np.ndarray, weights_out: np.ndarray, threshold: np.ndarray, threshold_out: np.ndarray, dash_mem: np.ndarray, dash_mem_out: np.ndarray, dash_syn: np.ndarray, dash_syn_2: np.ndarray, dash_syn_out: np.ndarray, bias: np.ndarray = None, bias_out: np.ndarray = None, bits_per_weight: int = 8, bits_per_threshold: int = 16, *_, **__, ): """ Quantize a Xylo model for deployment, using per-channel parameter scaling The figure below illustrates the groups of weights which are considered for quantization. Under this per-channel method, all input weights to a single target neuron are considered together when scaling and quantizing weights and thresholds. Input and recurrent weights are considered together as a group; output weights are considered separately when quantizing. Dashes are rounded and cast to integer. :: target ------------------- s ------------------- - o -++- -**- -##- -oo- - u -++- -**- -##- -oo- - r -++- -**- -##- -oo- - c -++- -**- -##- -oo- - e -++- -**- -##- -oo- - ------------------- Examples: >>> specs = xylo.devices.mapper(net.as_graph(), weight_dtype="float", threshold_dtype="float") >>> specs.update(channel_quantize(**specs, bits_per_weight = 12)) >>> xylo.devices.XyloSim.from_specifications(specs) Args: weights_in (np.ndarray): Input weight matrix weights_rec (np.ndarray): Recurrent weight matrix weights_out (np.ndarray): Output weight matrix threshold (np.ndarray): Firing threshold for hidden neurons threshold_out (np.ndarray): Firing threshold for output neurons dash_mem (np.ndarray): Dash for membrane potential of hidden neurons dash_mem_out (np.ndarray): Dash for membrane potential of output neurons dash_syn (np.ndarray): Dash for synaptic current of hidden neurons dash_syn_2 (np.ndarray): Dash for second synaptic current of hidden neurons dash_syn_out (np.ndarray): Dash for synaptic current of output neurons bias (np.ndarray): Bias for hidden neurons bias_out (np.ndarray): Bias for output neurons bits_per_weight (int): Number of bits per integer signed weight. Default: ``8`` bits_per_threshold (int): Number of bits per integer signed threshold. Default: ``16`` Returns: dict: `model_quan` which can be used to update a Xylo specification dictionary """ w_in = validate_weights_to_3d(copy.copy(weights_in)) w_rec = validate_weights_to_3d(copy.copy(weights_rec)) w_out = validate_weights_to_2d(copy.copy(weights_out)) threshold = copy.copy(threshold) threshold_out = copy.copy(threshold_out) max_w_quan = 2 ** (bits_per_weight - 1) - 1 max_th_quan = 2 ** (bits_per_threshold - 1) - 1 Nin, Nien, Nsyn = w_in.shape Nhid, _, _ = w_rec.shape Noen, Nout = w_out.shape # quantize input weight, recurrent weight, threshold # two weight matrix need to stack together to consider per-channel quantization w_in_quan = np.zeros(shape=w_in.shape) w_rec_quan = np.zeros(shape=w_rec.shape) threshold_quan = np.zeros(shape=threshold.shape) for hidden_id in range(Nhid): # if two synaptic connection is used if Nsyn > 1: # - Find the maximum output weight for this neuron max_w = 0 max_w = ( np.max([max_w, np.max(np.abs(w_in[:, hidden_id, :]))]) if hidden_id < Nien else max_w ) max_w = np.max([max_w, np.max(np.abs(w_rec[:, hidden_id, :]))]) # - Scale and quantise weights, thresholds and biases if max_w != 0: scaling = max_w_quan / max_w if hidden_id < Nien: w_in_quan[:, hidden_id, :] = np.round( w_in[:, hidden_id, :] * scaling ) w_rec_quan[:, hidden_id, :] = np.round(w_rec[:, hidden_id, :] * scaling) threshold_quan[hidden_id] = np.round(threshold[hidden_id] * scaling) if not bias is None: bias[hidden_id] = np.round(bias[hidden_id] * scaling) # if the threshold exceeds boundary if np.abs(threshold_quan[hidden_id]) > max_th_quan: limited_scaling = max_th_quan / threshold[hidden_id] threshold_quan[hidden_id] = np.round( threshold[hidden_id] * limited_scaling ) if hidden_id <= Nien: w_in_quan[:, hidden_id, :] = np.round( w_in[:, hidden_id, :] * limited_scaling ) w_rec_quan[:, hidden_id, :] = np.round( w_rec[:, hidden_id, :] * limited_scaling ) else: threshold_quan[hidden_id] = np.round(threshold[hidden_id]) # if only one synaptic connection is used elif Nsyn == 1: # - Find the maximum output weight for this neuron max_w = 0 max_w = ( np.max([max_w, np.max(np.abs(w_in[:, hidden_id]))]) if hidden_id < Nien else max_w ) max_w = np.max([max_w, np.max(np.abs(w_rec[:, hidden_id]))]) # - Scale and quantise weights, thresholds and biases if max_w != 0: scaling = max_w_quan / max_w if hidden_id < Nien: w_in_quan[:, hidden_id] = np.round(w_in[:, hidden_id] * scaling) w_rec_quan[:, hidden_id] = np.round(w_rec[:, hidden_id] * scaling) threshold_quan[hidden_id] = np.round(threshold[hidden_id] * scaling) # if the threshold exceeds boundary if np.abs(threshold_quan[hidden_id]) > max_th_quan: limited_scaling = max_th_quan / threshold[hidden_id] threshold_quan[hidden_id] = np.round( threshold[hidden_id] * limited_scaling ) if hidden_id <= Nien: w_in_quan[:, hidden_id] = np.round( w_in[:, hidden_id] * limited_scaling ) w_rec_quan[:, hidden_id] = np.round( w_rec[:, hidden_id] * limited_scaling ) else: threshold_quan[hidden_id] = np.round(threshold[hidden_id]) # quantize output weight, threshold_out w_out_quan = np.zeros(shape=w_out.shape) threshold_out_quan = np.zeros(shape=threshold_out.shape) for output_id in range(Nout): max_w = 0 max_w = np.max([max_w, np.max(np.abs(w_out[:, output_id]))]) if max_w != 0: scaling = max_w_quan / max_w w_out_quan[:, output_id] = np.round(w_out[:, output_id] * scaling) threshold_out_quan[output_id] = np.round(threshold_out[output_id] * scaling) if not bias_out is None: bias_out[output_id] = np.round(bias_out[output_id] * scaling) # if the threshold exceed boundary if np.abs(threshold_out_quan[output_id]) > max_th_quan: limited_scaling = max_th_quan / threshold_out[output_id] threshold_out_quan[output_id] = np.round( threshold_out[output_id] * limited_scaling ) w_out_quan[:, output_id] = np.round( w_out[:, output_id] * limited_scaling ) else: threshold_out_quan[output_id] = np.round(threshold_out[output_id]) # make sure all types are int weights_in = w_in_quan.astype(int) weights_rec = w_rec_quan.astype(int) weights_out = w_out_quan.astype(int) dash_mem = np.round(dash_mem).astype(int) dash_mem_out = np.round(dash_mem_out).astype(int) dash_syn = np.round(dash_syn).astype(int) dash_syn_2 = np.round(dash_syn_2).astype(int) dash_syn_out = np.round(dash_syn_out).astype(int) threshold = threshold_quan.astype(int) threshold_out = threshold_out_quan.astype(int) if not bias is None: bias = np.round(bias).astype(int) if not bias_out is None: bias_out = np.round(bias_out).astype(int) model_quan = { "weights_in": weights_in, "weights_rec": weights_rec, "weights_out": weights_out, "threshold": threshold, "threshold_out": threshold_out, "dash_mem": dash_mem, "dash_mem_out": dash_mem_out, "dash_syn": dash_syn, "dash_syn_2": dash_syn_2, "dash_syn_out": dash_syn_out, } if not bias is None: model_quan["bias"] = bias if not bias_out is None: model_quan["bias_out"] = bias_out return model_quan