Source code for training.adversarial_jax

"""
Functions to implement adversarial training approaches using Jax

See Also:
    :ref:`/tutorials/adversarial_training.ipynb` illustrates how to use the functions in this module to implement adversarial attacks on the parameters of a network during training.
"""

import jax
import numpy as np

from rockpool.nn.modules.jax.jax_module import JaxModule

import jax.tree_util as tu
import jax.random as random
from jax.lax import stop_gradient
import jax.numpy as jnp
from jax import value_and_grad

from jax.tree_util import Partial

from typing import Tuple, Callable, List, Dict, Any, Optional
from rockpool.typehints import Tree, JaxTreeDef, JaxRNGKey

__all__ = ["pga_attack", "adversarial_loss"]


def _split_and_sample_normal(
    key: JaxRNGKey, shape: Tuple
) -> Tuple[JaxRNGKey, np.ndarray]:
    """
    Split an RNG key and generate random data of a given shape following a standard Gaussian distribution

    Args:
        key (JaxRNGKey): Array of two ints. A Jax random key
        shape (tuple): The shape that the random normal data should have

    Returns:
        (JaxRNGKey, np.ndarray): Tuple of `(key,data)`. `key` is the new key that can be used in subsequent computations and `data` is the Gaussian data
    """
    key, subkey = random.split(key)
    val = random.normal(subkey, shape=shape)
    return key, val


def _eval_target_loss(
    parameters: List,
    inputs: np.ndarray,
    target: np.ndarray,
    net: JaxModule,
    tree_def_params: JaxTreeDef,
    loss: Callable[[np.ndarray, np.ndarray], float],
) -> float:
    """
    Calculate the loss of the network output against a target output.

    This function resets the states of the network, unflattens the parameters `theta_star` and assigns them to the network `net` and then evaluates the network on the `inputs` using the adversarial weights. Following, the `loss` function is evaluated using the target signal `target` against the newly generated outputs. The method returns the loss value.

    Args:
        parameters (List): Set of parameters to use during evolution. Flattened pytree that was obtained using `jax.tree_util.tree_flatten`
        inputs (np.ndarray): Inputs that will be passed through the network
        target (np.ndarray): Target network outputs to comapre against
        net (Callable): A function (e.g. `Sequential` object) that takes an `np.ndarray` and generates another `np.ndarray`
        tree_def_params (JaxTreeDef): Tree structure obtained by calling `jax.tree_util.tree_flatten` on `theta_star_unflattened`. Basically defining the shape of `theta`/`theta_star`
        loss (Callable[[np.ndarray, np.ndarray], float]): Comparison loss function. Takes as input two `np.ndarray` s with the same shape and returns a `float`: `loss(target, output)`. Example: KL divergence between softmaxed logits of the networks.

    Returns:
        float: The `loss` evaluated on the outputs generated by the network using `parameters`, against the `target`
    """
    # - Reset the network
    net = net.reset_state()

    # - Set the network parameters to `parameters`
    net = net.set_attributes(tu.tree_unflatten(tree_def_params, parameters))

    # - Evolve the network
    output, _, _ = net(inputs)

    # - Return comparison loss
    return loss(target, output)


[docs]@Partial( jax.jit, static_argnames=[ "tree_def_params", "mismatch_loss", "attack_steps", "mismatch_level", "initial_std", ], ) def pga_attack( params_flattened: List, net: Callable[[np.ndarray], np.ndarray], rng_key: JaxRNGKey, inputs: np.ndarray, net_out_original: np.ndarray, tree_def_params: JaxTreeDef, mismatch_loss: Callable[[np.ndarray, np.ndarray], float], attack_steps: int = 10, mismatch_level: float = 0.025, initial_std: float = 1e-3, ) -> Tuple[List, Dict]: """ Performs the PGA (projected gradient ascent) based attack on the parameters of the network given inputs. This function performs an attack on the parameters of a network, using the gradient of a supplied loss. Starting from an initial set of parameters :math:`\\Theta` (`params_flattened`), we iteratively modify the parameters in order to worsen a supplied loss function `mismatch_loss`. `mismatch_loss` measures a comparison between the output of the network at the initial parameters (`net_out_original`) and the output of the network at the modified parameters :math:`\\Theta^*``. We compute the gradient of `mismatch_loss` w.r.t. the modified parameters :math:`\\Theta^*`, and step in a projected direction along the sign of the gradient. The step size at each iteration for a parameter ``p`` is given by ``(mismatch_level * abs(p)) / attack_steps``. Args: params_flattened (List): Flattened pytree that was obtained using `jax.tree_util.tree_flatten` of the network parameters (obtained by `net.parameters()`) net (Callable): A function (e.g. `Sequential` object) that takes an `np.ndarray` and generates another `np.ndarray` rng_key (JaxRNGKey): A Jax random key attack_steps (int): Number of PGA steps to be taken mismatch_level (float): Size by which the adversary can perturb the weights (:math:`\zeta`). Attack will be in :math:`[\Theta-\zeta \cdot |\Theta|,\Theta+\zeta \cdot |\Theta|]` initial_std (float): Initial perturbation (:math:`\zeta_{initial}`) of the parameters according to :math:`\Theta + \zeta_{initial} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I})` inputs (np.ndarray): Inputs that will be passed through the network net_out_original (np.ndarray): Outputs of the network using the original weights tree_def_params (JaxTreeDef): Tree structure obtained by calling `jax.tree_util.tree_flatten` on `theta_star_unflattened`. Basically defining the shape of `theta`/`theta_star` mismatch_loss (Callable): Mismatch loss. Takes as input two `np.ndarray` s and returns a `float`. Example: KL divergence between softmaxed logits of the networks. Signature: ``mismatch_loss(target, net_output)``. Returns: Tuple[List, Dict]: Tuple comprising :math:`\Theta^*` in flattened form and a dictionary holding the `grads` and `losses` for every PGA iteration """ # - Create verbose dict verbose = {"grads": [], "losses": []} # - Initialize Theta* by adding Gaussian noise to each parameter theta_star = [] step_size = [] for p in params_flattened: rng_key, random_normal_var = _split_and_sample_normal(rng_key, p.shape) theta_star.append(p + jnp.abs(p) * initial_std * random_normal_var) step_size.append((mismatch_level * jnp.abs(p)) / attack_steps) # - Perform gradient ascent on the parameters Theta*, with respect to the provided mismatch loss for _ in range(attack_steps): # - Compute loss and gradients loss, grads_theta_star = value_and_grad(_eval_target_loss)( theta_star, inputs, net_out_original, net, tree_def_params, mismatch_loss ) # - Store the loss and gradients for this iteration verbose["losses"].append(loss) verbose["grads"].append(grads_theta_star) # - Step each parameter in the direction of the gradient, scaled to the parameter scale for idx in range(len(theta_star)): theta_star[idx] = theta_star[idx] + step_size[idx] * jnp.sign( grads_theta_star[idx] ) # - Return the attacked parameters return theta_star, verbose
[docs]@Partial( jax.jit, static_argnames=[ "net", "task_loss", "mismatch_loss", "noisy_forward_std", "initial_std", "mismatch_level", "beta_robustness", "attack_steps", ], ) def adversarial_loss( parameters: Tree, net: JaxModule, inputs: np.ndarray, target: np.ndarray, task_loss: Callable[[np.ndarray, np.ndarray], float], mismatch_loss: Callable[[np.ndarray, np.ndarray], float], rng_key: JaxRNGKey, noisy_forward_std: float = 0.0, initial_std: float = 1e-3, mismatch_level: float = 0.025, beta_robustness: float = 0.25, attack_steps: int = 10, ) -> float: """ Implement a hybrid task / adversarial robustness loss This loss function combines a task loss with a loss that evaluates how robust a network is to parameter attack. The combined loss has the form :math:`\mathcal{L} = \mathcal{L}_{nat}(f(X,\Theta),y) + \\beta_{rob} \cdot \mathcal{L}_{rob}(f(X,\Theta),f(X,\mathcal{A}(\Theta)))` where :math:`\mathcal{A}(\Theta)` is an PGA-based adversary and :math:`\Theta` are the weights of the input that are perturbed by Gaussian noise during the forward pass. The goal is to train a network that performs a desired task, but where the trained network is insensitive to modification of its parameters. This approach is useful for neuromorphic hardware that exhibits uncontrolled parameter mismatch on deployment. The method combines two aspects --- Gaussian noise added to the parameters during the forward pass, and See Also: :ref:`/tutorials/adversarial_training.ipynb` for an example of how to train a network using this adversarial attack during training. Args: parameters (Tree): Parameters of the network (obtained by e.g. `net.parameters()`) net (JaxModule): A JaxModule undergoing training inputs (np.ndarray): Inputs that will be passed through the network target (np.ndarray): Targets for the network prediction. Can be anything as long as `training_loss` can cope with the type/shape task_loss (Callable): Task loss. Can be anything used for training a NN (e.g. cat. cross entropy). Signature: ``task_loss(net_output, target)``. mismatch_loss (Callable): Mismatch loss between output of nominal and attacked network. Takes as input two `np.ndarray` s and returns a `float`. Example: KL divergence between softmaxed logits of the networks. Signature: ``mismatch_loss(net_output_star, net_output)``. rng_key (JaxRNGKey): A Jax RNG key noisy_forward_std (float): Float (:math:`\zeta_{forward}`) determining the amound of noise added to the parameters in the forward pass of the network. Model: :math:`\Theta = \Theta + \zeta_{forward} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I})`. Default: ``0.``; do not use noise in the forward pass initial_std (float): Initial perturbation (:math:`\zeta_{initial}`) of the parameters according to :math:`\Theta + \zeta_{initial} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I})` mismatch_level (float): Size by which the adversary can perturb the weights (:math:`\zeta`). Attack will be in :math:`[\Theta-\zeta \cdot |\Theta|,\Theta+\zeta \cdot |\Theta|]`. Default: ``0.025`` beta_robustness (float): Tradeoff parameter for the adversarial regularizer. Setting to ``0.0`` trains without adversarial loss but is much slower and should not be done. Default: ``0.25`` attack_steps (int): Number of PGA steps to be taken during each training iteration, as part of the adversarial attack. Default: ``10`` Returns: float: The calculated loss, combining task loss and adversarial attack robustness loss """ # - Handle the network state — randomise or reset net = net.reset_state() # - Add Gaussian noise to the parameters before evaluating params_flattened, tree_def_params = tu.tree_flatten(parameters) params_gaussian_flattened = [] for p in params_flattened: rng_key, random_normal_var = _split_and_sample_normal(rng_key, p.shape) params_gaussian_flattened.append( p + stop_gradient(jnp.abs(p) * noisy_forward_std * random_normal_var) ) # - Evaluate the task loss using the perturbed parameters loss_n = _eval_target_loss( params_gaussian_flattened, inputs, target, net, tree_def_params, task_loss ) # - Get output for the original parameters # - Reset network state net = net.reset_state() # - Set parameters to the original parameters net = net.set_attributes(parameters) # - Get the network output using the original parameters output_theta, _, _ = net(inputs) # - Perform the adversarial attack to obtain the attacked parameters `theta_star` theta_star, _ = pga_attack( params_flattened=params_flattened, net=net, rng_key=rng_key, attack_steps=attack_steps, mismatch_level=mismatch_level, initial_std=initial_std, inputs=inputs, net_out_original=output_theta, tree_def_params=tree_def_params, mismatch_loss=mismatch_loss, ) # - Compute robustness loss using the attacked parameters `theta_star` loss_r = _eval_target_loss( theta_star, inputs, output_theta, net, tree_def_params, mismatch_loss ) # - Add the robustness loss as a regularizer return loss_n + beta_robustness * loss_r