Functions to implement adversarial training approaches using Jax

See also

πŸ‘Ή Adversarial training illustrates how to use the functions in this module to implement adversarial attacks on the parameters of a network during training.


adversarial_loss(parameters,Β net,Β inputs,Β ...)

Implement a hybrid task / adversarial robustness loss

pga_attack(params_flattened,Β net,Β rng_key,Β ...)

Performs the PGA (projected gradient ascent) based attack on the parameters of the network given inputs.