training.adversarial_jax.adversarial_loss
- training.adversarial_jax.adversarial_loss(parameters: Iterable | MutableMapping | Mapping, net: JaxModule, inputs: ndarray, target: ndarray, task_loss: Callable[[ndarray, ndarray], float], mismatch_loss: Callable[[ndarray, ndarray], float], rng_key: Any, noisy_forward_std: float = 0.0, initial_std: float = 0.001, mismatch_level: float = 0.025, beta_robustness: float = 0.25, attack_steps: int = 10) float [source]
Implement a hybrid task / adversarial robustness loss
This loss function combines a task loss with a loss that evaluates how robust a network is to parameter attack. The combined loss has the form \(\mathcal{L} = \mathcal{L}_{nat}(f(X,\Theta),y) + \beta_{rob} \cdot \mathcal{L}_{rob}(f(X,\Theta),f(X,\mathcal{A}(\Theta)))\) where \(\mathcal{A}(\Theta)\) is an PGA-based adversary and \(\Theta\) are the weights of the input that are perturbed by Gaussian noise during the forward pass.
The goal is to train a network that performs a desired task, but where the trained network is insensitive to modification of its parameters. This approach is useful for neuromorphic hardware that exhibits uncontrolled parameter mismatch on deployment.
The method combines two aspects — Gaussian noise added to the parameters during the forward pass, and
See also
👹 Adversarial training for an example of how to train a network using this adversarial attack during training.
- Parameters:
parameters (Tree) – Parameters of the network (obtained by e.g.
net.parameters()
)net (JaxModule) – A JaxModule undergoing training
inputs (np.ndarray) – Inputs that will be passed through the network
target (np.ndarray) – Targets for the network prediction. Can be anything as long as
training_loss
can cope with the type/shapetask_loss (Callable) – Task loss. Can be anything used for training a NN (e.g. cat. cross entropy). Signature:
task_loss(net_output, target)
.mismatch_loss (Callable) – Mismatch loss between output of nominal and attacked network. Takes as input two
np.ndarray
s and returns afloat
. Example: KL divergence between softmaxed logits of the networks. Signature:mismatch_loss(net_output_star, net_output)
.rng_key (JaxRNGKey) – A Jax RNG key
noisy_forward_std (float) – Float (\(\zeta_{forward}\)) determining the amound of noise added to the parameters in the forward pass of the network. Model: \(\Theta = \Theta + \zeta_{forward} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I})\). Default:
0.
; do not use noise in the forward passinitial_std (float) – Initial perturbation (\(\zeta_{initial}\)) of the parameters according to \(\Theta + \zeta_{initial} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I})\)
mismatch_level (float) – Size by which the adversary can perturb the weights (\(\zeta\)). Attack will be in \([\Theta-\zeta \cdot |\Theta|,\Theta+\zeta \cdot |\Theta|]\). Default:
0.025
beta_robustness (float) – Tradeoff parameter for the adversarial regularizer. Setting to
0.0
trains without adversarial loss but is much slower and should not be done. Default:0.25
attack_steps (int) – Number of PGA steps to be taken during each training iteration, as part of the adversarial attack. Default:
10
- Returns:
The calculated loss, combining task loss and adversarial attack robustness loss
- Return type:
float