training.adversarial_jax.pga_attack(params_flattened: List, net: Callable[[ndarray], ndarray], rng_key: Any, inputs: ndarray, net_out_original: ndarray, tree_def_params: Any, mismatch_loss: Callable[[ndarray, ndarray], float], attack_steps: int = 10, mismatch_level: float = 0.025, initial_std: float = 0.001) Tuple[List, Dict][source]

Performs the PGA (projected gradient ascent) based attack on the parameters of the network given inputs.

This function performs an attack on the parameters of a network, using the gradient of a supplied loss. Starting from an initial set of parameters \(\Theta\) (params_flattened), we iteratively modify the parameters in order to worsen a supplied loss function mismatch_loss. mismatch_loss measures a comparison between the output of the network at the initial parameters (net_out_original) and the output of the network at the modified parameters \(\Theta^*`\). We compute the gradient of mismatch_loss w.r.t. the modified parameters \(\Theta^*\), and step in a projected direction along the sign of the gradient.

The step size at each iteration for a parameter p is given by (mismatch_level * abs(p)) / attack_steps.

  • params_flattened (List) – Flattened pytree that was obtained using jax.tree_util.tree_flatten of the network parameters (obtained by net.parameters())

  • net (Callable) – A function (e.g. Sequential object) that takes an np.ndarray and generates another np.ndarray

  • rng_key (JaxRNGKey) – A Jax random key

  • attack_steps (int) – Number of PGA steps to be taken

  • mismatch_level (float) – Size by which the adversary can perturb the weights (\(\zeta\)). Attack will be in \([\Theta-\zeta \cdot |\Theta|,\Theta+\zeta \cdot |\Theta|]\)

  • initial_std (float) – Initial perturbation (\(\zeta_{initial}\)) of the parameters according to \(\Theta + \zeta_{initial} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I})\)

  • inputs (np.ndarray) – Inputs that will be passed through the network

  • net_out_original (np.ndarray) – Outputs of the network using the original weights

  • tree_def_params (JaxTreeDef) – Tree structure obtained by calling jax.tree_util.tree_flatten on theta_star_unflattened. Basically defining the shape of theta/theta_star

  • mismatch_loss (Callable) – Mismatch loss. Takes as input two np.ndarray s and returns a float. Example: KL divergence between softmaxed logits of the networks. Signature: mismatch_loss(target, net_output).


Tuple comprising \(\Theta^*\) in flattened form and a dictionary holding the grads and losses for every PGA iteration

Return type:

Tuple[List, Dict]