training.adversarial_jax.pga_attack
- training.adversarial_jax.pga_attack(params_flattened: List, net: Callable[[ndarray], ndarray], rng_key: Any, inputs: ndarray, net_out_original: ndarray, tree_def_params: Any, mismatch_loss: Callable[[ndarray, ndarray], float], attack_steps: int = 10, mismatch_level: float = 0.025, initial_std: float = 0.001) Tuple[List, Dict] [source]
Performs the PGA (projected gradient ascent) based attack on the parameters of the network given inputs.
This function performs an attack on the parameters of a network, using the gradient of a supplied loss. Starting from an initial set of parameters \(\Theta\) (
params_flattened
), we iteratively modify the parameters in order to worsen a supplied loss functionmismatch_loss
.mismatch_loss
measures a comparison between the output of the network at the initial parameters (net_out_original
) and the output of the network at the modified parameters \(\Theta^*`\). We compute the gradient ofmismatch_loss
w.r.t. the modified parameters \(\Theta^*\), and step in a projected direction along the sign of the gradient.The step size at each iteration for a parameter
p
is given by(mismatch_level * abs(p)) / attack_steps
.- Parameters:
params_flattened (List) – Flattened pytree that was obtained using
jax.tree_util.tree_flatten
of the network parameters (obtained bynet.parameters()
)net (Callable) – A function (e.g.
Sequential
object) that takes annp.ndarray
and generates anothernp.ndarray
rng_key (JaxRNGKey) – A Jax random key
attack_steps (int) – Number of PGA steps to be taken
mismatch_level (float) – Size by which the adversary can perturb the weights (\(\zeta\)). Attack will be in \([\Theta-\zeta \cdot |\Theta|,\Theta+\zeta \cdot |\Theta|]\)
initial_std (float) – Initial perturbation (\(\zeta_{initial}\)) of the parameters according to \(\Theta + \zeta_{initial} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I})\)
inputs (np.ndarray) – Inputs that will be passed through the network
net_out_original (np.ndarray) – Outputs of the network using the original weights
tree_def_params (JaxTreeDef) – Tree structure obtained by calling
jax.tree_util.tree_flatten
ontheta_star_unflattened
. Basically defining the shape oftheta
/theta_star
mismatch_loss (Callable) – Mismatch loss. Takes as input two
np.ndarray
s and returns afloat
. Example: KL divergence between softmaxed logits of the networks. Signature:mismatch_loss(target, net_output)
.
- Returns:
Tuple comprising \(\Theta^*\) in flattened form and a dictionary holding the
grads
andlosses
for every PGA iteration- Return type:
Tuple[List, Dict]