def deep_linear_model_loss_regularized_dw0_lp_da1_lq(w, x, y, reg_coeff, norm_type1, norm_type2): """Penalize by ||dL/dw0||+||dL/da1||, where w0, a1 are first layer w, o.""" dloss_da = jax.grad(deep_linear_model_loss, argnums=1) loss = deep_linear_model_loss(w, x, y) a1 = w[0].T @ x reg = norm_f(w[0], norm_type1) + norm_f(dloss_da(w[1:], a1, y), norm_type2) return loss + reg_coeff * reg
def compute_min_norm_solution(x, y, norm_type): """Compute the min-norm solution using a convex-program solver.""" w = cp.Variable((x.shape[0], 1)) if norm_type == 'linf': # compute minimal L_infinity solution constraints = [cp.multiply(y, (w.T @ x)) >= 1] prob = cp.Problem(cp.Minimize(cp.norm_inf(w)), constraints) elif norm_type == 'l2': # compute minimal L_2 solution constraints = [cp.multiply(y, (w.T @ x)) >= 1] prob = cp.Problem(cp.Minimize(cp.norm2(w)), constraints) elif norm_type == 'l1': # compute minimal L_1 solution constraints = [cp.multiply(y, (w.T @ x)) >= 1] prob = cp.Problem(cp.Minimize(cp.norm1(w)), constraints) elif norm_type[0] == 'l': # compute minimal Lp solution p = float(norm_type[1:]) constraints = [cp.multiply(y, (w.T @ x)) >= 1] prob = cp.Problem(cp.Minimize(cp.pnorm(w, p)), constraints) elif norm_type == 'dft1': w = cp.Variable((x.shape[0], 1), complex=True) # compute minimal Fourier L1 norm (||F(w)||_1) solution dft = np.matrix(scipy.linalg.dft(x.shape[0], scale='sqrtn')) constraints = [cp.multiply(y, (cp.real(w).T @ x)) >= 1] prob = cp.Problem(cp.Minimize(cp.norm1(dft @ w)), constraints) prob.solve() logging.info('Min %s-norm solution found (norm=%.4f)', norm_type, float(norm_f(w.value, norm_type))) return cp.real(w).value
def two_linear_model_loss_regularized_dy0dx_lp(w0, w1, x, y, reg_coeff, norm_type): """Penalize by ||dy0/dx||, optimal when first layer is fixed.""" dy0_dx_f = jax.grad(two_linear_model_y0_mean, argnums=1) dy0_dx = dy0_dx_f(w0, x) loss = two_linear_model_loss(w0, w1, x, y) reg = norm_f(dy0_dx, norm_type) return loss + reg_coeff * reg
def train(model_param, train_test_data, predict_f, loss_f, loss_adv_f, linearize_f, normalize_f, loss_and_prox_op, summary, config, rng_key ): """Train a model and log risks.""" dloss_dw = jax.grad(loss_f, argnums=0) dloss_adv_dx = jax.grad(loss_adv_f, argnums=1) train_data = train_test_data[0] xtrain, ytrain = train_data # Precompute min-norm solutions if config.enable_cvxpy: min_norm_w = {} for norm_type in config.available_norm_types: min_norm_w[norm_type] = compute_min_norm_solution(xtrain, ytrain, norm_type) if config.adv.eps_from_cvxpy: dual_norm = norm_type_dual(config.adv.norm_type) wcomp = min_norm_w[dual_norm] wnorm = norm_f(wcomp, dual_norm) margin = 1. / wnorm # The loop over the range of epsilons will cover 0.5 * (2 * margin) config.adv.eps_tot = config.adv.eps_iter = float(2 * margin) if config.optim.adv_train.eps_from_cvxpy: dual_norm = norm_type_dual(config.optim.adv_train.norm_type) wcomp = min_norm_w[dual_norm] wnorm = norm_f(wcomp, dual_norm) margin = 1. / wnorm # Exactly with the optimal margin and varepsilon # minus 0.1 is needed otherwise the convergence is slow. Also sensitive # to 0.1. 1e-2 would not work. config.optim.adv_train.eps_tot = float(margin) - .1 config.optim.adv_train.eps_iter = float(margin) - .1 if config['optim']['name'] == 'cvxpy': norm_type = config['optim']['norm'] cvxpy_sol = compute_min_norm_solution(xtrain, ytrain, norm_type) model_param = jnp.array(cvxpy_sol) # Train loop optim_step, optim_options = optim.get_optimizer_step(config['optim']) niters = optim_options['niters'] for step in range(1, niters): # Take one optimization step if config['optim']['name'] != 'cvxpy': if config['optim']['adv_train']['enable']: # Adversarial training rng_key, rng_subkey = jax.random.split(rng_key) x_adv = adversarial.find_adversarial_samples(train_data, loss_adv_f, dloss_adv_dx, linearize_f, model_param, normalize_f, config.optim.adv_train, rng_key) train_data_new = x_adv, ytrain else: # Standard training train_data_new = train_data if config['optim']['name'] == 'fista': model_param, optim_options = optim_step(train_data_new, loss_and_prox_op, model_param, optim_options) else: model_param, optim_options = optim_step(train_data_new, loss_f, model_param, optim_options) # Log risks and other statistics if (step + 1) % config.log_interval == 0: # Evaluate risk on train/test sets for do_train in [True, False]: data = train_test_data[0] if do_train else train_test_data[1] prefix = 'risk/train' if do_train else 'risk/test' risk = evaluate_risks(data, predict_f, loss_f, model_param) for rname, rvalue in risk.items(): summary.scalar('%s/%s' % (prefix, rname), rvalue, step=step) rng_key, rng_subkey = jax.random.split(rng_key) risk = evaluate_adversarial_risk(data, predict_f, loss_adv_f, dloss_adv_dx, linearize_f, model_param, normalize_f, config, rng_subkey) for rname, rvalue in risk.items(): summary.array('%s/%s' % (prefix, rname), rvalue, step=step) grad = dloss_dw(model_param, xtrain, ytrain) grad_ravel, _ = ravel_pytree(grad) model_param_ravel, _ = ravel_pytree(model_param) for norm_type in config.available_norm_types: # Log the norm of the gradient w.r.t. various norms if not norm_type.startswith('dft'): summary.scalar( 'grad/norm/' + norm_type, norm_f(grad_ravel, norm_type), step=step) # Log weight norm if not norm_type.startswith('dft'): wnorm = norm_f(model_param_ravel, norm_type) summary.scalar('weight/norm/' + norm_type, wnorm, step=step) # Log margin for the equivalent linearized single layer model linear_param = linearize_f(model_param) min_loss = jnp.min(ytrain * (linear_param.T @ xtrain)) wcomp = linear_param / min_loss wnorm = norm_f(wcomp, norm_type) margin = jnp.sign(min_loss) * 1 / wnorm summary.scalar('margin/' + norm_type, margin, step=step) summary.scalar('weight/linear/norm/' + norm_type, wnorm, step=step) # Cosine similarity between the current params and min-norm solution if config.enable_cvxpy: def cos_sim(a, b): return jnp.dot(a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b)) min_norm_w_ravel, _ = ravel_pytree(min_norm_w[norm_type]) cs = cos_sim(linear_param.flatten(), min_norm_w_ravel) summary.scalar('csim_to_wmin/' + norm_type, cs, step=step) if 'step_size' in optim_options: summary.scalar('optim/step_size', optim_options['step_size'], step=step) logging.info('Epoch: [%d/%d]\t%s', step + 1, niters, summary.last_scalars_to_str(config.log_keys)) logging.flush() summary.flush()
def deep_linear_model_normalize_param(params, norm_type): """Normalizes the last layer weights by the norm of the product of weights.""" norm_p = norm_f(deep_linear_model_linearize_param(params), norm_type) params_new = copy.deepcopy(params) params_new[-1] = params[-1] / jnp.maximum(1e-7, norm_p) return params_new
def linear_model_loss_regularized_dw_lp(w, x, y, reg_coeff, norm_type): """Penalize by ||dL/dx||.""" loss_and_dloss_dw_f = jax.value_and_grad(linear_model_loss, argnums=0) loss, dloss_dw = loss_and_dloss_dw_f(w, x, y) reg = norm_f(dloss_dw, norm_type) return loss + reg_coeff * reg
def linear_model_loss_regularized_w_lp(w, x, y, reg_coeff, norm_type): loss = linear_model_loss(w, x, y) reg = norm_f(w, norm_type) return loss + reg_coeff * reg
def get_model_functions(rng_key, dim, arch, nlayers, regularizer, reg_coeff, r): """Returns model init/predict/loss functions given the model name.""" loss_and_prox_op = None if arch == 'linear': init_f, predict_f = linear_model_init_param, linear_model_predict if regularizer == 'none': loss_f = linear_model_loss prox_op = lambda x, _: x elif re.match('w_l.*', regularizer) or re.match('w_dft.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = loss_f_with_args(linear_model_loss_regularized_w_lp, reg_coeff, norm_type) prox_op = lambda v, lam: get_prox_op(norm_type)(v, lam * reg_coeff) elif re.match('dx_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = loss_f_with_args(linear_model_loss_regularized_dx_lp, reg_coeff, norm_type) elif re.match('dw_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = loss_f_with_args(linear_model_loss_regularized_dw_lp, reg_coeff, norm_type) model_param = init_f(dim) loss_adv_f = linear_model_loss linearize_param = lambda p: p normalize_param = lambda p, n: p / jnp.maximum(1e-7, norm_f(p, n)) loss_and_prox_op = (linear_model_loss, prox_op) elif arch == 'deep_linear': init_f = deep_linear_model_init_param predict_f = deep_linear_model_predict if regularizer == 'none': loss_f = deep_linear_model_loss elif re.match('w_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = loss_f_with_args(deep_linear_model_loss_regularized_w_lp, reg_coeff, norm_type) elif re.match('dx_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = loss_f_with_args(deep_linear_model_loss_regularized_dx_lp, reg_coeff, norm_type) elif re.match('dw_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = loss_f_with_args(deep_linear_model_loss_regularized_dw_lp, reg_coeff, norm_type) elif re.match('dw0_l.*_da1_l.*', regularizer): norm_type = regularizer.split('_') norm_type1, norm_type2 = norm_type[1], norm_type[3] loss_f = loss_f_with_args( deep_linear_model_loss_regularized_dw0_lp_da1_lq, reg_coeff, norm_type1, norm_type2) model_param = init_f(dim, nlayers, r, rng_key) loss_adv_f = deep_linear_model_loss linearize_param = deep_linear_model_linearize_param normalize_param = deep_linear_model_normalize_param elif arch == 'two_linear_fixed_w0': w0, model_param = two_linear_w0fixed_init_param(dim, r, rng_key) predict_f = lambda w, x: two_linear_model_predict(w0, w, x) if regularizer == 'none': loss_f = lambda w, x, y: two_linear_model_loss(w0, w, x, y) elif re.match('dy1dx_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = lambda w, x, y: two_linear_model_loss_regularized_dy1dx_lp( w0, w, x, y, reg_coeff, norm_type) elif re.match('w0w1_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = lambda w, x, y: two_linear_model_loss_regularized_w0w1_lp( w0, w, x, y, reg_coeff, norm_type) elif re.match('w1_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = lambda w, x, y: two_linear_model_loss_regularized_w1_lp( w0, w, x, y, reg_coeff, norm_type) loss_adv_f = lambda w, x, y: two_linear_model_loss(w0, w, x, y) linearize_param = lambda p: w0.T @ p normalize_param = lambda p, n: p / jnp.maximum(1e-7, norm_f(w0.T @ p, n)) elif 'two_linear_fixed_w1' in arch: non_isotropic = arch == 'two_linear_fixed_w1_noniso' model_param, w1 = two_linear_w1fixed_init_param(dim, r, rng_key, non_isotropic) predict_f = lambda w, x: two_linear_model_predict(w, w1, x) if regularizer == 'none': loss_f = lambda w, x, y: two_linear_model_loss(w, w1, x, y) elif re.match('dy0dx_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = lambda w, x, y: two_linear_model_loss_regularized_dy0dx_lp( w, w1, x, y, reg_coeff, norm_type) elif re.match('w0w1_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = lambda w, x, y: two_linear_model_loss_regularized_w0w1_lp( w, w1, x, y, reg_coeff, norm_type) elif re.match('w0_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = lambda w, x, y: two_linear_model_loss_regularized_w0_lp( w, w1, x, y, reg_coeff, norm_type) elif re.match('dy1dx_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = lambda w, x, y: two_linear_model_loss_regularized_dy1dx_lp( w, w1, x, y, reg_coeff, norm_type) loss_adv_f = lambda w, x, y: two_linear_model_loss(w, w1, x, y) linearize_param = lambda p: p.T @ w1 normalize_param = lambda p, n: p / jnp.maximum(1e-7, norm_f(p.T @ w1, n)) elif 'conv_linear' in arch: init_f = conv_linear_model_init_param predict_f = conv_linear_model_predict if arch == 'conv_linear_exp': conv_linear_model_loss = conv_linear_model_exp_loss else: conv_linear_model_loss = conv_linear_model_logistic_loss if regularizer == 'none': loss_f = conv_linear_model_loss prox_op = lambda x, _: x model_param = init_f(dim, nlayers, r, rng_key) loss_adv_f = conv_linear_model_loss linearize_param = conv_linear_model_linearize_param normalize_param = conv_linear_model_normalize_param loss_and_prox_op = (conv_linear_model_loss, prox_op) return (model_param, predict_f, loss_f, loss_adv_f, linearize_param, normalize_param, loss_and_prox_op)
def two_linear_model_loss_regularized_w1_lp(w0, w1, x, y, reg_coeff, norm_type): """Penalize by ||w1||, optimal when first layer is fixed faster than dydx.""" loss = two_linear_model_loss(w0, w1, x, y) reg = norm_f(w1, norm_type) return loss + reg_coeff * reg
def deep_linear_model_loss_regularized_w_lp(w, x, y, reg_coeff, norm_type): """Penalize by ||dL/dw||.""" loss = deep_linear_model_loss(w, x, y) w_unravel = ravel_pytree(w)[0] reg = norm_f(w_unravel, norm_type) return loss + reg_coeff * reg
def find_adversarial_samples(data, loss_f, dloss_dx, linearize_f, model_param0, normalize_f, config, rng_key): """Generates an adversarial example in the epsilon-ball centered at data. Args: data: An array of size dim x num, with input vectors as the initialization for the adversarial attack. loss_f: Loss function for attack. dloss_dx: The gradient function of the adversarial loss w.r.t. the input linearize_f: Linearize function used for the DFT attack. model_param0: Current model parameters. normalize_f: A function to normalize the weights of the model. config: Dictionary of hyperparameters. rng_key: JAX random number generator key. Returns: An array of size dim x num, one adversarial example per input data point. - Can gradient wrt parameters be zero but not the gradient wrt inputs? No. f = w*x df/dw = x (for a linear model) dL/dw = dL/df1 df1/dw = G x if dL/dw ~= 0, and x != 0 => G~=0 dL/dx = dL/df1 df1/dx = G w G ~= 0 => dL/dx ~= 0 """ x0, y = data eps_iter, eps_tot = config.eps_iter, config.eps_tot norm_type = config.norm_type # - Normalize model params, prevents small gradients # For both linear and non-linear models the norm of the gradient can be # artificially very small. This might be less of an issue for linear models # and when we use sign(grad) to adv. # For separable data, the norm of the weights of the separating classifier # can increase to increase the confidence and decrease the gradient. # But adversarial examples within the epsilon ball still exist. # For linear models, divide one layer by the norm of the product of weights model_param = model_param0 if config.pre_normalize: model_param = normalize_f(model_param0, norm_type) # - Reason for starting from a random delta instead of zero: # A non-linear model can have zero dL/dx at x0 but huge d^2L/dx^2 which means # a gradient-based attack fails if it always starts the optimization from x0 # but succeed if starts from a point nearby with non-zero gradient. # It is not trivial what the distribution for the initial perturbation should # be. Uniform within the epsilon ball has its merits but then we have to use # different distributions for different norm-balls. We instead config for # sampling from a uniform distribution and clipping delta to lie within the # norm ball. delta = jax.random.normal(rng_key, x0.shape) if not config.rand_init: # Makes it harder to find the optimal adversarial direction for linear # models delta *= 0 assert eps_iter <= eps_tot, 'eps_iter > eps_tot' delta = norm_projection(delta, norm_type, eps_iter) options = {'bound_step': True, 'step_size': 1000.} m_buffer = None for _ in range(config.niters): x_adv = x0 + delta # Untargeted attack: increases the loss for the correct label if config.step_dir == 'sign_grad': # Linf attack, FGSM and PGD attacks use only sign grad = dloss_dx(model_param, x_adv, y) adv_step = config.lr * jnp.sign(grad) elif config.step_dir == 'grad': # For L2 attack grad = dloss_dx(model_param, x_adv, y) adv_step = config.lr * grad elif config.step_dir == 'grad_max': grad = dloss_dx(model_param, x_adv, y) grad_max = grad * (jnp.abs(grad) == jnp.abs(grad).max()) adv_step = config.lr * grad_max elif config.step_dir == 'dftinf_sd': # Linf attack, FGSM and PGD attacks use only sign grad = dloss_dx(model_param, x_adv, y) adv_step = sd_dir(grad, eps_iter) adv_step = config.lr * jnp.real(adv_step) # - Reason for having both a per-step epsilon and a total epsilon: # Needed for non-linear models. Increases attack success if dL/dx at x0 is # huge and f(x) is correct on the entire shell of the norm-ball but wrong # inside the norm ball. delta_i = norm_projection(adv_step, norm_type, eps_iter) delta = norm_projection(delta + delta_i, norm_type, eps_tot) delta = norm_projection(delta, norm_type, eps_tot) if config.post_normalize: delta_norm = jax.vmap(lambda x: norm_f(x, norm_type), (1, ), 0)(delta) delta = delta / jnp.maximum(1e-12, delta_norm) * eps_tot delta = norm_projection(delta, norm_type, eps_tot) x_adv = x0 + delta return x_adv