Functions to implement adversarial training approaches using Jax

👹 Adversarial training illustrates how to use the functions in this module to implement adversarial attacks on the parameters of a network during training.

Functions overview

 adversarial_loss(parameters, net, inputs, ...) Implement a hybrid task / adversarial robustness loss pga_attack(params_flattened, net, rng_key, ...) Performs the PGA (projected gradient ascent) based attack on the parameters of the network given inputs.

Functions

training.adversarial_jax.adversarial_loss(parameters: Union[collections.abc.Iterable, collections.abc.MutableMapping, collections.abc.Mapping], net: rockpool.nn.modules.jax.jax_module.JaxModule, inputs: numpy.ndarray, target: numpy.ndarray, task_loss: Callable[[numpy.ndarray, numpy.ndarray], float], mismatch_loss: Callable[[numpy.ndarray, numpy.ndarray], float], rng_key: Any, noisy_forward_std: float = 0.0, initial_std: float = 0.001, mismatch_level: float = 0.025, beta_robustness: float = 0.25, attack_steps: int = 10) float[source]

This loss function combines a task loss with a loss that evaluates how robust a network is to parameter attack. The combined loss has the form $$\mathcal{L} = \mathcal{L}_{nat}(f(X,\Theta),y) + \beta_{rob} \cdot \mathcal{L}_{rob}(f(X,\Theta),f(X,\mathcal{A}(\Theta)))$$ where $$\mathcal{A}(\Theta)$$ is an PGA-based adversary and $$\Theta$$ are the weights of the input that are perturbed by Gaussian noise during the forward pass.

The goal is to train a network that performs a desired task, but where the trained network is insensitive to modification of its parameters. This approach is useful for neuromorphic hardware that exhibits uncontrolled parameter mismatch on deployment.

The method combines two aspects — Gaussian noise added to the parameters during the forward pass, and

👹 Adversarial training for an example of how to train a network using this adversarial attack during training.

Parameters
• parameters (Tree) – Parameters of the network (obtained by e.g. net.parameters())

• net (JaxModule) – A JaxModule undergoing training

• inputs (np.ndarray) – Inputs that will be passed through the network

• target (np.ndarray) – Targets for the network prediction. Can be anything as long as training_loss can cope with the type/shape

• task_loss (Callable) – Task loss. Can be anything used for training a NN (e.g. cat. cross entropy). Signature: task_loss(net_output, target).

• mismatch_loss (Callable) – Mismatch loss between output of nominal and attacked network. Takes as input two np.ndarray s and returns a float. Example: KL divergence between softmaxed logits of the networks. Signature: mismatch_loss(net_output_star, net_output).

• rng_key (JaxRNGKey) – A Jax RNG key

• noisy_forward_std (float) – Float ($$\zeta_{forward}$$) determining the amound of noise added to the parameters in the forward pass of the network. Model: $$\Theta = \Theta + \zeta_{forward} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I})$$. Default: 0.; do not use noise in the forward pass

• initial_std (float) – Initial perturbation ($$\zeta_{initial}$$) of the parameters according to $$\Theta + \zeta_{initial} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I})$$

• mismatch_level (float) – Size by which the adversary can perturb the weights ($$\zeta$$). Attack will be in $$[\Theta-\zeta \cdot |\Theta|,\Theta+\zeta \cdot |\Theta|]$$. Default: 0.025

• beta_robustness (float) – Tradeoff parameter for the adversarial regularizer. Setting to 0.0 trains without adversarial loss but is much slower and should not be done. Default: 0.25

• attack_steps (int) – Number of PGA steps to be taken during each training iteration, as part of the adversarial attack. Default: 10

Returns

Return type

float

training.adversarial_jax.pga_attack(params_flattened: List, net: Callable[[numpy.ndarray], numpy.ndarray], rng_key: Any, inputs: numpy.ndarray, net_out_original: numpy.ndarray, tree_def_params: Any, mismatch_loss: Callable[[numpy.ndarray, numpy.ndarray], float], attack_steps: int = 10, mismatch_level: float = 0.025, initial_std: float = 0.001) Tuple[List, Dict][source]

Performs the PGA (projected gradient ascent) based attack on the parameters of the network given inputs.

This function performs an attack on the parameters of a network, using the gradient of a supplied loss. Starting from an initial set of parameters $$\Theta$$ (params_flattened), we iteratively modify the parameters in order to worsen a supplied loss function mismatch_loss. mismatch_loss measures a comparison between the output of the network at the initial parameters (net_out_original) and the output of the network at the modified parameters $$\Theta^*$$. We compute the gradient of mismatch_loss w.r.t. the modified parameters $$\Theta^*$$, and step in a projected direction along the sign of the gradient.

The step size at each iteration for a parameter p is given by (mismatch_level * abs(p)) / attack_steps.

Parameters
• params_flattened (List) – Flattened pytree that was obtained using jax.tree_util.tree_flatten of the network parameters (obtained by net.parameters())

• net (Callable) – A function (e.g. Sequential object) that takes an np.ndarray and generates another np.ndarray

• rng_key (JaxRNGKey) – A Jax random key

• attack_steps (int) – Number of PGA steps to be taken

• mismatch_level (float) – Size by which the adversary can perturb the weights ($$\zeta$$). Attack will be in $$[\Theta-\zeta \cdot |\Theta|,\Theta+\zeta \cdot |\Theta|]$$

• initial_std (float) – Initial perturbation ($$\zeta_{initial}$$) of the parameters according to $$\Theta + \zeta_{initial} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I})$$

• inputs (np.ndarray) – Inputs that will be passed through the network

• net_out_original (np.ndarray) – Outputs of the network using the original weights

• tree_def_params (JaxTreeDef) – Tree structure obtained by calling jax.tree_util.tree_flatten on theta_star_unflattened. Basically defining the shape of theta/theta_star

• mismatch_loss (Callable) – Mismatch loss. Takes as input two np.ndarray s and returns a float. Example: KL divergence between softmaxed logits of the networks. Signature: mismatch_loss(target, net_output).

Returns

Tuple comprising $$\Theta^*$$ in flattened form and a dictionary holding the grads and losses for every PGA iteration

Return type

Tuple[List, Dict]