"""
Functions to implement adversarial training approaches using Jax

:ref:/tutorials/adversarial_training.ipynb illustrates how to use the functions in this module to implement adversarial attacks on the parameters of a network during training.
"""
import jax
import numpy as np

from rockpool.nn.modules.jax.jax_module import JaxModule

import jax.tree_util as tu
import jax.random as random
import jax.numpy as jnp

from functools import partial

from typing import Tuple, Callable, List, Dict, Any, Optional
from rockpool.typehints import Tree, JaxTreeDef, JaxRNGKey

def _split_and_sample_normal(
key: JaxRNGKey, shape: Tuple
) -> Tuple[JaxRNGKey, np.ndarray]:
"""
Split an RNG key and generate random data of a given shape following a standard Gaussian distribution

Args:
key (JaxRNGKey): Array of two ints. A Jax random key
shape (tuple): The shape that the random normal data should have

Returns:
(JaxRNGKey, np.ndarray): Tuple of (key,data). key is the new key that can be used in subsequent computations and data is the Gaussian data
"""
key, subkey = random.split(key)
val = random.normal(subkey, shape=shape)
return key, val

def _eval_target_loss(
parameters: List,
inputs: np.ndarray,
target: np.ndarray,
net: JaxModule,
tree_def_params: JaxTreeDef,
loss: Callable[[np.ndarray, np.ndarray], float],
) -> float:
"""
Calculate the loss of the network output against a target output.

This function resets the states of the network, unflattens the parameters theta_star and assigns them to the network net and then evaluates the network on the inputs using the adversarial weights. Following, the loss function is evaluated using the target signal target against the newly generated outputs. The method returns the loss value.

Args:
parameters (List): Set of parameters to use during evolution. Flattened pytree that was obtained using jax.tree_util.tree_flatten
inputs (np.ndarray): Inputs that will be passed through the network
target (np.ndarray): Target network outputs to comapre against
net (Callable): A function (e.g. Sequential object) that takes an np.ndarray and generates another np.ndarray
tree_def_params (JaxTreeDef): Tree structure obtained by calling jax.tree_util.tree_flatten on theta_star_unflattened. Basically defining the shape of theta/theta_star
loss (Callable[[np.ndarray, np.ndarray], float]): Comparison loss function. Takes as input two np.ndarray s with the same shape and returns a float: loss(target, output). Example: KL divergence between softmaxed logits of the networks.

Returns:
float: The loss evaluated on the outputs generated by the network using parameters, against the target
"""
# - Reset the network
net = net.reset_state()

# - Set the network parameters to parameters
net = net.set_attributes(tu.tree_unflatten(tree_def_params, parameters))

# - Evolve the network
output, _, _ = net(inputs)

# - Return comparison loss
return loss(target, output)

[docs]@partial(
jax.jit,
static_argnames=(
"net",
"tree_def_params",
"mismatch_loss",
"attack_steps",
"mismatch_level",
"initial_std",
),
)
def pga_attack(
params_flattened: List,
net: Callable[[np.ndarray], np.ndarray],
rng_key: JaxRNGKey,
inputs: np.ndarray,
net_out_original: np.ndarray,
tree_def_params: JaxTreeDef,
mismatch_loss: Callable[[np.ndarray, np.ndarray], float],
attack_steps: int = 10,
mismatch_level: float = 0.025,
initial_std: float = 1e-3,
) -> Tuple[List, Dict]:
"""
Performs the PGA (projected gradient ascent) based attack on the parameters of the network given inputs.

This function performs an attack on the parameters of a network, using the gradient of a supplied loss. Starting from an initial set of parameters :math:\\Theta (params_flattened), we iteratively modify the parameters in order to worsen a supplied loss function mismatch_loss. mismatch_loss measures a comparison between the output of the network at the initial parameters (net_out_original) and the output of the network at the modified parameters :math:\\Theta^*. We compute the gradient of mismatch_loss w.r.t. the modified parameters :math:\\Theta^*, and step in a projected direction along the sign of the gradient.

The step size at each iteration for a parameter p is given by (mismatch_level * abs(p)) / attack_steps.

Args:
params_flattened (List): Flattened pytree that was obtained using jax.tree_util.tree_flatten of the network parameters (obtained by net.parameters())
net (Callable): A function (e.g. Sequential object) that takes an np.ndarray and generates another np.ndarray
rng_key (JaxRNGKey): A Jax random key
attack_steps (int): Number of PGA steps to be taken
mismatch_level (float): Size by which the adversary can perturb the weights (:math:\zeta). Attack will be in :math:[\Theta-\zeta \cdot |\Theta|,\Theta+\zeta \cdot |\Theta|]
initial_std (float): Initial perturbation (:math:\zeta_{initial}) of the parameters according to :math:\Theta + \zeta_{initial} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I})
inputs (np.ndarray): Inputs that will be passed through the network
net_out_original (np.ndarray): Outputs of the network using the original weights
tree_def_params (JaxTreeDef): Tree structure obtained by calling jax.tree_util.tree_flatten on theta_star_unflattened. Basically defining the shape of theta/theta_star
mismatch_loss (Callable): Mismatch loss. Takes as input two np.ndarray s and returns a float. Example: KL divergence between softmaxed logits of the networks. Signature: mismatch_loss(target, net_output).

Returns:
Tuple[List, Dict]: Tuple comprising :math:\Theta^* in flattened form and a dictionary holding the grads and losses for every PGA iteration
"""
# - Create verbose dict
verbose = {"grads": [], "losses": []}

# - Initialize Theta* by adding Gaussian noise to each parameter
theta_star = []
step_size = []
for p in params_flattened:
rng_key, random_normal_var = _split_and_sample_normal(rng_key, p.shape)
theta_star.append(p + jnp.abs(p) * initial_std * random_normal_var)
step_size.append((mismatch_level * jnp.abs(p)) / attack_steps)

# - Perform gradient ascent on the parameters Theta*, with respect to the provided mismatch loss
for _ in range(attack_steps):
# - Compute loss and gradients
theta_star, inputs, net_out_original, net, tree_def_params, mismatch_loss
)

# - Store the loss and gradients for this iteration
verbose["losses"].append(loss)

# - Step each parameter in the direction of the gradient, scaled to the parameter scale
for idx in range(len(theta_star)):
theta_star[idx] = theta_star[idx] + step_size[idx] * jnp.sign(
)

# - Return the attacked parameters
return theta_star, verbose

[docs]@partial(
jax.jit,
static_argnames=(
"net",
"mismatch_loss",
"noisy_forward_std",
"initial_std",
"mismatch_level",
"beta_robustness",
"attack_steps",
),
)
parameters: Tree,
net: JaxModule,
inputs: np.ndarray,
target: np.ndarray,
mismatch_loss: Callable[[np.ndarray, np.ndarray], float],
rng_key: JaxRNGKey,
noisy_forward_std: float = 0.0,
initial_std: float = 1e-3,
mismatch_level: float = 0.025,
beta_robustness: float = 0.25,
attack_steps: int = 10,
) -> float:
"""

This loss function combines a task loss with a loss that evaluates how robust a network is to parameter attack. The combined loss has the form :math:\mathcal{L} = \mathcal{L}_{nat}(f(X,\Theta),y) + \\beta_{rob} \cdot \mathcal{L}_{rob}(f(X,\Theta),f(X,\mathcal{A}(\Theta))) where :math:\mathcal{A}(\Theta) is an PGA-based adversary and :math:\Theta are the weights of the input that are perturbed by Gaussian noise during the forward pass.

The goal is to train a network that performs a desired task, but where the trained network is insensitive to modification of its parameters. This approach is useful for neuromorphic hardware that exhibits uncontrolled parameter mismatch on deployment.

The method combines two aspects --- Gaussian noise added to the parameters during the forward pass, and

:ref:/tutorials/adversarial_training.ipynb for an example of how to train a network using this adversarial attack during training.

Args:
parameters (Tree): Parameters of the network (obtained by e.g. net.parameters())
net (JaxModule): A JaxModule undergoing training
inputs (np.ndarray): Inputs that will be passed through the network
target (np.ndarray): Targets for the network prediction. Can be anything as long as training_loss can cope with the type/shape
task_loss (Callable): Task loss. Can be anything used for training a NN (e.g. cat. cross entropy). Signature: task_loss(net_output, target).
mismatch_loss (Callable): Mismatch loss between output of nominal and attacked network. Takes as input two np.ndarray s and returns a float. Example: KL divergence between softmaxed logits of the networks. Signature: mismatch_loss(net_output_star, net_output).
rng_key (JaxRNGKey): A Jax RNG key
noisy_forward_std (float): Float (:math:\zeta_{forward}) determining the amound of noise added to the parameters in the forward pass of the network. Model: :math:\Theta = \Theta + \zeta_{forward} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I}). Default: 0.; do not use noise in the forward pass
initial_std (float): Initial perturbation (:math:\zeta_{initial}) of the parameters according to :math:\Theta + \zeta_{initial} \cdot R \odot |\Theta| \; ; R \sim \mathcal{N}(0,\mathbf{I})
mismatch_level (float): Size by which the adversary can perturb the weights (:math:\zeta). Attack will be in :math:[\Theta-\zeta \cdot |\Theta|,\Theta+\zeta \cdot |\Theta|]. Default: 0.025
beta_robustness (float): Tradeoff parameter for the adversarial regularizer. Setting to 0.0 trains without adversarial loss but is much slower and should not be done. Default: 0.25
attack_steps (int): Number of PGA steps to be taken during each training iteration, as part of the adversarial attack. Default: 10

Returns:
float: The calculated loss, combining task loss and adversarial attack robustness loss
"""
# - Handle the network state — randomise or reset
net = net.reset_state()

# - Add Gaussian noise to the parameters before evaluating
params_flattened, tree_def_params = tu.tree_flatten(parameters)

params_gaussian_flattened = []
for p in params_flattened:
rng_key, random_normal_var = _split_and_sample_normal(rng_key, p.shape)
params_gaussian_flattened.append(
p + stop_gradient(jnp.abs(p) * noisy_forward_std * random_normal_var)
)

# - Evaluate the task loss using the perturbed parameters
loss_n = _eval_target_loss(
params_gaussian_flattened, inputs, target, net, tree_def_params, task_loss
)

# - Get output for the original parameters
# - Reset network state
net = net.reset_state()

# - Set parameters to the original parameters
net = net.set_attributes(parameters)

# - Get the network output using the original parameters
output_theta, _, _ = net(inputs)

# - Perform the adversarial attack to obtain the attacked parameters theta_star
theta_star, _ = pga_attack(
params_flattened=params_flattened,
net=net,
rng_key=rng_key,
attack_steps=attack_steps,
mismatch_level=mismatch_level,
initial_std=initial_std,
inputs=inputs,
net_out_original=output_theta,
tree_def_params=tree_def_params,
mismatch_loss=mismatch_loss,
)

# - Compute robustness loss using the attacked parameters theta_star
loss_r = _eval_target_loss(
theta_star, inputs, output_theta, net, tree_def_params, mismatch_loss
)

# - Add the robustness loss as a regularizer
return loss_n + beta_robustness * loss_r
`