Ejemplo n.º 1
0
def deep_linear_model_loss_regularized_dw0_lp_da1_lq(w, x, y, reg_coeff,
                                                     norm_type1, norm_type2):
  """Penalize by ||dL/dw0||+||dL/da1||, where w0, a1 are first layer w, o."""
  dloss_da = jax.grad(deep_linear_model_loss, argnums=1)
  loss = deep_linear_model_loss(w, x, y)
  a1 = w[0].T @ x
  reg = norm_f(w[0], norm_type1) + norm_f(dloss_da(w[1:], a1, y), norm_type2)
  return loss + reg_coeff * reg
Ejemplo n.º 2
0
def compute_min_norm_solution(x, y, norm_type):
  """Compute the min-norm solution using a convex-program solver."""
  w = cp.Variable((x.shape[0], 1))
  if norm_type == 'linf':
    # compute minimal L_infinity solution
    constraints = [cp.multiply(y, (w.T @ x)) >= 1]
    prob = cp.Problem(cp.Minimize(cp.norm_inf(w)), constraints)
  elif norm_type == 'l2':
    # compute minimal L_2 solution
    constraints = [cp.multiply(y, (w.T @ x)) >= 1]
    prob = cp.Problem(cp.Minimize(cp.norm2(w)), constraints)
  elif norm_type == 'l1':
    # compute minimal L_1 solution
    constraints = [cp.multiply(y, (w.T @ x)) >= 1]
    prob = cp.Problem(cp.Minimize(cp.norm1(w)), constraints)
  elif norm_type[0] == 'l':
    # compute minimal Lp solution
    p = float(norm_type[1:])
    constraints = [cp.multiply(y, (w.T @ x)) >= 1]
    prob = cp.Problem(cp.Minimize(cp.pnorm(w, p)), constraints)
  elif norm_type == 'dft1':
    w = cp.Variable((x.shape[0], 1), complex=True)
    # compute minimal Fourier L1 norm (||F(w)||_1) solution
    dft = np.matrix(scipy.linalg.dft(x.shape[0], scale='sqrtn'))
    constraints = [cp.multiply(y, (cp.real(w).T @ x)) >= 1]
    prob = cp.Problem(cp.Minimize(cp.norm1(dft @ w)), constraints)
  prob.solve()
  logging.info('Min %s-norm solution found (norm=%.4f)', norm_type,
               float(norm_f(w.value, norm_type)))
  return cp.real(w).value
Ejemplo n.º 3
0
def two_linear_model_loss_regularized_dy0dx_lp(w0, w1, x, y, reg_coeff,
                                               norm_type):
  """Penalize by ||dy0/dx||, optimal when first layer is fixed."""
  dy0_dx_f = jax.grad(two_linear_model_y0_mean, argnums=1)
  dy0_dx = dy0_dx_f(w0, x)
  loss = two_linear_model_loss(w0, w1, x, y)
  reg = norm_f(dy0_dx, norm_type)
  return loss + reg_coeff * reg
Ejemplo n.º 4
0
def train(model_param, train_test_data, predict_f, loss_f, loss_adv_f,
          linearize_f, normalize_f, loss_and_prox_op, summary, config, rng_key
          ):
  """Train a model and log risks."""
  dloss_dw = jax.grad(loss_f, argnums=0)
  dloss_adv_dx = jax.grad(loss_adv_f, argnums=1)

  train_data = train_test_data[0]
  xtrain, ytrain = train_data

  # Precompute min-norm solutions
  if config.enable_cvxpy:
    min_norm_w = {}
    for norm_type in config.available_norm_types:
      min_norm_w[norm_type] = compute_min_norm_solution(xtrain, ytrain,
                                                        norm_type)
    if config.adv.eps_from_cvxpy:
      dual_norm = norm_type_dual(config.adv.norm_type)
      wcomp = min_norm_w[dual_norm]
      wnorm = norm_f(wcomp, dual_norm)
      margin = 1. / wnorm
      # The loop over the range of epsilons will cover 0.5 * (2 * margin)
      config.adv.eps_tot = config.adv.eps_iter = float(2 * margin)

    if config.optim.adv_train.eps_from_cvxpy:
      dual_norm = norm_type_dual(config.optim.adv_train.norm_type)
      wcomp = min_norm_w[dual_norm]
      wnorm = norm_f(wcomp, dual_norm)
      margin = 1. / wnorm
      # Exactly with the optimal margin and varepsilon
      # minus 0.1 is needed otherwise the convergence is slow. Also sensitive
      # to 0.1. 1e-2 would not work.
      config.optim.adv_train.eps_tot = float(margin) - .1
      config.optim.adv_train.eps_iter = float(margin) - .1

  if config['optim']['name'] == 'cvxpy':
    norm_type = config['optim']['norm']
    cvxpy_sol = compute_min_norm_solution(xtrain, ytrain, norm_type)
    model_param = jnp.array(cvxpy_sol)

  # Train loop
  optim_step, optim_options = optim.get_optimizer_step(config['optim'])
  niters = optim_options['niters']
  for step in range(1, niters):
    # Take one optimization step
    if config['optim']['name'] != 'cvxpy':
      if config['optim']['adv_train']['enable']:
        # Adversarial training
        rng_key, rng_subkey = jax.random.split(rng_key)
        x_adv = adversarial.find_adversarial_samples(train_data, loss_adv_f,
                                                     dloss_adv_dx,
                                                     linearize_f,
                                                     model_param, normalize_f,
                                                     config.optim.adv_train,
                                                     rng_key)
        train_data_new = x_adv, ytrain
      else:
        # Standard training
        train_data_new = train_data
      if config['optim']['name'] == 'fista':
        model_param, optim_options = optim_step(train_data_new,
                                                loss_and_prox_op,
                                                model_param,
                                                optim_options)
      else:
        model_param, optim_options = optim_step(train_data_new,
                                                loss_f,
                                                model_param, optim_options)

    # Log risks and other statistics
    if (step + 1) % config.log_interval == 0:
      # Evaluate risk on train/test sets
      for do_train in [True, False]:
        data = train_test_data[0] if do_train else train_test_data[1]
        prefix = 'risk/train' if do_train else 'risk/test'
        risk = evaluate_risks(data, predict_f, loss_f, model_param)
        for rname, rvalue in risk.items():
          summary.scalar('%s/%s' % (prefix, rname), rvalue, step=step)
        rng_key, rng_subkey = jax.random.split(rng_key)
        risk = evaluate_adversarial_risk(data, predict_f, loss_adv_f,
                                         dloss_adv_dx,
                                         linearize_f,
                                         model_param, normalize_f, config,
                                         rng_subkey)
        for rname, rvalue in risk.items():
          summary.array('%s/%s' % (prefix, rname), rvalue, step=step)

      grad = dloss_dw(model_param, xtrain, ytrain)
      grad_ravel, _ = ravel_pytree(grad)
      model_param_ravel, _ = ravel_pytree(model_param)
      for norm_type in config.available_norm_types:
        # Log the norm of the gradient w.r.t. various norms
        if not norm_type.startswith('dft'):
          summary.scalar(
              'grad/norm/' + norm_type,
              norm_f(grad_ravel, norm_type),
              step=step)

        # Log weight norm
        if not norm_type.startswith('dft'):
          wnorm = norm_f(model_param_ravel, norm_type)
          summary.scalar('weight/norm/' + norm_type, wnorm, step=step)

        # Log margin for the equivalent linearized single layer model
        linear_param = linearize_f(model_param)
        min_loss = jnp.min(ytrain * (linear_param.T @ xtrain))
        wcomp = linear_param / min_loss
        wnorm = norm_f(wcomp, norm_type)
        margin = jnp.sign(min_loss) * 1 / wnorm
        summary.scalar('margin/' + norm_type, margin, step=step)
        summary.scalar('weight/linear/norm/' + norm_type, wnorm, step=step)

        # Cosine similarity between the current params and min-norm solution
        if config.enable_cvxpy:

          def cos_sim(a, b):
            return jnp.dot(a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))

          min_norm_w_ravel, _ = ravel_pytree(min_norm_w[norm_type])
          cs = cos_sim(linear_param.flatten(), min_norm_w_ravel)
          summary.scalar('csim_to_wmin/' + norm_type, cs, step=step)

      if 'step_size' in optim_options:
        summary.scalar('optim/step_size', optim_options['step_size'], step=step)

      logging.info('Epoch: [%d/%d]\t%s', step + 1, niters,
                   summary.last_scalars_to_str(config.log_keys))
      logging.flush()

  summary.flush()
Ejemplo n.º 5
0
def deep_linear_model_normalize_param(params, norm_type):
  """Normalizes the last layer weights by the norm of the product of weights."""
  norm_p = norm_f(deep_linear_model_linearize_param(params), norm_type)
  params_new = copy.deepcopy(params)
  params_new[-1] = params[-1] / jnp.maximum(1e-7, norm_p)
  return params_new
Ejemplo n.º 6
0
def linear_model_loss_regularized_dw_lp(w, x, y, reg_coeff, norm_type):
  """Penalize by ||dL/dx||."""
  loss_and_dloss_dw_f = jax.value_and_grad(linear_model_loss, argnums=0)
  loss, dloss_dw = loss_and_dloss_dw_f(w, x, y)
  reg = norm_f(dloss_dw, norm_type)
  return loss + reg_coeff * reg
Ejemplo n.º 7
0
def linear_model_loss_regularized_w_lp(w, x, y, reg_coeff, norm_type):
  loss = linear_model_loss(w, x, y)
  reg = norm_f(w, norm_type)
  return loss + reg_coeff * reg
Ejemplo n.º 8
0
def get_model_functions(rng_key, dim, arch, nlayers, regularizer, reg_coeff, r):
  """Returns model init/predict/loss functions given the model name."""
  loss_and_prox_op = None
  if arch == 'linear':
    init_f, predict_f = linear_model_init_param, linear_model_predict
    if regularizer == 'none':
      loss_f = linear_model_loss
      prox_op = lambda x, _: x
    elif re.match('w_l.*', regularizer) or re.match('w_dft.*', regularizer):
      norm_type = regularizer.split('_')[1]
      loss_f = loss_f_with_args(linear_model_loss_regularized_w_lp, reg_coeff,
                                norm_type)
      prox_op = lambda v, lam: get_prox_op(norm_type)(v, lam * reg_coeff)
    elif re.match('dx_l.*', regularizer):
      norm_type = regularizer.split('_')[1]
      loss_f = loss_f_with_args(linear_model_loss_regularized_dx_lp, reg_coeff,
                                norm_type)
    elif re.match('dw_l.*', regularizer):
      norm_type = regularizer.split('_')[1]
      loss_f = loss_f_with_args(linear_model_loss_regularized_dw_lp, reg_coeff,
                                norm_type)
    model_param = init_f(dim)
    loss_adv_f = linear_model_loss
    linearize_param = lambda p: p
    normalize_param = lambda p, n: p / jnp.maximum(1e-7, norm_f(p, n))
    loss_and_prox_op = (linear_model_loss, prox_op)
  elif arch == 'deep_linear':
    init_f = deep_linear_model_init_param
    predict_f = deep_linear_model_predict
    if regularizer == 'none':
      loss_f = deep_linear_model_loss
    elif re.match('w_l.*', regularizer):
      norm_type = regularizer.split('_')[1]
      loss_f = loss_f_with_args(deep_linear_model_loss_regularized_w_lp,
                                reg_coeff, norm_type)
    elif re.match('dx_l.*', regularizer):
      norm_type = regularizer.split('_')[1]
      loss_f = loss_f_with_args(deep_linear_model_loss_regularized_dx_lp,
                                reg_coeff, norm_type)
    elif re.match('dw_l.*', regularizer):
      norm_type = regularizer.split('_')[1]
      loss_f = loss_f_with_args(deep_linear_model_loss_regularized_dw_lp,
                                reg_coeff, norm_type)
    elif re.match('dw0_l.*_da1_l.*', regularizer):
      norm_type = regularizer.split('_')
      norm_type1, norm_type2 = norm_type[1], norm_type[3]
      loss_f = loss_f_with_args(
          deep_linear_model_loss_regularized_dw0_lp_da1_lq, reg_coeff,
          norm_type1, norm_type2)
    model_param = init_f(dim, nlayers, r, rng_key)
    loss_adv_f = deep_linear_model_loss
    linearize_param = deep_linear_model_linearize_param
    normalize_param = deep_linear_model_normalize_param
  elif arch == 'two_linear_fixed_w0':
    w0, model_param = two_linear_w0fixed_init_param(dim, r, rng_key)
    predict_f = lambda w, x: two_linear_model_predict(w0, w, x)
    if regularizer == 'none':
      loss_f = lambda w, x, y: two_linear_model_loss(w0, w, x, y)
    elif re.match('dy1dx_l.*', regularizer):
      norm_type = regularizer.split('_')[1]
      loss_f = lambda w, x, y: two_linear_model_loss_regularized_dy1dx_lp(
          w0, w, x, y, reg_coeff, norm_type)
    elif re.match('w0w1_l.*', regularizer):
      norm_type = regularizer.split('_')[1]
      loss_f = lambda w, x, y: two_linear_model_loss_regularized_w0w1_lp(
          w0, w, x, y, reg_coeff, norm_type)
    elif re.match('w1_l.*', regularizer):
      norm_type = regularizer.split('_')[1]
      loss_f = lambda w, x, y: two_linear_model_loss_regularized_w1_lp(
          w0, w, x, y, reg_coeff, norm_type)
    loss_adv_f = lambda w, x, y: two_linear_model_loss(w0, w, x, y)
    linearize_param = lambda p: w0.T @ p
    normalize_param = lambda p, n: p / jnp.maximum(1e-7, norm_f(w0.T @ p, n))
  elif 'two_linear_fixed_w1' in arch:
    non_isotropic = arch == 'two_linear_fixed_w1_noniso'
    model_param, w1 = two_linear_w1fixed_init_param(dim, r, rng_key,
                                                    non_isotropic)
    predict_f = lambda w, x: two_linear_model_predict(w, w1, x)
    if regularizer == 'none':
      loss_f = lambda w, x, y: two_linear_model_loss(w, w1, x, y)
    elif re.match('dy0dx_l.*', regularizer):
      norm_type = regularizer.split('_')[1]
      loss_f = lambda w, x, y: two_linear_model_loss_regularized_dy0dx_lp(
          w, w1, x, y, reg_coeff, norm_type)
    elif re.match('w0w1_l.*', regularizer):
      norm_type = regularizer.split('_')[1]
      loss_f = lambda w, x, y: two_linear_model_loss_regularized_w0w1_lp(
          w, w1, x, y, reg_coeff, norm_type)
    elif re.match('w0_l.*', regularizer):
      norm_type = regularizer.split('_')[1]
      loss_f = lambda w, x, y: two_linear_model_loss_regularized_w0_lp(
          w, w1, x, y, reg_coeff, norm_type)
    elif re.match('dy1dx_l.*', regularizer):
      norm_type = regularizer.split('_')[1]
      loss_f = lambda w, x, y: two_linear_model_loss_regularized_dy1dx_lp(
          w, w1, x, y, reg_coeff, norm_type)
    loss_adv_f = lambda w, x, y: two_linear_model_loss(w, w1, x, y)
    linearize_param = lambda p: p.T @ w1
    normalize_param = lambda p, n: p / jnp.maximum(1e-7, norm_f(p.T @ w1, n))
  elif 'conv_linear' in arch:
    init_f = conv_linear_model_init_param
    predict_f = conv_linear_model_predict
    if arch == 'conv_linear_exp':
      conv_linear_model_loss = conv_linear_model_exp_loss
    else:
      conv_linear_model_loss = conv_linear_model_logistic_loss
    if regularizer == 'none':
      loss_f = conv_linear_model_loss
      prox_op = lambda x, _: x
    model_param = init_f(dim, nlayers, r, rng_key)
    loss_adv_f = conv_linear_model_loss
    linearize_param = conv_linear_model_linearize_param
    normalize_param = conv_linear_model_normalize_param
    loss_and_prox_op = (conv_linear_model_loss, prox_op)

  return (model_param, predict_f, loss_f, loss_adv_f, linearize_param,
          normalize_param, loss_and_prox_op)
Ejemplo n.º 9
0
def two_linear_model_loss_regularized_w1_lp(w0, w1, x, y, reg_coeff, norm_type):
  """Penalize by ||w1||, optimal when first layer is fixed faster than dydx."""
  loss = two_linear_model_loss(w0, w1, x, y)
  reg = norm_f(w1, norm_type)
  return loss + reg_coeff * reg
Ejemplo n.º 10
0
def deep_linear_model_loss_regularized_w_lp(w, x, y, reg_coeff, norm_type):
  """Penalize by ||dL/dw||."""
  loss = deep_linear_model_loss(w, x, y)
  w_unravel = ravel_pytree(w)[0]
  reg = norm_f(w_unravel, norm_type)
  return loss + reg_coeff * reg
Ejemplo n.º 11
0
def find_adversarial_samples(data, loss_f, dloss_dx, linearize_f, model_param0,
                             normalize_f, config, rng_key):
    """Generates an adversarial example in the epsilon-ball centered at data.

  Args:
    data: An array of size dim x num, with input vectors as the initialization
      for the adversarial attack.
    loss_f: Loss function for attack.
    dloss_dx: The gradient function of the adversarial loss w.r.t. the input
    linearize_f: Linearize function used for the DFT attack.
    model_param0: Current model parameters.
    normalize_f: A function to normalize the weights of the model.
    config: Dictionary of hyperparameters.
    rng_key: JAX random number generator key.

  Returns:
    An array of size dim x num, one adversarial example per input data point.

  - Can gradient wrt parameters be zero but not the gradient wrt inputs? No.
  f = w*x
  df/dw = x (for a linear model)

  dL/dw = dL/df1 df1/dw = G x
  if dL/dw ~= 0, and x != 0 => G~=0

  dL/dx = dL/df1 df1/dx = G w
  G ~= 0 => dL/dx ~= 0
  """

    x0, y = data
    eps_iter, eps_tot = config.eps_iter, config.eps_tot
    norm_type = config.norm_type
    # - Normalize model params, prevents small gradients
    # For both linear and non-linear models the norm of the gradient can be
    # artificially very small. This might be less of an issue for linear models
    # and when we use sign(grad) to adv.
    # For separable data, the norm of the weights of the separating classifier
    # can increase to increase the confidence and decrease the gradient.
    # But adversarial examples within the epsilon ball still exist.
    # For linear models, divide one layer by the norm of the product of weights
    model_param = model_param0
    if config.pre_normalize:
        model_param = normalize_f(model_param0, norm_type)

    # - Reason for starting from a random delta instead of zero:
    # A non-linear model can have zero dL/dx at x0 but huge d^2L/dx^2 which means
    # a gradient-based attack fails if it always starts the optimization from x0
    # but succeed if starts from a point nearby with non-zero gradient.
    # It is not trivial what the distribution for the initial perturbation should
    # be. Uniform within the epsilon ball has its merits but then we have to use
    # different distributions for different norm-balls. We instead config for
    # sampling from a uniform distribution and clipping delta to lie within the
    # norm ball.
    delta = jax.random.normal(rng_key, x0.shape)
    if not config.rand_init:
        # Makes it harder to find the optimal adversarial direction for linear
        # models
        delta *= 0
    assert eps_iter <= eps_tot, 'eps_iter > eps_tot'
    delta = norm_projection(delta, norm_type, eps_iter)
    options = {'bound_step': True, 'step_size': 1000.}
    m_buffer = None
    for _ in range(config.niters):
        x_adv = x0 + delta
        # Untargeted attack: increases the loss for the correct label
        if config.step_dir == 'sign_grad':
            # Linf attack, FGSM and PGD attacks use only sign
            grad = dloss_dx(model_param, x_adv, y)
            adv_step = config.lr * jnp.sign(grad)
        elif config.step_dir == 'grad':
            # For L2 attack
            grad = dloss_dx(model_param, x_adv, y)
            adv_step = config.lr * grad
        elif config.step_dir == 'grad_max':
            grad = dloss_dx(model_param, x_adv, y)
            grad_max = grad * (jnp.abs(grad) == jnp.abs(grad).max())
            adv_step = config.lr * grad_max
        elif config.step_dir == 'dftinf_sd':
            # Linf attack, FGSM and PGD attacks use only sign
            grad = dloss_dx(model_param, x_adv, y)
            adv_step = sd_dir(grad, eps_iter)
            adv_step = config.lr * jnp.real(adv_step)
        # - Reason for having both a per-step epsilon and a total epsilon:
        # Needed for non-linear models. Increases attack success if dL/dx at x0 is
        # huge and f(x) is correct on the entire shell of the norm-ball but wrong
        # inside the norm ball.
        delta_i = norm_projection(adv_step, norm_type, eps_iter)
        delta = norm_projection(delta + delta_i, norm_type, eps_tot)
    delta = norm_projection(delta, norm_type, eps_tot)

    if config.post_normalize:
        delta_norm = jax.vmap(lambda x: norm_f(x, norm_type), (1, ), 0)(delta)
        delta = delta / jnp.maximum(1e-12, delta_norm) * eps_tot
        delta = norm_projection(delta, norm_type, eps_tot)
    x_adv = x0 + delta
    return x_adv