Source code for utilities.backend_management

"""
Utility functionality for managing backends

To check a standard backend, use :py:func:`.backend_available`. To check a non-standard backend specification, use :py:func:`.check_backend`.

To build a shim class that raises an error on instantiation, for when a required backend is not available, use :py:func:`.missing_backend_shim`.
"""
import importlib
from importlib import util
from typing import List, Union, Tuple, Optional, Dict

# - Configure exports
__all__ = ["backend_available", "check_backend", "missing_backend_shim", "AbortImport"]


class AbortImport(Exception):
    pass


[docs]def check_torch_cuda_available() -> bool: try: from torch.cuda import is_available return is_available() except: return False
[docs]def check_samna_available() -> bool: """ check_samna_available controls if samna package is "installed" and "usable" The default `backend_available()` operation cannot correctly identifies the samna availability. In the case that one installed samna and then uninstalled via pip * `pip install samna` then `pip uninstall samna` samna leaves a trace and `import samna` does not raise an error even though the package is not usable. :return: true if samna package available :rtype: bool """ try: import samna try: samna.__version__ except: return False except: return False return True
# - Maintain a cache of checked backends __checked_backends: Dict[str, bool] = {} # - Specifications for common backends __backend_specs: Dict[str, tuple] = { "numpy": (), "numba": (), "jax": (["jax", "jaxlib"],), "torch": (), "sinabs": (), "sinabs-exodus": (["sinabs", "sinabs.exodus"], check_torch_cuda_available()), "brian": (["brian2"],), "cuda": (["torch"], check_torch_cuda_available()), "samna": (["samna"], check_samna_available()), }
[docs]def check_backend( backend_name: str, required_modules: Optional[Union[Tuple[str], List[str]]] = None, check_flag: bool = True, ) -> bool: """ Check if a backend is available, and register it in a list of available backends Args: backend_name (str): The name of this backend to check for and register required_modules (Optional[List[str]]): A list of required modules to search for. If ``None`` (default), check the backend name check_flag (bool): A manual check that can be performed externally, to see if the backend is available Returns: bool: The backend is available """ # - See if the backend check is already cached if backend_name in __checked_backends: return __checked_backends[backend_name] # - If no list of required modules, just check the backend name if required_modules is None: required_modules = [backend_name] requirements_met = check_flag for spec in required_modules: try: # - Check the required module is installed requirements_met = requirements_met and (util.find_spec(spec) is not None) # - Try to import the module importlib.import_module(spec) except Exception as e: requirements_met = False if not requirements_met: break # - Register the backend as having been checked if backend_name not in __checked_backends: __checked_backends.update({backend_name: requirements_met}) # - Let the caller know if we passed the check return requirements_met
[docs]def backend_available(*backend_names) -> bool: """ Report if a backend is available for use This function returns immediately if the named backend has already been checked previously. Otherwise, if the backend is "standard", then it will be checked for availability. If the backend is non-standard, it cannot be checked automatically. In that case you must use :py:func:`.check_backend` directly. Args: backend_name0, backend_name1, ... (str): A backend to check Returns: bool: ``True`` iff the backend is available for use """ def check_single_backend(backend_name): if backend_name in __checked_backends: return __checked_backends[backend_name] elif backend_name in __backend_specs: return check_backend(backend_name, *__backend_specs[backend_name]) else: return check_backend(backend_name) return all([check_single_backend(be) for be in backend_names])
[docs]def missing_backend_shim(class_name: str, backend_name: str): """ Make a class constructor that raises an error about a missing backend Examples: Generate a `LIFTorch` class shim, that will raise an error on instantiation. >>> LIFTorch = missing_backend_shim('LIFTorch', 'torch') >>> LIFTorch((3,), tau_syn = 10e-3) ModuleNotFoundError: Missing the `torch` backend. `LIFTorch` objects, and others relying on `torch` are not available. Args: class_name (str): The intended class name backend_name (str): The required backend that is missing Returns: Class: A class that raises an error on construction """ class MBSMeta(type): def __getattr__(cls, *args): raise ModuleNotFoundError( f"Missing the `{backend_name}` backend. `{class_name}` objects, and others relying on `{backend_name}` are not available." ) class MissingBackendShim(metaclass=MBSMeta): """ BACKEND MISSING FOR THIS CLASS """ def __init__(self, *args, **kwargs): raise ModuleNotFoundError( f"Missing the `{backend_name}` backend. `{class_name}` objects, and others relying on `{backend_name}` are not available." ) return MissingBackendShim
[docs]def missing_backend_error(class_name: str, backend_name: str): """ Raise a ``ModuleNotFoundError`` exception, with information about a missing backend Args: class_name (str): Name of a class which is unavailable backend_name (str): "User-facing" of the backend which is unavailable Raises: ModuleNotFoundError: Describe the missing backend """ def __init__(self, *args, **kwargs): raise ModuleNotFoundError( f"Missing the `{backend_name}` backend. `{class_name}` objects, and others relying on `{backend_name}` are not available." ) return __init__
[docs]def list_backends(): """ Print a list of computational backends available in this session """ print("Backends available to Rockpool:") for backend in __backend_specs.keys(): print(f"{backend:>15}: {backend_available(backend)}")
[docs]def torch_version_satisfied( req_major: int = 0, req_minor: int = 0, req_patch: int = 0 ) -> bool: """ Check if the installed version of torch satisfies a minimum version requirement i.e. torch 2.0.0 >= 1.12.0 : True torch 1.12.0 >= 1.12.0 : True torch 1.11.0 >= 1.12.0 : False Args: req_major (int): The minimum major version required req_minor (int): The minimum minor version required req_patch (int): The minimum patch version required Returns: bool: The installed version of torch satisfies the minimum version requirement """ if not backend_available("torch"): return False import torch # - Check torch version lib_major, lib_minor, lib_patch = torch.__version__.split(".") patch_vers = lib_patch.split("+") if len(patch_vers) > 1: lib_patch, *lib_cuda = patch_vers if int(lib_major) > req_major: return True elif int(lib_major) == req_major: if int(lib_minor) > req_minor: return True elif int(lib_minor) == req_minor: if int(lib_patch) >= req_patch: return True else: return False else: return False else: return False