def deep_linear_model_loss_regularized_dw0_lp_da1_lq(w, x, y, reg_coeff, norm_type1, norm_type2): """Penalize by ||dL/dw0||+||dL/da1||, where w0, a1 are first layer w, o.""" dloss_da = jax.grad(deep_linear_model_loss, argnums=1) loss = deep_linear_model_loss(w, x, y) a1 = w[0].T @ x reg = norm_f(w[0], norm_type1) + norm_f(dloss_da(w[1:], a1, y), norm_type2) return loss + reg_coeff * reg
def compute_min_norm_solution(x, y, norm_type): """Compute the min-norm solution using a convex-program solver.""" w = cp.Variable((x.shape[0], 1)) if norm_type == 'linf': # compute minimal L_infinity solution constraints = [cp.multiply(y, (w.T @ x)) >= 1] prob = cp.Problem(cp.Minimize(cp.norm_inf(w)), constraints) elif norm_type == 'l2': # compute minimal L_2 solution constraints = [cp.multiply(y, (w.T @ x)) >= 1] prob = cp.Problem(cp.Minimize(cp.norm2(w)), constraints) elif norm_type == 'l1': # compute minimal L_1 solution constraints = [cp.multiply(y, (w.T @ x)) >= 1] prob = cp.Problem(cp.Minimize(cp.norm1(w)), constraints) elif norm_type[0] == 'l': # compute minimal Lp solution p = float(norm_type[1:]) constraints = [cp.multiply(y, (w.T @ x)) >= 1] prob = cp.Problem(cp.Minimize(cp.pnorm(w, p)), constraints) elif norm_type == 'dft1': w = cp.Variable((x.shape[0], 1), complex=True) # compute minimal Fourier L1 norm (||F(w)||_1) solution dft = scipy.linalg.dft(x.shape[0]) / np.sqrt(x.shape[0]) constraints = [cp.multiply(y, (cp.real(w).T @ x)) >= 1] prob = cp.Problem(cp.Minimize(cp.norm1(dft @ w)), constraints) prob.solve(verbose=True) logging.info('Min %s-norm solution found (norm=%.4f)', norm_type, float(norm_f(w.value, norm_type))) return cp.real(w).value
def two_linear_model_loss_regularized_dy0dx_lp(w0, w1, x, y, reg_coeff, norm_type): """Penalize by ||dy0/dx||, optimal when first layer is fixed.""" dy0_dx_f = jax.grad(two_linear_model_y0_mean, argnums=1) dy0_dx = dy0_dx_f(w0, x) loss = two_linear_model_loss(w0, w1, x, y) reg = norm_f(dy0_dx, norm_type) return loss + reg_coeff * reg
def deep_linear_model_normalize_param(params, norm_type): """Normalizes the last layer weights by the norm of the product of weights.""" norm_p = norm_f(deep_linear_model_linearize_param(params), norm_type) params_new = params params_new[-1] = params[-1] / jnp.maximum(1e-7, norm_p) return params_new
def linear_model_loss_regularized_dw_lp(w, x, y, reg_coeff, norm_type): """Penalize by ||dL/dx||.""" loss_and_dloss_dw_f = jax.value_and_grad(linear_model_loss, argnums=0) loss, dloss_dw = loss_and_dloss_dw_f(w, x, y) reg = norm_f(dloss_dw, norm_type) return loss + reg_coeff * reg
def linear_model_loss_regularized_w_lp(w, x, y, reg_coeff, norm_type): loss = linear_model_loss(w, x, y) reg = norm_f(w, norm_type) return loss + reg_coeff * reg
def get_model_functions(rng_key, dim, arch, nlayers, regularizer, reg_coeff, r): """Returns model init/predict/loss functions given the model name.""" loss_and_prox_op = None if arch == 'linear': init_f, predict_f = linear_model_init_param, linear_model_predict if regularizer == 'none': loss_f = linear_model_loss prox_op = lambda x, _: x elif re.match('w_l.*', regularizer) or re.match( 'w_dft.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = loss_f_with_args(linear_model_loss_regularized_w_lp, reg_coeff, norm_type) prox_op = lambda v, lam: get_prox_op(norm_type)(v, lam * reg_coeff) elif re.match('dx_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = loss_f_with_args(linear_model_loss_regularized_dx_lp, reg_coeff, norm_type) elif re.match('dw_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = loss_f_with_args(linear_model_loss_regularized_dw_lp, reg_coeff, norm_type) model_param = init_f(dim) loss_adv_f = linear_model_loss linearize_param = lambda p: p normalize_param = lambda p, n: p / jnp.maximum(1e-7, norm_f(p, n)) loss_and_prox_op = (linear_model_loss, prox_op) elif arch == 'deep_linear': init_f = deep_linear_model_init_param predict_f = deep_linear_model_predict if regularizer == 'none': loss_f = deep_linear_model_loss elif re.match('w_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = loss_f_with_args(deep_linear_model_loss_regularized_w_lp, reg_coeff, norm_type) elif re.match('dx_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = loss_f_with_args(deep_linear_model_loss_regularized_dx_lp, reg_coeff, norm_type) elif re.match('dw_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = loss_f_with_args(deep_linear_model_loss_regularized_dw_lp, reg_coeff, norm_type) elif re.match('dw0_l.*_da1_l.*', regularizer): norm_type = regularizer.split('_') norm_type1, norm_type2 = norm_type[1], norm_type[3] loss_f = loss_f_with_args( deep_linear_model_loss_regularized_dw0_lp_da1_lq, reg_coeff, norm_type1, norm_type2) model_param = init_f(dim, nlayers, r, rng_key) loss_adv_f = deep_linear_model_loss linearize_param = deep_linear_model_linearize_param normalize_param = deep_linear_model_normalize_param elif arch == 'two_linear_fixed_w0': w0, model_param = two_linear_w0fixed_init_param(dim, r, rng_key) predict_f = lambda w, x: two_linear_model_predict(w0, w, x) if regularizer == 'none': loss_f = lambda w, x, y: two_linear_model_loss(w0, w, x, y) elif re.match('dy1dx_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = lambda w, x, y: two_linear_model_loss_regularized_dy1dx_lp( # pylint: disable=g-long-lambda w0, w, x, y, reg_coeff, norm_type) elif re.match('w0w1_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = lambda w, x, y: two_linear_model_loss_regularized_w0w1_lp( # pylint: disable=g-long-lambda w0, w, x, y, reg_coeff, norm_type) elif re.match('w1_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = lambda w, x, y: two_linear_model_loss_regularized_w1_lp( # pylint: disable=g-long-lambda w0, w, x, y, reg_coeff, norm_type) loss_adv_f = lambda w, x, y: two_linear_model_loss(w0, w, x, y) linearize_param = lambda p: w0.T @ p normalize_param = lambda p, n: p / jnp.maximum(1e-7, norm_f( w0.T @ p, n)) elif arch == 'two_linear_fixed_w1' or arch == 'two_linear_fixed_w1_noniso': non_isotropic = arch == 'two_linear_fixed_w1_noniso' model_param, w1 = two_linear_w1fixed_init_param( dim, r, rng_key, non_isotropic) predict_f = lambda w, x: two_linear_model_predict(w, w1, x) if regularizer == 'none': loss_f = lambda w, x, y: two_linear_model_loss(w, w1, x, y) elif re.match('dy0dx_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = lambda w, x, y: two_linear_model_loss_regularized_dy0dx_lp( # pylint: disable=g-long-lambda w, w1, x, y, reg_coeff, norm_type) elif re.match('w0w1_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = lambda w, x, y: two_linear_model_loss_regularized_w0w1_lp( # pylint: disable=g-long-lambda w, w1, x, y, reg_coeff, norm_type) elif re.match('w0_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = lambda w, x, y: two_linear_model_loss_regularized_w0_lp( # pylint: disable=g-long-lambda w, w1, x, y, reg_coeff, norm_type) elif re.match('dy1dx_l.*', regularizer): norm_type = regularizer.split('_')[1] loss_f = lambda w, x, y: two_linear_model_loss_regularized_dy1dx_lp( # pylint: disable=g-long-lambda w, w1, x, y, reg_coeff, norm_type) loss_adv_f = lambda w, x, y: two_linear_model_loss(w, w1, x, y) linearize_param = lambda p: p.T @ w1 normalize_param = lambda p, n: p / jnp.maximum(1e-7, norm_f( p.T @ w1, n)) elif arch == 'conv_linear': init_f = conv_linear_model_init_param predict_f = conv_linear_model_predict if regularizer == 'none': loss_f = conv_linear_model_loss prox_op = lambda x, _: x model_param = init_f(dim, nlayers, r, rng_key) loss_adv_f = conv_linear_model_loss linearize_param = conv_linear_model_linearize_param normalize_param = conv_linear_model_normalize_param loss_and_prox_op = (conv_linear_model_loss, prox_op) return (model_param, predict_f, loss_f, loss_adv_f, linearize_param, normalize_param, loss_and_prox_op)
def two_linear_model_loss_regularized_w1_lp(w0, w1, x, y, reg_coeff, norm_type): """Penalize by ||w1||, optimal when first layer is fixed faster than dydx.""" loss = two_linear_model_loss(w0, w1, x, y) reg = norm_f(w1, norm_type) return loss + reg_coeff * reg
def deep_linear_model_loss_regularized_w_lp(w, x, y, reg_coeff, norm_type): """Penalize by ||dL/dw||.""" loss = deep_linear_model_loss(w, x, y) w_unravel = ravel_pytree(w)[0] reg = norm_f(w_unravel, norm_type) return loss + reg_coeff * reg
def train(model_param, train_test_data, predict_f, loss_f, loss_adv_f, linearize_f, normalize_f, loss_and_prox_op, summary, config, rng_key): """Train a model and log risks.""" dloss_dw = jax.grad(loss_f, argnums=0) dloss_adv_dx = jax.grad(loss_adv_f, argnums=1) train_data = train_test_data[0] xtrain, ytrain = train_data # Precompute min-norm solutions if config.enable_cvxpy: min_norm_w = {} for norm_type in config.available_norm_types: min_norm_w[norm_type] = compute_min_norm_solution( xtrain, ytrain, norm_type) if config.adv.eps_from_cvxpy: dual_norm = norm_type_dual(config.adv.norm_type) wcomp = min_norm_w[dual_norm] wnorm = norm_f(wcomp, dual_norm) margin = 1. / wnorm config.adv.eps_tot = config.adv.eps_iter = float(2 * margin) if config['optim']['name'] == 'cvxpy': norm_type = config['optim']['norm'] cvxpy_sol = compute_min_norm_solution(xtrain, ytrain, norm_type) model_param = jnp.array(cvxpy_sol) # Train loop optim_step, optim_options = optim.get_optimizer_step(config['optim']) niters = optim_options['niters'] for step in range(1, niters): # Take one optimization step if config['optim']['name'] != 'cvxpy': if config['optim']['adv_train']['enable']: # Adversarial training rng_key, rng_subkey = jax.random.split(rng_key) x_adv = adversarial.find_adversarial_samples( train_data, dloss_adv_dx, model_param, normalize_f, config.optim.adv_train, rng_key) train_data_new = x_adv, ytrain else: # Standard training train_data_new = train_data if config['optim']['name'] == 'fista': model_param, optim_options = optim_step( train_data_new, loss_and_prox_op, model_param, optim_options) else: model_param, optim_options = optim_step( train_data_new, loss_f, model_param, optim_options) # Log risks and other statistics if (step + 1) % config.log_interval == 0: # Evaluate risk on train/test sets for do_train in [True, False]: data = train_test_data[0] if do_train else train_test_data[1] prefix = 'risk/train' if do_train else 'risk/test' risk = evaluate_risks(data, predict_f, loss_f, model_param) for rname, rvalue in risk.items(): summary.scalar('%s/%s' % (prefix, rname), rvalue, step=step) rng_key, rng_subkey = jax.random.split(rng_key) risk = evaluate_adversarial_risk(data, predict_f, dloss_adv_dx, model_param, normalize_f, config, rng_subkey) for rname, rvalue in risk.items(): summary.array('%s/%s' % (prefix, rname), rvalue, step=step) grad = dloss_dw(model_param, xtrain, ytrain) grad_ravel, _ = ravel_pytree(grad) model_param_ravel, _ = ravel_pytree(model_param) for norm_type in config.available_norm_types: # Log the norm of the gradient w.r.t. various norms if not norm_type.startswith('dft'): summary.scalar('grad/norm/' + norm_type, norm_f(grad_ravel, norm_type), step=step) # Log weight norm if not norm_type.startswith('dft'): wnorm = norm_f(model_param_ravel, norm_type) summary.scalar('weight/norm/' + norm_type, wnorm, step=step) # Log margin for the equivalent linearized single layer model linear_param = linearize_f(model_param) min_loss = jnp.min(ytrain * (linear_param.T @ xtrain)) wcomp = linear_param / min_loss wnorm = norm_f(wcomp, norm_type) margin = jnp.sign(min_loss) * 1 / wnorm summary.scalar('margin/' + norm_type, margin, step=step) summary.scalar('weight/linear/norm/' + norm_type, wnorm, step=step) # Cosine similarity between the current params and min-norm solution if config.enable_cvxpy: def cos_sim(a, b): return jnp.dot( a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b)) min_norm_w_ravel, _ = ravel_pytree(min_norm_w[norm_type]) cs = cos_sim(linear_param.flatten(), min_norm_w_ravel) summary.scalar('csim_to_wmin/' + norm_type, cs, step=step) if 'step_size' in optim_options: summary.scalar('optim/step_size', optim_options['step_size'], step=step) logging.info('Epoch: [%d/%d]\t%s', step + 1, niters, summary.last_scalars_to_str(config.log_keys)) logging.flush() summary.flush()