"""
Mapper package for Xylo core v2
- Create a graph using the :py:meth:`~.graph.GraphModule.as_graph` API
- Call :py:func:`.mapper`
"""
import numpy as np
import warnings
from rockpool.graph import (
GraphModuleBase,
GenericNeurons,
AliasConnection,
LinearWeights,
SetList,
bag_graph,
find_modules_of_subclass,
find_recurrent_modules,
)
from .xylo_graph_modules import (
Xylo2HiddenNeurons,
Xylo2OutputNeurons,
Xylo2Neurons,
)
from typing import List, Callable, Set, Optional, Union
__all__ = ["mapper", "DRCError", "DRCWarning"]
class DRCError(ValueError):
pass
class DRCWarning(Warning, DRCError):
pass
def output_nodes_have_neurons_as_source(graph: GraphModuleBase) -> None:
# - All output nodes must have a source that is a neuron
for n in graph.output_nodes:
for s in n.source_modules:
if not isinstance(s, GenericNeurons):
raise DRCError(
f"All network outputs must be directly from neurons.\nA network output node {n} has a source {s} which is not a neuron."
)
def input_to_neurons_is_a_weight(graph: GraphModuleBase) -> None:
# - Every neuron module must have weights on the input
neurons = find_modules_of_subclass(graph, GenericNeurons)
for n in neurons:
for inp in n.input_nodes:
for sm in inp.source_modules:
if not isinstance(sm, LinearWeights):
raise DRCError(
f"All neurons must receive inputs only from weight nodes.\nA neuron node {n} has a source module {sm} which is not a LinearWeight."
)
def first_module_is_a_weight(graph: GraphModuleBase) -> None:
# - The first module after the input must be a set of weights
for inp in graph.input_nodes:
for sink in inp.sink_modules:
if not isinstance(sink, LinearWeights):
raise DRCError(
f"The network input must go first through a weight.\nA network input node {inp} has a sink module {sink} which is not a LinearWeight."
)
def le_16_input_channels(graph: GraphModuleBase) -> None:
if len(graph.input_nodes) > 16:
warnings.warn(
DRCWarning(
f"Xylo only supports up to 16 input channels. The network requires {len(graph.input_nodes)} input channels."
),
DRCWarning,
)
def le_8_output_channels(graph: GraphModuleBase) -> None:
if len(graph.output_nodes) > 8:
warnings.warn(
DRCWarning(
f"Xylo only supports up to 8 output channels. The network requires {len(graph.output_nodes)} output channels."
),
DRCWarning,
)
def all_neurons_have_same_dt(graph: GraphModuleBase) -> None:
neurons: SetList[GenericNeurons] = find_modules_of_subclass(graph, GenericNeurons)
dt: Optional[float] = None
for n in neurons:
if hasattr(n, "dt"):
dt = n.dt if dt is None else dt
if dt is not None and n.dt is not None and not np.isclose(dt, n.dt):
raise DRCError("All neurons in the network must share a common `dt`.")
if dt is None:
raise DRCError(
"The network must specify a `dt` for at least one neuron module."
)
def output_neurons_cannot_be_recurrent(graph: GraphModuleBase) -> None:
_, recurrent_modules = find_recurrent_modules(graph)
output_neurons = SetList()
for n in graph.output_nodes:
for s in n.source_modules:
if isinstance(s, GenericNeurons):
output_neurons.add(s)
rec_output_neurons = set(output_neurons).intersection(recurrent_modules)
if len(rec_output_neurons) > 0:
raise DRCError(
f"Output neurons may not be recurrent.\nFound output neurons {rec_output_neurons} that are recurrent."
)
def no_consecutive_weights(graph: GraphModuleBase) -> None:
all_weights: List[LinearWeights] = find_modules_of_subclass(graph, LinearWeights)
for w in all_weights:
for i_n in w.input_nodes:
for sm in i_n.source_modules:
if isinstance(sm, LinearWeights):
raise DRCError(
f"Inputs to linear weights may not be linear weights.\nFound linear weights {sm} as source module -> to linear weights {w}."
)
for o_n in w.output_nodes:
for sm in o_n.sink_modules:
if isinstance(sm, LinearWeights):
raise DRCError(
f"Outputs of linear weights may not be linear weights.\nFound linear weights {w} with output sink module -> {sm}."
)
def alias_inputs_must_be_neurons(graph: GraphModuleBase) -> None:
all_aliases: List[AliasConnection] = find_modules_of_subclass(
graph, AliasConnection
)
for a in all_aliases:
for i_n in a.input_nodes:
for source in i_n.source_modules:
if not isinstance(source, (GenericNeurons, AliasConnection)):
raise DRCError(
f"Inputs to alias connections must be neurons or another alias.\nFound source module {source} as source -> to aliases {a}."
)
def alias_output_nodes_must_have_neurons_as_input(graph: GraphModuleBase) -> None:
all_aliases: List[AliasConnection] = find_modules_of_subclass(
graph, AliasConnection
)
for a in all_aliases:
for o_n in a.output_nodes:
for source in o_n.source_modules:
if not isinstance(source, (GenericNeurons, AliasConnection)):
raise DRCError(
f"Alias connections must have neurons as the last block before the output.\nFound aliases {a} with module {source} as the last module in the graph."
)
def at_least_two_neuron_layers_needed(graph: GraphModuleBase) -> None:
all_neurons: List[GenericNeurons] = find_modules_of_subclass(graph, GenericNeurons)
if len(all_neurons) < 2:
raise DRCError(
"At least two layers of neurons are required to map to hidden and output layers on Xylo."
)
def weight_nodes_have_no_biases(graph: GraphModuleBase) -> None:
all_weights: List[LinearWeights] = find_modules_of_subclass(graph, LinearWeights)
for w in all_weights:
if w.biases is not None:
warnings.warn(
f"Bias parameters of LinearWeights modules are *not* transferred to Xylo.\nFound weights {w} with biases. Set `has_bias = False` for this module .",
DRCWarning,
)
xylo_drc: List[Callable[[GraphModuleBase], None]] = [
output_nodes_have_neurons_as_source,
input_to_neurons_is_a_weight,
first_module_is_a_weight,
le_16_input_channels,
le_8_output_channels,
all_neurons_have_same_dt,
output_neurons_cannot_be_recurrent,
no_consecutive_weights,
alias_inputs_must_be_neurons,
alias_output_nodes_must_have_neurons_as_input,
at_least_two_neuron_layers_needed,
weight_nodes_have_no_biases,
]
""" List[Callable[[GraphModuleBase], None]]: The collection of design rules for Xylo """
def check_drc(
graph: GraphModuleBase, design_rules: List[Callable[[GraphModuleBase], None]]
):
"""
Perform a design rule check over a graph
Args:
graph (GraphModuleBase): A graph to check
design_rules (List[Callable[[GraphModuleBase], None]]): A list of functions, each of which performs a DRC over a graph
Raises:
DRCError: If a design rule is violated
"""
for dr in design_rules:
try:
dr(graph)
except DRCError as e:
raise DRCError(
f"Design rule {dr.__name__} triggered an error:\n"
+ "".join([f"{msg}" for msg in e.args])
)
def assign_ids_to_module(m: GraphModuleBase, available_ids: List[int]) -> List[int]:
"""
Assign IDs from a list to a single graph module
This function sets the :py:attr:`~.graph.GraphModule.hw_ids` attribute for a single :py:class:`.graph.GraphModule`, by assigning them greedily from a list. The allocated IDs are removed from the ``available`` list, are set in the graph module, and are returned as a list.
Examples:
>>> output_ids = list(range(16))
>>> allocated_ids = assign_ids_to_module(mod, output_ids)
>>> print(allocated_ids)
[0, 1, 2, 3, 4, 5]
>>> print(output_ids)
[6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
Args:
m (GraphModuleBase): The module to assign IDs to
available_ids (List[int]): A list of integer unique HW IDs that can be allocated from. These IDs will be allocated to the graph modules.
Returns:
List[int]: The hardware IDs that were allocated to the graph module
"""
num_needed_ids = len(m.output_nodes)
if len(available_ids) < num_needed_ids:
raise DRCError(f"Exceeded number of available resources for graph module {m}.")
# - Allocate the IDs and remove them from the available list
m.hw_ids = available_ids[:num_needed_ids]
del available_ids[:num_needed_ids]
# - Annotate the original computational module with the allocated hardware IDs, if possible
if m.computational_module is not None:
m.computational_module._hw_ids = m.hw_ids
return m.hw_ids
def assign_ids_to_class(
graph: GraphModuleBase, cls, available_ids: List[int]
) -> List[int]:
"""
Assign IDs from a list to a class of graph module
This function sets the :py:attr:`~.graph.GraphModule.hw_ids` attribute for all :py:class:`.graph.GraphModule` s of a chosen subclass, by assigning them greedily from a list. The allocated IDs are removed from the ``available`` list, are set in the graph modules, and are returned as a list.
Examples:
>>> output_ids = list(range(16))
>>> allocated_ids = assign_ids_to_class(graph, XyloOutputNeurons, output_ids)
>>> print(allocated_ids)
[0, 1, 2, 3, 4, 5]
>>> print(output_ids)
[6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
Args:
graph (GraphModuleBase): The graph to search over
cls: The :py:class:`~.graph.GraphModule` subclass to search for, to assign IDs to
available_ids (List[int]): A list of integer unique HW IDs that can be allocated from. These IDs will be allocated to the graph modules.
Returns:
List[int]: The hardware IDs that were allocated to the graph modules
"""
# - Build a list of ids that are allocated
allocated_ids = []
# - Get all modules of the defined class
modules = find_modules_of_subclass(graph, cls)
# - Allocate HW ids to these modules
allocated_ids = [assign_ids_to_module(m, available_ids) for m in modules]
allocated_ids = [nid for ids in allocated_ids for nid in ids]
return allocated_ids
[docs]def mapper(
graph: GraphModuleBase,
weight_dtype: Union[np.dtype, str] = "float",
threshold_dtype: Union[np.dtype, str] = "float",
dash_dtype: Union[np.dtype, str] = "float",
max_hidden_neurons: int = 1000,
max_output_neurons: int = 8,
) -> dict:
"""
Map a computational graph onto the Xylo v2 (SYNS61201) architecture
This function performs a DRC of the computational graph to ensure it can be mapped onto the Xylo v2 (SYNS61201) architecture.
Warnings:
:py:func:`mapper` operates **in-place** on the graph, and may modify it. If you need the un-mapped graph, you may need to call :py:meth:`.Module.as_graph` again on your :py:class:`.Module`.
It then allocates neurons and converts the network weights into a specification for Xylo. This specification can be used to create a config object with :py:func:`~rockpool.devices.xylo.syns61201.config_from_specification`.
Args:
graph (GraphModuleBase): The graph to map
weight_dtype (Union[np.dtype, str]): Data type for mapped weight parameters. Default: ``"int8"``
threshold_dtype (Union[np.dtype, str]): Data type for mapped threshold parameters. Default: ``"int16"``
dash_dtype (Union[np.dtype, str]): Data type for mapped dash (bitshift time constant) parameters. Default: ``"uint8"``
max_hidden_neurons (int): Maximum number of available hidden neurons. Default: ``1000``, matching Xylo hardware
max_output_neurons (int): Maximum number of available output neurons. Default: ``8``, matching Xylo hardware
Returns:
dict: A dictionary of specifications for Xylo v2, containing the mapped computational graph
"""
# - Check design rules
check_drc(graph, xylo_drc)
# --- Replace neuron modules with known graph classes ---
# - Get output spiking layer from output nodes
output_neurons: Set[GenericNeurons] = set()
for on in graph.output_nodes:
for sm in on.source_modules:
if isinstance(sm, GenericNeurons):
output_neurons.add(sm)
# - Replace these output neurons with XyloOutputNeurons
for on in output_neurons:
try:
Xylo2OutputNeurons._convert_from(on)
except Exception as e:
raise DRCError(f"Error replacing output neuron module {on}.") from e
# - Replace all other neurons with XyloHiddenNeurons
nodes, modules = bag_graph(graph)
for m in modules:
if isinstance(m, GenericNeurons) and not isinstance(m, Xylo2OutputNeurons):
try:
Xylo2HiddenNeurons._convert_from(m)
except Exception as e:
raise DRCError(f"Error replacing module {m}.") from e
# --- Assign neurons to HW neurons ---
# - Enumerate hidden neurons
available_hidden_neuron_ids = list(range(max_hidden_neurons))
try:
allocated_hidden_neurons = assign_ids_to_class(
graph, Xylo2HiddenNeurons, available_hidden_neuron_ids
)
except Exception as e:
raise DRCError("Failed to allocate HW resources for hidden neurons.") from e
# - Enumerate output neurons
available_output_neuron_ids = list(
range(max_hidden_neurons, max_hidden_neurons + max_output_neurons)
)
try:
allocated_output_neurons = assign_ids_to_class(
graph, Xylo2OutputNeurons, available_output_neuron_ids
)
except Exception as e:
raise DRCError("Failed to allocate HW resources for output neurons.") from e
# - Enumerate input channels
input_channels = list(range(len(graph.input_nodes)))
# - How many synapses are we using for hidden neurons?
hidden_neurons: SetList[Xylo2HiddenNeurons] = find_modules_of_subclass(
graph, Xylo2HiddenNeurons
)
num_hidden_synapses = 1
for hn in hidden_neurons:
if len(hn.input_nodes) > len(hn.output_nodes):
num_hidden_synapses = 2
# --- Map weights and build Xylo weight matrices ---
# - Build an input weight matrix
input_weight_mod: LinearWeights = graph.input_nodes[0].sink_modules[0]
target_neurons: Xylo2Neurons = input_weight_mod.output_nodes[0].sink_modules[0]
# ^ Since DRC passed, we know this is valid
weight_num_synapses = (
2 if len(target_neurons.input_nodes) > len(target_neurons.output_nodes) else 1
)
target_ids = target_neurons.hw_ids
these_dest_indices = [allocated_hidden_neurons.index(id) for id in target_ids]
# - Allocate and assign the input weights
w_in = np.zeros(
(len(input_channels), len(allocated_hidden_neurons), num_hidden_synapses),
weight_dtype,
)
w_in[
np.ix_(input_channels, these_dest_indices, list(range(weight_num_synapses)))
] = input_weight_mod.weights.reshape(
(len(input_channels), len(these_dest_indices), weight_num_synapses)
)
# - Build a recurrent weight matrix
w_rec = np.zeros(
(
len(allocated_hidden_neurons),
len(allocated_hidden_neurons),
num_hidden_synapses,
),
weight_dtype,
)
w_rec_source_ids = allocated_hidden_neurons
w_rec_dest_ids = allocated_hidden_neurons
# - Build an output weight matrix
w_out = np.zeros(
(len(allocated_hidden_neurons), len(allocated_output_neurons)), weight_dtype
)
w_out_source_ids = allocated_hidden_neurons
w_out_dest_ids = allocated_output_neurons
# - Get all weights
weights: SetList[LinearWeights] = find_modules_of_subclass(graph, LinearWeights)
weights.remove(input_weight_mod)
# - For each weight module, place the weights in the right place
for w in weights:
# - Find the destination neurons
sm = SetList(
[
sm
for n in w.output_nodes
for sm in n.sink_modules
if isinstance(sm, Xylo2Neurons)
]
)
target_neurons: Xylo2Neurons = sm[0]
# - How many target synapses per neuron?
num_target_syns = (
2
if len(target_neurons.input_nodes) > len(target_neurons.output_nodes)
else 1
)
# - Find the source neurons
sm = SetList(
[
sm
for n in w.input_nodes
for sm in n.source_modules
if isinstance(sm, Xylo2Neurons)
]
)
source_neurons: Xylo2Neurons = sm[0]
# - Get source and target HW IDs
source_ids = source_neurons.hw_ids
target_ids = target_neurons.hw_ids
# - Does this go in the recurrent or output weights?
if isinstance(target_neurons, Xylo2HiddenNeurons):
# - Recurrent weights
these_weights = np.reshape(
w.weights, (len(source_ids), len(target_ids), num_target_syns)
)
these_source_indices = [w_rec_source_ids.index(id) for id in source_ids]
these_dest_indices = [w_rec_dest_ids.index(id) for id in target_ids]
# - Assign weights
w_rec[
np.ix_(
these_source_indices, these_dest_indices, np.arange(num_target_syns)
)
] = these_weights
elif isinstance(target_neurons, Xylo2OutputNeurons):
# - Output weights
these_source_indices = [w_out_source_ids.index(id) for id in source_ids]
these_dest_indices = [w_out_dest_ids.index(id) for id in target_ids]
# - Assign weights
w_out[np.ix_(these_source_indices, these_dest_indices)] = w.weights
else:
raise DRCError(
f"Unexpected target of weight graph module {w}. Expected XyloHiddenNeurons or XyloOutputNeurons."
)
# - If we are not using synapse 2, we need to trim the weights
if num_hidden_synapses == 1:
w_in = np.reshape(w_in, (len(input_channels), len(allocated_hidden_neurons)))
w_rec = np.reshape(
w_rec, (len(allocated_hidden_neurons), len(allocated_hidden_neurons))
)
# --- Extract parameters from nodes ---
hidden_neurons: SetList[Xylo2HiddenNeurons] = find_modules_of_subclass(
graph, Xylo2HiddenNeurons
)
output_neurons: SetList[Xylo2OutputNeurons] = find_modules_of_subclass(
graph, Xylo2OutputNeurons
)
num_hidden_neurons = len(allocated_hidden_neurons)
num_output_neurons = len(allocated_output_neurons)
dash_mem = np.zeros(num_hidden_neurons, dash_dtype)
dash_mem_out = np.zeros(num_output_neurons, dash_dtype)
dash_syn = np.zeros(num_hidden_neurons, dash_dtype)
dash_syn_2 = np.zeros(num_hidden_neurons, dash_dtype)
dash_syn_out = np.zeros(num_output_neurons, dash_dtype)
threshold = np.zeros(num_hidden_neurons, threshold_dtype)
threshold_out = np.zeros(num_output_neurons, threshold_dtype)
bias = np.zeros(num_hidden_neurons, weight_dtype)
bias_out = np.zeros(num_output_neurons, weight_dtype)
for n in hidden_neurons:
these_indices = n.hw_ids
dash_mem[these_indices] = n.dash_mem
if len(n.input_nodes) > len(n.output_nodes):
dash_syn_reshape = np.array(n.dash_syn).reshape((-1, 2))
for i, index in enumerate(these_indices):
dash_syn[index] = dash_syn_reshape[i][0]
dash_syn_2[index] = dash_syn_reshape[i][1]
else:
for i, index in enumerate(these_indices):
dash_syn[index] = n.dash_syn[i]
threshold[these_indices] = n.threshold
bias[these_indices] = n.bias
for n in output_neurons:
these_indices = [allocated_output_neurons.index(id) for id in n.hw_ids]
dash_mem_out[these_indices] = n.dash_mem
for i, index in enumerate(these_indices):
dash_syn_out[index] = n.dash_syn[i]
threshold_out[these_indices] = n.threshold
bias_out[these_indices] = n.bias
neurons: SetList[Xylo2Neurons] = find_modules_of_subclass(graph, Xylo2Neurons)
dt = None
for n in neurons:
dt = n.dt if dt is None else dt
# --- Extract aliases from nodes ---
aliases = find_modules_of_subclass(graph, AliasConnection)
list_aliases = [[] for _ in range(num_hidden_neurons)]
for a in aliases:
# - Find the source neurons
sm = SetList(
[
sm
for n in a.input_nodes
for sm in n.source_modules
if isinstance(sm, Xylo2Neurons)
]
)
source_neurons: Xylo2Neurons = sm[0]
# - Find the destination neurons
sm = SetList(
[
sm
for n in a.output_nodes
for sm in n.source_modules
if isinstance(sm, Xylo2Neurons)
]
)
target_neurons: Xylo2Neurons = sm[0]
# - Get the source and target HW IDs
source_ids = source_neurons.hw_ids
target_ids = target_neurons.hw_ids
# - Add to the aliases list
for source, target in zip(source_ids, target_ids):
list_aliases[source].append(target)
return {
"mapped_graph": graph,
"weights_in": w_in,
"weights_out": w_out,
"weights_rec": w_rec,
"dash_mem": dash_mem,
"dash_mem_out": dash_mem_out,
"dash_syn": dash_syn,
"dash_syn_2": dash_syn_2,
"dash_syn_out": dash_syn_out,
"threshold": threshold,
"threshold_out": threshold_out,
"bias": bias,
"bias_out": bias_out,
"weight_shift_in": 0,
"weight_shift_rec": 0,
"weight_shift_out": 0,
"aliases": list_aliases,
"dt": dt,
}