Module training.adversarial_jax
Functions to implement adversarial training approaches using Jax
See also
👹 Adversarial training illustrates how to use the functions in this module to implement adversarial attacks on the parameters of a network during training.
Functions overview
|
Implement a hybrid task / adversarial robustness loss |
|
Performs the PGA (projected gradient ascent) based attack on the parameters of the network given inputs. |
Functions
- training.adversarial_jax.adversarial_loss(parameters: Iterable | MutableMapping | Mapping, net: JaxModule, inputs: ndarray, target: ndarray, task_loss: Callable[[ndarray, ndarray], float], mismatch_loss: Callable[[ndarray, ndarray], float], rng_key: Any, noisy_forward_std: float = 0.0, initial_std: float = 0.001, mismatch_level: float = 0.025, beta_robustness: float = 0.25, attack_steps: int = 10) float [source]
Implement a hybrid task / adversarial robustness loss
This loss function combines a task loss with a loss that evaluates how robust a network is to parameter attack. The combined loss has the form \(\mathcal{L} = \mathcal{L}_{nat}(f(X,\Theta),y) + \beta_{rob} \cdot \mathcal{L}_{rob}(f(X,\Theta),f(X,\mathcal{A}(\Theta)))\) where \(\mathcal{A}(\Theta)\) is an PGA-based adversary and \(\Theta\) are the weights of the input that are perturbed by Gaussian noise during the forward pass.
The goal is to train a network that performs a desired task, but where the trained network is insensitive to modification of its parameters. This approach is useful for neuromorphic hardware that exhibits uncontrolled parameter mismatch on deployment.
The method combines two aspects — Gaussian noise added to the parameters during the forward pass, and
See also
👹 Adversarial training for an example of how to train a network using this adversarial attack during training.
- Parameters:
parameters (Tree) – Parameters of the network (obtained by e.g.
net.parameters()
)net (JaxModule) – A JaxModule undergoing training
inputs (np.ndarray) – Inputs that will be passed through the network
target (np.ndarray) – Targets for the network prediction. Can be anything as long as
training_loss
can cope with the type/shapetask_loss (Callable) – Task loss. Can be anything used for training a NN (e.g. cat. cross entropy). Signature:
task_loss(net_output, target)
.mismatch_loss (Callable) – Mismatch loss between output of nominal and attacked network. Takes as input two
np.ndarray
s and returns afloat
. Example: KL divergence between softmaxed logits of the networks. Signature:mismatch_loss(net_output_star, net_output)
.rng_key (JaxRNGKey) – A Jax RNG key
noisy_forward_std (float) – Float (\(\zeta_{forward}\)) determining the amound of noise added to the parameters in the forward pass of the network. Model: \(\Theta = \Theta + \zeta_{forward} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I})\). Default:
0.
; do not use noise in the forward passinitial_std (float) – Initial perturbation (\(\zeta_{initial}\)) of the parameters according to \(\Theta + \zeta_{initial} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I})\)
mismatch_level (float) – Size by which the adversary can perturb the weights (\(\zeta\)). Attack will be in \([\Theta-\zeta \cdot |\Theta|,\Theta+\zeta \cdot |\Theta|]\). Default:
0.025
beta_robustness (float) – Tradeoff parameter for the adversarial regularizer. Setting to
0.0
trains without adversarial loss but is much slower and should not be done. Default:0.25
attack_steps (int) – Number of PGA steps to be taken during each training iteration, as part of the adversarial attack. Default:
10
- Returns:
The calculated loss, combining task loss and adversarial attack robustness loss
- Return type:
float
- training.adversarial_jax.pga_attack(params_flattened: List, net: Callable[[ndarray], ndarray], rng_key: Any, inputs: ndarray, net_out_original: ndarray, tree_def_params: Any, mismatch_loss: Callable[[ndarray, ndarray], float], attack_steps: int = 10, mismatch_level: float = 0.025, initial_std: float = 0.001) Tuple[List, Dict] [source]
Performs the PGA (projected gradient ascent) based attack on the parameters of the network given inputs.
This function performs an attack on the parameters of a network, using the gradient of a supplied loss. Starting from an initial set of parameters \(\Theta\) (
params_flattened
), we iteratively modify the parameters in order to worsen a supplied loss functionmismatch_loss
.mismatch_loss
measures a comparison between the output of the network at the initial parameters (net_out_original
) and the output of the network at the modified parameters \(\Theta^*`\). We compute the gradient ofmismatch_loss
w.r.t. the modified parameters \(\Theta^*\), and step in a projected direction along the sign of the gradient.The step size at each iteration for a parameter
p
is given by(mismatch_level * abs(p)) / attack_steps
.- Parameters:
params_flattened (List) – Flattened pytree that was obtained using
jax.tree_util.tree_flatten
of the network parameters (obtained bynet.parameters()
)net (Callable) – A function (e.g.
Sequential
object) that takes annp.ndarray
and generates anothernp.ndarray
rng_key (JaxRNGKey) – A Jax random key
attack_steps (int) – Number of PGA steps to be taken
mismatch_level (float) – Size by which the adversary can perturb the weights (\(\zeta\)). Attack will be in \([\Theta-\zeta \cdot |\Theta|,\Theta+\zeta \cdot |\Theta|]\)
initial_std (float) – Initial perturbation (\(\zeta_{initial}\)) of the parameters according to \(\Theta + \zeta_{initial} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I})\)
inputs (np.ndarray) – Inputs that will be passed through the network
net_out_original (np.ndarray) – Outputs of the network using the original weights
tree_def_params (JaxTreeDef) – Tree structure obtained by calling
jax.tree_util.tree_flatten
ontheta_star_unflattened
. Basically defining the shape oftheta
/theta_star
mismatch_loss (Callable) – Mismatch loss. Takes as input two
np.ndarray
s and returns afloat
. Example: KL divergence between softmaxed logits of the networks. Signature:mismatch_loss(target, net_output)
.
- Returns:
Tuple comprising \(\Theta^*\) in flattened form and a dictionary holding the
grads
andlosses
for every PGA iteration- Return type:
Tuple[List, Dict]