"""
Utility functionality for managing backends
To check a standard backend, use :py:func:`.backend_available`. To check a non-standard backend specification, use :py:func:`.check_backend`.
To build a shim class that raises an error on instantiation, for when a required backend is not available, use :py:func:`.missing_backend_shim`.
"""
import importlib
from importlib import util
from typing import List, Union, Tuple, Optional, Dict
# - Configure exports
__all__ = ["backend_available", "check_backend", "missing_backend_shim", "AbortImport"]
class AbortImport(Exception):
pass
[docs]def check_torch_cuda_available() -> bool:
try:
from torch.cuda import is_available
return is_available()
except:
return False
[docs]def check_samna_available() -> bool:
"""
check_samna_available controls if samna package is "installed" and "usable"
The default `backend_available()` operation cannot correctly identifies the samna availability.
In the case that one installed samna and then uninstalled via pip
* `pip install samna` then `pip uninstall samna`
samna leaves a trace and `import samna` does not raise an error even though the package is not usable.
:return: true if samna package available
:rtype: bool
"""
try:
import samna
try:
samna.__version__
except:
return False
except:
return False
return True
# - Maintain a cache of checked backends
__checked_backends: Dict[str, bool] = {}
# - Specifications for common backends
__backend_specs: Dict[str, tuple] = {
"numpy": (),
"numba": (),
"jax": (["jax", "jaxlib"],),
"torch": (),
"sinabs": (),
"sinabs-exodus": (["sinabs", "sinabs.exodus"], check_torch_cuda_available()),
"brian": (["brian2"],),
"cuda": (["torch"], check_torch_cuda_available()),
"samna": (["samna"], check_samna_available()),
}
[docs]def check_backend(
backend_name: str,
required_modules: Optional[Union[Tuple[str], List[str]]] = None,
check_flag: bool = True,
) -> bool:
"""
Check if a backend is available, and register it in a list of available backends
Args:
backend_name (str): The name of this backend to check for and register
required_modules (Optional[List[str]]): A list of required modules to search for. If ``None`` (default), check the backend name
check_flag (bool): A manual check that can be performed externally, to see if the backend is available
Returns:
bool: The backend is available
"""
# - See if the backend check is already cached
if backend_name in __checked_backends:
return __checked_backends[backend_name]
# - If no list of required modules, just check the backend name
if required_modules is None:
required_modules = [backend_name]
requirements_met = check_flag
for spec in required_modules:
try:
# - Check the required module is installed
requirements_met = requirements_met and (util.find_spec(spec) is not None)
# - Try to import the module
importlib.import_module(spec)
except Exception as e:
requirements_met = False
if not requirements_met:
break
# - Register the backend as having been checked
if backend_name not in __checked_backends:
__checked_backends.update({backend_name: requirements_met})
# - Let the caller know if we passed the check
return requirements_met
[docs]def backend_available(*backend_names) -> bool:
"""
Report if a backend is available for use
This function returns immediately if the named backend has already been checked previously.
Otherwise, if the backend is either a defined standard backend, or is a simple importable python module, then it will be checked for availability.
If the backend is non-standard, it cannot be checked automatically.
In that case you must use :py:func:`.check_backend` directly.
Args:
backend_name0, backend_name1, ... (str): A backend to check
Returns:
bool: ``True`` iff the backend is available for use
"""
def check_single_backend(backend_name):
if backend_name in __checked_backends:
return __checked_backends[backend_name]
elif backend_name in __backend_specs:
return check_backend(backend_name, *__backend_specs[backend_name])
else:
return check_backend(backend_name)
return all([check_single_backend(be) for be in backend_names])
[docs]def missing_backend_shim(class_name: str, backend_name: str):
"""
Make a class constructor that raises an error about a missing backend
Examples:
Generate a `LIFTorch` class shim, that will raise an error on instantiation.
>>> LIFTorch = missing_backend_shim('LIFTorch', 'torch')
>>> LIFTorch((3,), tau_syn = 10e-3)
ModuleNotFoundError: Missing the `torch` backend. `LIFTorch` objects, and others relying on `torch` are not available.
Args:
class_name (str): The intended class name
backend_name (str): The required backend that is missing
Returns:
Class: A class that raises an error on construction
"""
class MBSMeta(type):
def __getattr__(cls, *args):
raise ModuleNotFoundError(
f"Missing the `{backend_name}` backend. `{class_name}` objects, and others relying on `{backend_name}` are not available."
)
class MissingBackendShim(metaclass=MBSMeta):
"""
BACKEND MISSING FOR THIS CLASS
"""
def __init__(self, *args, **kwargs):
raise ModuleNotFoundError(
f"Missing the `{backend_name}` backend. `{class_name}` objects, and others relying on `{backend_name}` are not available."
)
return MissingBackendShim
[docs]def missing_backend_error(class_name: str, backend_name: str):
"""
Raise a ``ModuleNotFoundError`` exception, with information about a missing backend
Args:
class_name (str): Name of a class which is unavailable
backend_name (str): "User-facing" of the backend which is unavailable
Raises:
ModuleNotFoundError: Describe the missing backend
"""
def __init__(self, *args, **kwargs):
raise ModuleNotFoundError(
f"Missing the `{backend_name}` backend. `{class_name}` objects, and others relying on `{backend_name}` are not available."
)
return __init__
[docs]def list_backends():
"""
Print a list of computational backends available in this session
"""
print("Backends available to Rockpool:")
for backend in __backend_specs.keys():
print(f"{backend:>15}: {backend_available(backend)}")
[docs]def torch_version_satisfied(
req_major: int = 0, req_minor: int = 0, req_patch: int = 0
) -> bool:
"""
Check if the installed version of torch satisfies a minimum version requirement
i.e.
torch 2.0.0 >= 1.12.0 : True
torch 1.12.0 >= 1.12.0 : True
torch 1.11.0 >= 1.12.0 : False
Args:
req_major (int): The minimum major version required
req_minor (int): The minimum minor version required
req_patch (int): The minimum patch version required
Returns:
bool: The installed version of torch satisfies the minimum version requirement
"""
if not backend_available("torch"):
return False
import torch
# - Check torch version
lib_major, lib_minor, lib_patch = torch.__version__.split(".")
patch_vers = lib_patch.split("+")
if len(patch_vers) > 1:
lib_patch, *lib_cuda = patch_vers
if int(lib_major) > req_major:
return True
elif int(lib_major) == req_major:
if int(lib_minor) > req_minor:
return True
elif int(lib_minor) == req_minor:
if int(lib_patch) >= req_patch:
return True
else:
return False
else:
return False
else:
return False