Module graph.utils

Utilities for generating and manipulating computational graphs

See also

See Computational graphs in Rockpool for an introduction to computational graphs.

Functions overview

bag_graph(graph[, nodes_bag, modules_bag])

Convert a graph into a collection of connection nodes and modules, by traversal

connect_modules(source, dest[, ...])

Connect two GraphModule s together

find_modules_of_subclass(graph, cls)

Search a graph for all GraphModule s of a specific class or any subclass


Search for graph modules that are connected in a one-module loop

replace_module(target_module, replacement_module)

Replace a graph module with a different module


graph.utils.bag_graph(graph: GraphModuleBase, nodes_bag: SetList[GraphNode] | None = None, modules_bag: SetList[GraphModule] | None = None) Tuple[SetList[GraphNode], SetList[GraphModule]][source]

Convert a graph into a collection of connection nodes and modules, by traversal

A graph will be traversed, following all connections. The connection GraphNode s and GraphModule s will be collected and returned in two collections. Any GraphHolder modules will be ignored and discarded.

  • graph (GraphModuleBase) – A graph to analyse

  • nodes_bag (SetList) – The initial nodes bag. Used in recursive calls. Default: None

  • modules_bag (SetList) – The initial modules bag. Used in recursive calls. Default: None


nodes, modules. nodes will be a SetList containing all the reachable GraphNode s in graph. modules will be a SetList containing all the reachable GraphModule s in graph.

Return type:


graph.utils.connect_modules(source: GraphModuleBase, dest: GraphModuleBase, source_indices: Iterable[int] | None = None, dest_indices: Iterable[int] | None = None) None[source]

Connect two GraphModule s together

Connecting two graph modules can only occur if the output and input dimensionality match across the connection. The output GraphNode s from the source module will be merged with the input GraphNodes of the destination module. The GraphNode s of the destination module will then be discarded.

If source or dest are GraphHolder s, then the internal subgraphs will be connected, and the GraphHolder s may be discarded.


>>> connect_modules(mod1, mod2)
# Modules are connected in place, from all output node to all input nodes
>>> connect_modules(mod1, mod2, range(5))
# Connect a subset of source output nodes to the destination module
# Output nodes `mod1.output_nodes[0:5]` are connected to all input nodes `mod2.input_nodes[:]`
>>> connect_modules(mod1, mod2, None, range(3))
# All output nodes `mod1.output_nodes[:]` are connected to input nodes `mod2.input_nodes[0:3]`
>>> connect_modules(mod1, mod2, [0, 2, 4], [1, 2, 5])
# `mod1` output nodes 0, 2 and 4 are connected to `mod2` input nodes 1, 2, 5
  • source (GraphModule) – The source graph module to connect

  • dest (GraphModule) – The destination graph module to connect

  • source_indices (Optional[Iterable[int]]) – The indices of the source output nodes to connect over. Default: None, use all output nodes

  • dest_indices (Optional[Iterable[int]]) – The indices of dest input nodes to connect over. Default: None, use all input nodes

graph.utils.find_modules_of_subclass(graph: GraphModuleBase, cls) SetList[Any][source]

Search a graph for all GraphModule s of a specific class or any subclass

The search uses isinstance to search for cls, so any subclass of cls will also be found.

  • graph (GraphModuleBase)

  • cls – A class to search for instances of, or instances of any subclass


A collection of objects of the desired class

Return type:


graph.utils.find_recurrent_modules(graph: GraphModuleBase) Tuple[SetList[GraphModule]][source]

Search for graph modules that are connected in a one-module loop

A “recurrent module” is defined as a graph module that connects with itself via another single graph module. e.g. a module of neurons, connected to a module of weights that itself connects recurrently back from output of the neurons to the input of the neurons.


graph (GraphModuleBase) – A graph to search


A collection containing all identified recurrent modules in the graph Tuple[SetList[GraphNode]] : modules, recurrent_modules. modules will be a SetList containing all the reachable GraphModule s in graph. recurrent_modules is a collection containing all identified recurrent modules in the graph

Return type:


graph.utils.replace_module(target_module: GraphModule, replacement_module: GraphModule) None[source]

Replace a graph module with a different module

This function removes a target graph module from a graph, and replaces it with a replacement module. It removes the target module from any connection GraphNode s, and wires in the replacement module instead.

  • target_module (GraphModule) – A module inside a graph to replace

  • replacement_module (GraphModule) – A replacement module to wire into the graph, in place of target_module