Source code for devices.xylo.syns61201.xylo_mapper

"""
Mapper package for Xylo core v2

- Create a graph using the :py:meth:`~.graph.GraphModule.as_graph` API
- Call :py:func:`.mapper`

"""

import numpy as np

import copy

from rockpool.graph import (
    GraphModuleBase,
    GenericNeurons,
    AliasConnection,
    LinearWeights,
    SetList,
    bag_graph,
    find_modules_of_subclass,
    find_recurrent_modules,
)
from .xylo_graph_modules import (
    Xylo2HiddenNeurons,
    Xylo2OutputNeurons,
    Xylo2Neurons,
)
from rockpool.devices.xylo.syns61300.xylo_mapper import (
    xylo_drc,
    DRCError,
    DRCWarning,
    check_drc,
    assign_ids_to_class,
)

from typing import List, Callable, Set, Optional, Union

__all__ = ["mapper", "DRCError", "DRCWarning"]


[docs]def mapper( graph: GraphModuleBase, weight_dtype: Union[np.dtype, str] = "float", threshold_dtype: Union[np.dtype, str] = "float", dash_dtype: Union[np.dtype, str] = "float", max_hidden_neurons: int = 1000, max_output_neurons: int = 8, ) -> dict: """ Map a computational graph onto the Xylo v2 (SYNS61201) architecture This function performs a DRC of the computational graph to ensure it can be mapped onto the Xylo v2 (SYNS61201) architecture. Warnings: :py:func:`mapper` operates **in-place** on the graph, and may modify it. If you need the un-mapped graph, you may need to call :py:meth:`.Module.as_graph` again on your :py:class:`.Module`. It then allocates neurons and converts the network weights into a specification for Xylo. This specification can be used to create a config object with :py:func:`~rockpool.devices.xylo.syns61201.config_from_specification`. Args: graph (GraphModuleBase): The graph to map weight_dtype (Union[np.dtype, str]): Data type for mapped weight parameters. Default: ``"int8"`` threshold_dtype (Union[np.dtype, str]): Data type for mapped threshold parameters. Default: ``"int16"`` dash_dtype (Union[np.dtype, str]): Data type for mapped dash (bitshift time constant) parameters. Default: ``"uint8"`` max_hidden_neurons (int): Maximum number of available hidden neurons. Default: ``1000``, matching Xylo hardware max_output_neurons (int): Maximum number of available output neurons. Default: ``8``, matching Xylo hardware Returns: dict: A dictionary of specifications for Xylo v2, containing the mapped computational graph """ # - Check design rules check_drc(graph, xylo_drc) # --- Replace neuron modules with known graph classes --- # - Get output spiking layer from output nodes output_neurons: Set[GenericNeurons] = set() for on in graph.output_nodes: for sm in on.source_modules: if isinstance(sm, GenericNeurons): output_neurons.add(sm) # - Replace these output neurons with XyloOutputNeurons for on in output_neurons: try: Xylo2OutputNeurons._convert_from(on) except Exception as e: raise DRCError(f"Error replacing output neuron module {on}.") from e # - Replace all other neurons with XyloHiddenNeurons nodes, modules = bag_graph(graph) for m in modules: if isinstance(m, GenericNeurons) and not isinstance(m, Xylo2OutputNeurons): try: Xylo2HiddenNeurons._convert_from(m) except Exception as e: raise DRCError(f"Error replacing module {m}.") from e # --- Assign neurons to HW neurons --- # - Enumerate hidden neurons available_hidden_neuron_ids = list(range(max_hidden_neurons)) try: allocated_hidden_neurons = assign_ids_to_class( graph, Xylo2HiddenNeurons, available_hidden_neuron_ids ) except Exception as e: raise DRCError("Failed to allocate HW resources for hidden neurons.") from e # - Enumerate output neurons available_output_neuron_ids = list( range(max_hidden_neurons, max_hidden_neurons + max_output_neurons) ) try: allocated_output_neurons = assign_ids_to_class( graph, Xylo2OutputNeurons, available_output_neuron_ids ) except Exception as e: raise DRCError("Failed to allocate HW resources for output neurons.") from e # - Enumerate input channels input_channels = list(range(len(graph.input_nodes))) # - How many synapses are we using for hidden neurons? hidden_neurons: SetList[Xylo2HiddenNeurons] = find_modules_of_subclass( graph, Xylo2HiddenNeurons ) num_hidden_synapses = 1 for hn in hidden_neurons: if len(hn.input_nodes) > len(hn.output_nodes): num_hidden_synapses = 2 # --- Map weights and build Xylo weight matrices --- # - Build an input weight matrix input_weight_mod: LinearWeights = graph.input_nodes[0].sink_modules[0] target_neurons: Xylo2Neurons = input_weight_mod.output_nodes[0].sink_modules[0] # ^ Since DRC passed, we know this is valid weight_num_synapses = ( 2 if len(target_neurons.input_nodes) > len(target_neurons.output_nodes) else 1 ) target_ids = target_neurons.hw_ids these_dest_indices = [allocated_hidden_neurons.index(id) for id in target_ids] # - Allocate and assign the input weights w_in = np.zeros( (len(input_channels), len(allocated_hidden_neurons), num_hidden_synapses), weight_dtype, ) w_in[ np.ix_(input_channels, these_dest_indices, list(range(weight_num_synapses))) ] = input_weight_mod.weights.reshape( (len(input_channels), len(these_dest_indices), weight_num_synapses) ) # - Build a recurrent weight matrix w_rec = np.zeros( ( len(allocated_hidden_neurons), len(allocated_hidden_neurons), num_hidden_synapses, ), weight_dtype, ) w_rec_source_ids = allocated_hidden_neurons w_rec_dest_ids = allocated_hidden_neurons # - Build an output weight matrix w_out = np.zeros( (len(allocated_hidden_neurons), len(allocated_output_neurons)), weight_dtype ) w_out_source_ids = allocated_hidden_neurons w_out_dest_ids = allocated_output_neurons # - Get all weights weights: SetList[LinearWeights] = find_modules_of_subclass(graph, LinearWeights) weights.remove(input_weight_mod) # - For each weight module, place the weights in the right place for w in weights: # - Find the destination neurons sm = SetList( [ sm for n in w.output_nodes for sm in n.sink_modules if isinstance(sm, Xylo2Neurons) ] ) target_neurons: Xylo2Neurons = sm[0] # - How many target synapses per neuron? num_target_syns = ( 2 if len(target_neurons.input_nodes) > len(target_neurons.output_nodes) else 1 ) # - Find the source neurons sm = SetList( [ sm for n in w.input_nodes for sm in n.source_modules if isinstance(sm, Xylo2Neurons) ] ) source_neurons: Xylo2Neurons = sm[0] # - Get source and target HW IDs source_ids = source_neurons.hw_ids target_ids = target_neurons.hw_ids # - Does this go in the recurrent or output weights? if isinstance(target_neurons, Xylo2HiddenNeurons): # - Recurrent weights these_weights = np.reshape( w.weights, (len(source_ids), len(target_ids), num_target_syns) ) these_source_indices = [w_rec_source_ids.index(id) for id in source_ids] these_dest_indices = [w_rec_dest_ids.index(id) for id in target_ids] # - Assign weights w_rec[ np.ix_( these_source_indices, these_dest_indices, np.arange(num_target_syns) ) ] = these_weights elif isinstance(target_neurons, Xylo2OutputNeurons): # - Output weights these_source_indices = [w_out_source_ids.index(id) for id in source_ids] these_dest_indices = [w_out_dest_ids.index(id) for id in target_ids] # - Assign weights w_out[np.ix_(these_source_indices, these_dest_indices)] = w.weights else: raise DRCError( f"Unexpected target of weight graph module {w}. Expected XyloHiddenNeurons or XyloOutputNeurons." ) # - If we are not using synapse 2, we need to trim the weights if num_hidden_synapses == 1: w_in = np.reshape(w_in, (len(input_channels), len(allocated_hidden_neurons))) w_rec = np.reshape( w_rec, (len(allocated_hidden_neurons), len(allocated_hidden_neurons)) ) # --- Extract parameters from nodes --- hidden_neurons: SetList[Xylo2HiddenNeurons] = find_modules_of_subclass( graph, Xylo2HiddenNeurons ) output_neurons: SetList[Xylo2OutputNeurons] = find_modules_of_subclass( graph, Xylo2OutputNeurons ) num_hidden_neurons = len(allocated_hidden_neurons) num_output_neurons = len(allocated_output_neurons) dash_mem = np.zeros(num_hidden_neurons, dash_dtype) dash_mem_out = np.zeros(num_output_neurons, dash_dtype) dash_syn = np.zeros(num_hidden_neurons, dash_dtype) dash_syn_2 = np.zeros(num_hidden_neurons, dash_dtype) dash_syn_out = np.zeros(num_output_neurons, dash_dtype) threshold = np.zeros(num_hidden_neurons, threshold_dtype) threshold_out = np.zeros(num_output_neurons, threshold_dtype) bias = np.zeros(num_hidden_neurons, weight_dtype) bias_out = np.zeros(num_output_neurons, weight_dtype) for n in hidden_neurons: these_indices = n.hw_ids dash_mem[these_indices] = n.dash_mem if len(n.input_nodes) > len(n.output_nodes): dash_syn_reshape = np.array(n.dash_syn).reshape((-1, 2)) for i, index in enumerate(these_indices): dash_syn[index] = dash_syn_reshape[i][0] dash_syn_2[index] = dash_syn_reshape[i][1] else: for i, index in enumerate(these_indices): dash_syn[index] = n.dash_syn[i] threshold[these_indices] = n.threshold bias[these_indices] = n.bias for n in output_neurons: these_indices = [allocated_output_neurons.index(id) for id in n.hw_ids] dash_mem_out[these_indices] = n.dash_mem for i, index in enumerate(these_indices): dash_syn_out[index] = n.dash_syn[i] threshold_out[these_indices] = n.threshold bias_out[these_indices] = n.bias neurons: SetList[Xylo2Neurons] = find_modules_of_subclass(graph, Xylo2Neurons) dt = None for n in neurons: dt = n.dt if dt is None else dt # --- Extract aliases from nodes --- aliases = find_modules_of_subclass(graph, AliasConnection) list_aliases = [[] for _ in range(num_hidden_neurons)] for a in aliases: # - Find the source neurons sm = SetList( [ sm for n in a.input_nodes for sm in n.source_modules if isinstance(sm, Xylo2Neurons) ] ) source_neurons: Xylo2Neurons = sm[0] # - Find the destination neurons sm = SetList( [ sm for n in a.output_nodes for sm in n.source_modules if isinstance(sm, Xylo2Neurons) ] ) target_neurons: Xylo2Neurons = sm[0] # - Get the source and target HW IDs source_ids = source_neurons.hw_ids target_ids = target_neurons.hw_ids # - Add to the aliases list for source, target in zip(source_ids, target_ids): list_aliases[source].append(target) return { "mapped_graph": graph, "weights_in": w_in, "weights_out": w_out, "weights_rec": w_rec, "dash_mem": dash_mem, "dash_mem_out": dash_mem_out, "dash_syn": dash_syn, "dash_syn_2": dash_syn_2, "dash_syn_out": dash_syn_out, "threshold": threshold, "threshold_out": threshold_out, "bias": bias, "bias_out": bias_out, "weight_shift_in": 0, "weight_shift_rec": 0, "weight_shift_out": 0, "aliases": list_aliases, "dt": dt, }