Source code for graph.utils

"""
Utilities for generating and manipulating computational graphs

See Also:
    See :ref:`/advanced/graph_overview.ipynb` for an introduction to computational graphs.

"""


from rockpool.graph.graph_base import (
    GraphModule,
    GraphHolder,
    GraphNode,
    SetList,
    GraphModuleBase,
)

import copy

from typing import Any, Optional, Iterable, Tuple

__all__ = [
    "connect_modules",
    "bag_graph",
    "find_modules_of_subclass",
    "replace_module",
    "find_recurrent_modules",
]


[docs]def connect_modules( source: GraphModuleBase, dest: GraphModuleBase, source_indices: Optional[Iterable[int]] = None, dest_indices: Optional[Iterable[int]] = None, ) -> None: """ Connect two :py:class:`.GraphModule` s together Connecting two graph modules can only occur if the output and input dimensionality match across the connection. The output :py:class:`.GraphNode` s from the source module will be merged with the input :py:class:`.GraphNodes` of the destination module. The :py:class:`.GraphNode` s of the destination module will then be discarded. If ``source`` or ``dest`` are :py:class:`.GraphHolder` s, then the internal subgraphs will be connected, and the :py:class:`.GraphHolder` s may be discarded. Examples: >>> connect_modules(mod1, mod2) # Modules are connected in place, from all output node to all input nodes >>> connect_modules(mod1, mod2, range(5)) # Connect a subset of source output nodes to the destination module # Output nodes `mod1.output_nodes[0:5]` are connected to all input nodes `mod2.input_nodes[:]` >>> connect_modules(mod1, mod2, None, range(3)) # All output nodes `mod1.output_nodes[:]` are connected to input nodes `mod2.input_nodes[0:3]` >>> connect_modules(mod1, mod2, [0, 2, 4], [1, 2, 5]) # `mod1` output nodes 0, 2 and 4 are connected to `mod2` input nodes 1, 2, 5 Args: source (GraphModule): The source graph module to connect dest (GraphModule): The destination graph module to connect source_indices (Optional[Iterable[int]]): The indices of the ``source`` output nodes to connect over. Default: ``None``, use all output nodes dest_indices (Optional[Iterable[int]]): The indices of ``dest`` input nodes to connect over. Default: ``None``, use all input nodes """ # - Get indices for source and destination nodes source_indices = ( range(len(source.output_nodes)) if source_indices is None else source_indices ) len_source_nodes = len(source_indices) dest_indices = ( range(len(dest.input_nodes)) if dest_indices is None else dest_indices ) len_dest_nodes = len(dest_indices) # - Check channel dimensions if len_source_nodes != len_dest_nodes: raise ValueError( f"Connecting {source.name} and {dest.name}. Number of output nodes {len_source_nodes} does not match number of input nodes {len_dest_nodes}." ) # - Wire up modules over nodes. Keep only the output nodes from the source module. for num in range(len_source_nodes): # - Get corresponding source and dest nodes index source_node_index = source_indices[num] dest_node_index = dest_indices[num] # - Get corresponding source and dest nodes to merge s_o_node = source.output_nodes[source_node_index] d_i_node = dest.input_nodes[dest_node_index] # - Copy all sinks and sources from dest node into source node [s_o_node.add_sink(t) for t in d_i_node.sink_modules] [s_o_node.add_source(s) for s in d_i_node.source_modules] # - Remove all module links from dest node del d_i_node.source_modules[:] del d_i_node.sink_modules[:] # - Replace input node in all sink objects with the connected output nodes # We need to use a copy to get the correct index; cannot use `[dest.input_nodes[i] for i in dest_index]` since we are modifying `dest.input_nodes` dest_nodes = copy.copy(dest.input_nodes) for num in range(len_dest_nodes): # - Get corresponding source and dest nodes index source_node_index = source_indices[num] dest_node_index = dest_indices[num] # - Get corresponding source and dest nodes s_o_node = source.output_nodes[source_node_index] d_i_node = dest_nodes[dest_node_index] # - For all source modules to this node, replace the node on the output for sm in s_o_node.source_modules: if d_i_node in sm.output_nodes: this_source_node_index = sm.output_nodes.index(d_i_node) sm.output_nodes[this_source_node_index] = s_o_node # - For all sink modules to this node, replace the node on the input for sink in s_o_node.sink_modules: if d_i_node in sink.input_nodes: this_dest_node_index = sink.input_nodes.index(d_i_node) sink.input_nodes[this_dest_node_index] = s_o_node # - Replace destination input nodes, if required if d_i_node in dest.input_nodes: this_dest_node_index = dest.input_nodes.index(d_i_node) dest.input_nodes[this_dest_node_index] = s_o_node # - Catch "pass through" modules as destination module, replace output nodes as well if d_i_node in dest.output_nodes: this_dest_node_index = dest.output_nodes.index(d_i_node) dest.output_nodes[this_dest_node_index] = s_o_node
[docs]def bag_graph( graph: GraphModuleBase, nodes_bag: Optional[SetList[GraphNode]] = None, modules_bag: Optional[SetList[GraphModule]] = None, ) -> Tuple[SetList[GraphNode], SetList[GraphModule]]: """ Convert a graph into a collection of connection nodes and modules, by traversal A graph will be traversed, following all connections. The connection :py:class:`.GraphNode` s and :py:class:`.GraphModule` s will be collected and returned in two collections. Any :py:class:`.GraphHolder` modules will be ignored and discarded. Args: graph (GraphModuleBase): A graph to analyse nodes_bag (SetList): The initial nodes bag. Used in recursive calls. Default: ``None`` modules_bag (SetList): The initial modules bag. Used in recursive calls. Default: ``None`` Returns: Tuple[SetList[GraphNode]]: nodes, modules. `nodes` will be a :py:class:`SetList` containing all the reachable :py:class:`GraphNode` s in `graph`. `modules` will be a :py:class:`SetList` containing all the reachable :py:class:`GraphModule` s in `graph`. """ nodes_bag = SetList() if nodes_bag is None else nodes_bag modules_bag = SetList() if modules_bag is None else modules_bag # - Have we seen this module before? if graph not in modules_bag: # - Add this module to the bag modules_bag.add(graph) # - Add input and output nodes to bag [nodes_bag.add(n) for n in graph.input_nodes] [nodes_bag.add(n) for n in graph.output_nodes] # - Recurse over input nodes for n in graph.input_nodes: # - Recurse over source modules for source in n.source_modules: bag_graph(source, nodes_bag, modules_bag) for sink in n.sink_modules: # - Recurse over sink modules bag_graph(sink, nodes_bag, modules_bag) # - Recurse over output nodes for n in graph.output_nodes: for source in n.source_modules: # - Recurse over source modules bag_graph(source, nodes_bag, modules_bag) for sink in n.sink_modules: # - Recurse over sink modules bag_graph(sink, nodes_bag, modules_bag) # - Remove all GraphHolders from the modules bag modules_bag_to_return = SetList() for mod in modules_bag: if not isinstance(mod, GraphHolder): modules_bag_to_return.add(mod) return nodes_bag, modules_bag_to_return
[docs]def find_modules_of_subclass(graph: GraphModuleBase, cls) -> SetList[Any]: """ Search a graph for all :py:class:`.GraphModule` s of a specific class or any subclass The search uses `isinstance` to search for ``cls``, so any subclass of ``cls`` will also be found. Args: graph (GraphModuleBase): cls: A class to search for instances of, or instances of any subclass Returns: SetList[Any]: A collection of objects of the desired class """ _, modules_bag = bag_graph(graph) return SetList(m for m in modules_bag if isinstance(m, cls))
[docs]def replace_module(target_module: GraphModule, replacement_module: GraphModule) -> None: """ Replace a graph module with a different module This function removes a target graph module from a graph, and replaces it with a replacement module. It removes the target module from any connection :py:class:`.GraphNode` s, and wires in the replacement module instead. Args: target_module (GraphModule): A module inside a graph to replace replacement_module (GraphModule): A replacement module to wire into the graph, in place of ``target_module`` """ # - Check that the input and output numbers match if len(target_module.input_nodes) != len(replacement_module.input_nodes): raise ValueError("Number of input nodes do not match") if len(target_module.output_nodes) != len(replacement_module.output_nodes): raise ValueError("Number of input nodes do not match") # - Remove target_module from input and output nodes, replace with replacement_module for n in target_module.input_nodes: n.remove_sink(target_module) n.add_sink(replacement_module) for n in target_module.output_nodes: n.remove_source(target_module) n.add_source(replacement_module) # - Include original input and output nodes in replacement_module replacement_module.input_nodes = target_module.input_nodes replacement_module.output_nodes = target_module.output_nodes
[docs]def find_recurrent_modules(graph: GraphModuleBase) -> Tuple[SetList[GraphModule]]: """ Search for graph modules that are connected in a one-module loop A "recurrent module" is defined as a graph module that connects with itself via another single graph module. e.g. a module of neurons, connected to a module of weights that itself connects recurrently back from output of the neurons to the input of the neurons. Args: graph (GraphModuleBase): A graph to search Returns: SetList[GraphModule]: A collection containing all identified recurrent modules in the graph Tuple[SetList[GraphNode]] : modules, recurrent_modules. `modules` will be a :py:class:`SetList` containing all the reachable :py:class:`GraphModule` s in `graph`. `recurrent_modules` is a collection containing all identified recurrent modules in the graph """ _, modules = bag_graph(graph) recurrent_modules = SetList() for m in modules: # - Get a collection of all source modules source_modules = SetList() [source_modules.extend(i_node.source_modules) for i_node in m.input_nodes] # - Get a collection of all destination modules dest_modules = SetList() [dest_modules.extend(d_node.sink_modules) for d_node in m.output_nodes] # - Detect duplicates between source and destination modules if len(set(source_modules).intersection(dest_modules)) > 0: recurrent_modules.add(m) return modules, recurrent_modules