training.adversarial_jax.adversarial_loss

training.adversarial_jax.adversarial_loss(parameters: Iterable | MutableMapping | Mapping, net: JaxModule, inputs: ndarray, target: ndarray, task_loss: Callable[[ndarray, ndarray], float], mismatch_loss: Callable[[ndarray, ndarray], float], rng_key: Any, noisy_forward_std: float = 0.0, initial_std: float = 0.001, mismatch_level: float = 0.025, beta_robustness: float = 0.25, attack_steps: int = 10) float[source]

Implement a hybrid task / adversarial robustness loss

This loss function combines a task loss with a loss that evaluates how robust a network is to parameter attack. The combined loss has the form \(\mathcal{L} = \mathcal{L}_{nat}(f(X,\Theta),y) + \beta_{rob} \cdot \mathcal{L}_{rob}(f(X,\Theta),f(X,\mathcal{A}(\Theta)))\) where \(\mathcal{A}(\Theta)\) is an PGA-based adversary and \(\Theta\) are the weights of the input that are perturbed by Gaussian noise during the forward pass.

The goal is to train a network that performs a desired task, but where the trained network is insensitive to modification of its parameters. This approach is useful for neuromorphic hardware that exhibits uncontrolled parameter mismatch on deployment.

The method combines two aspects — Gaussian noise added to the parameters during the forward pass, and

See also

👹 Adversarial training for an example of how to train a network using this adversarial attack during training.

Parameters:
  • parameters (Tree) – Parameters of the network (obtained by e.g. net.parameters())

  • net (JaxModule) – A JaxModule undergoing training

  • inputs (np.ndarray) – Inputs that will be passed through the network

  • target (np.ndarray) – Targets for the network prediction. Can be anything as long as training_loss can cope with the type/shape

  • task_loss (Callable) – Task loss. Can be anything used for training a NN (e.g. cat. cross entropy). Signature: task_loss(net_output, target).

  • mismatch_loss (Callable) – Mismatch loss between output of nominal and attacked network. Takes as input two np.ndarray s and returns a float. Example: KL divergence between softmaxed logits of the networks. Signature: mismatch_loss(net_output_star, net_output).

  • rng_key (JaxRNGKey) – A Jax RNG key

  • noisy_forward_std (float) – Float (\(\zeta_{forward}\)) determining the amound of noise added to the parameters in the forward pass of the network. Model: \(\Theta = \Theta + \zeta_{forward} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I})\). Default: 0.; do not use noise in the forward pass

  • initial_std (float) – Initial perturbation (\(\zeta_{initial}\)) of the parameters according to \(\Theta + \zeta_{initial} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I})\)

  • mismatch_level (float) – Size by which the adversary can perturb the weights (\(\zeta\)). Attack will be in \([\Theta-\zeta \cdot |\Theta|,\Theta+\zeta \cdot |\Theta|]\). Default: 0.025

  • beta_robustness (float) – Tradeoff parameter for the adversarial regularizer. Setting to 0.0 trains without adversarial loss but is much slower and should not be done. Default: 0.25

  • attack_steps (int) – Number of PGA steps to be taken during each training iteration, as part of the adversarial attack. Default: 10

Returns:

The calculated loss, combining task loss and adversarial attack robustness loss

Return type:

float