Graph mapping

You can extend the computational graphing capabilities of Rockpool, by adding new graph.GraphModule subclasses. These classes can be converted between eachother and the graph can be analysed in order to map networks on to hardware.

Subclassing GraphModule

All GraphModule classes are dataclass es, and use the @dataclass decorator. As below, you must use @dataclass(eq = False, repr = False) to decorate your subclass, in order to be compatible with the graph mapping subsystem.

The subsystem requires that equality is defined by object ID (hence eq = False), and provides a human-readable __repr__() method (hence repr = False, to avoid using the dataclass __repr__() method).

[1]:

# - Switch off warnings
import warnings
warnings.filterwarnings('ignore')

# - Rockpool imports
import rockpool.graph as rg
from dataclasses import dataclass

@dataclass(eq = False, repr = False)
class MyGraphModule(rg.GraphModule):
# - Define parameters as for any dataclass
param1: float
param2: int
param3: list


GraphModule provides a __post_init__() method that can be used to perform any validity checks after initialisation. __post_init__() also ensures that the input_nodes and output_nodes are correctly connected to the module being created.

If you override __post_init__(), you must call super().__post_init__.

[2]:

class MyGraphModule(MyGraphModule):
# - Any initialisation checks can be performed a __post_init__ method
def __post_init__(self, *args, **kwargs):
# - You *must* call super().__post_init__()
super().__post_init__(*args, **kwargs)

if param1 < param2:
raise ValueError('param1 must be > param2')


GraphModule provides several methods:

Method

Purpose

_factory()

Factory method to instantiate an object with self-created input and output nodes

__post_init__()

Perform any post-initialisation checks on the module

add_input()

Add a GraphNode as an input of this module

add_output()

Add a GraphNode as an output of this module

remove_input()

Remove a GraphNode as an input of this module

remove_output()

Remove a GraphNode as an output of this module

_convert_from()

Class method: Try to convert a different a GraphModule to an object of the current subclass

Transforming GraphModule s

GraphModule provides a method _convert_from(), which is used to transform GraphModule objects between various subclasses. These conversion rules must be specifically defined — there is no real automatic conversion between classes. If you do not override _convert_from() then you will not be able to convert other GraphModule subclasses to objects of your class.

Below is an example implementation of _convert_from().

[3]:

import rockpool.graph as rg
from typing import List
from dataclasses import dataclass

@dataclass(eq=False, repr=False)
class MyNeurons(rg.GenericNeurons):
thresholds: List[int]
dt: float

@classmethod
def _convert_from(cls, mod: rg.GraphModule) -> rg.GraphModule:
if isinstance(mod, cls):
# - No need to do anything
return mod

elif isinstance(mod, LIFNeuronWithSynsRealValue):
# - Convert from a real-valued LIF neuron
# - Get a value for dt to use in the conversion
if mod.dt is None:
raise ValueError(
f"Graph module of type {type(mod).__name__} has no dt set, so cannot convert time constants when converting to {cls.__name__}."
)

# - Get thresholds from source module
thresholds = np.round(np.array(mod.threshold)).astype(int).tolist()

# - Build a new self module to insert into the graph
neurons = cls._factory(
len(mod.input_nodes),
len(mod.output_nodes),
mod.name,
thresholds,
mod.dt,
)

# - Replace the target module and return
rg.replace_module(mod, neurons)
return neurons

else:
raise ValueError(
f"Graph module of type {type(mod).__name__} cannot be converted to a {cls.__name__}"
)


In the example above, the rules match specific subclasses of GraphModule, and convert them explicitly.

Creating a mapper

The steps in mapping a graph onto some target hardware are generally

• Check DRC rules (this permits you to make many assumptions about the graph structure once DRC rules pass)

• Convert neuron graph module types to types that match the hardware

• Assign hardware IDs to neurons, weights, inputs, outputs

• Pull required data from the graph and build an equivalent hardware configuration

Currently there is a mapper for the Xylo architecture in devices.xylo.mapper(). Look through the code there for an example of building a mapper.

DRC checks

The suggested way to perform DRC checks is to write a set of functions, each of which defines a single design rule, as an check over a graph. If the design rule is violated, then you should raise an error.

Below are examples of a few design rules.

[4]:

import rockpool.graph as rg
from typing import List, Callable

# - Define an error class for DRC violations
class DRCError(ValueError):
pass

def output_nodes_have_neurons_as_source(graph: rg.GraphModule):
# - All output nodes must have a source that is a neuron
for n in graph.output_nodes:
for s in n.source_modules:
if not isinstance(s, rg.GenericNeurons):
raise DRCError(
f"All network outputs must be directly from neurons.\nA network output node {n} has a source {s} which is not a neuron."
)

def first_module_is_a_weight(graph: rg.GraphModule):
# - The first module after the input must be a set of weights
for inp in graph.input_nodes:
for sink in inp.sink_modules:
if not isinstance(sink, rg.LinearWeights):
raise DRCError(
f"The network input must go first through a weight.\nA network input node {inp} has a sink module {sink} which is not a LinearWeight."
)

def le_16_input_channels(graph: rg.GraphModule):
# - Only 16 input channels are supported
if len(graph.input_nodes) > 16:
raise DRCError(
f"Xylo only supports up to 16 input channels. The network requires {len(graph.input_nodes)} input channels."
)


Now we show a suggested way to collect the rules and perform a DRC.

[5]:

# - Collect a list of DRC functions
xylo_drc = [
output_nodes_have_neurons_as_source,
first_module_is_a_weight,
le_16_input_channels,
]

def check_drc(graph, design_rules: List[Callable[[rg.GraphModule], None]]):
"""
Perform a design rule check
"""
for dr in design_rules:
try:
dr(graph)
except DRCError as e:
raise DRCError(
f"Design rule {dr.__name__} triggered an error:\n"
+ "".join([f"{msg}" for msg in e.args])
)

# - To perform the DRC check, use the function like so:
# check_drc(graph, xylo_drc)