training.adversarial_jax.pga_attack(params_flattened: List, net: Callable[[numpy.ndarray], numpy.ndarray], rng_key: Any, inputs: numpy.ndarray, net_out_original: numpy.ndarray, tree_def_params: Any, mismatch_loss: Callable[[numpy.ndarray, numpy.ndarray], float], attack_steps: int = 10, mismatch_level: float = 0.025, initial_std: float = 0.001) Tuple[List, Dict][source]

Performs the PGA (projected gradient ascent) based attack on the parameters of the network given inputs.

This function performs an attack on the parameters of a network, using the gradient of a supplied loss. Starting from an initial set of parameters $$\Theta$$ (params_flattened), we iteratively modify the parameters in order to worsen a supplied loss function mismatch_loss. mismatch_loss measures a comparison between the output of the network at the initial parameters (net_out_original) and the output of the network at the modified parameters $$\Theta^*$$. We compute the gradient of mismatch_loss w.r.t. the modified parameters $$\Theta^*$$, and step in a projected direction along the sign of the gradient.

The step size at each iteration for a parameter p is given by (mismatch_level * abs(p)) / attack_steps.

Parameters
• params_flattened (List) – Flattened pytree that was obtained using jax.tree_util.tree_flatten of the network parameters (obtained by net.parameters())

• net (Callable) – A function (e.g. Sequential object) that takes an np.ndarray and generates another np.ndarray

• rng_key (JaxRNGKey) – A Jax random key

• attack_steps (int) – Number of PGA steps to be taken

• mismatch_level (float) – Size by which the adversary can perturb the weights ($$\zeta$$). Attack will be in $$[\Theta-\zeta \cdot |\Theta|,\Theta+\zeta \cdot |\Theta|]$$

• initial_std (float) – Initial perturbation ($$\zeta_{initial}$$) of the parameters according to $$\Theta + \zeta_{initial} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I})$$

• inputs (np.ndarray) – Inputs that will be passed through the network

• net_out_original (np.ndarray) – Outputs of the network using the original weights

• tree_def_params (JaxTreeDef) – Tree structure obtained by calling jax.tree_util.tree_flatten on theta_star_unflattened. Basically defining the shape of theta/theta_star

• mismatch_loss (Callable) – Mismatch loss. Takes as input two np.ndarray s and returns a float. Example: KL divergence between softmaxed logits of the networks. Signature: mismatch_loss(target, net_output).

Returns

Tuple comprising $$\Theta^*$$ in flattened form and a dictionary holding the grads and losses for every PGA iteration

Return type

Tuple[List, Dict]