Ejemplo n.º 1
0
def create_classification_losses(global_step,
                                 inputs,
                                 label,
                                 predictor_network,
                                 epsilon,
                                 loss_weights,
                                 warmup_steps=0,
                                 rampup_steps=-1,
                                 input_bounds=(0., 1.),
                                 options=None):
    """Create the training loss."""
    # Whether to elide the last linear layer with the specification.
    elide = True
    # Which loss to use for the IBP loss.
    loss_type = 'xent'
    # If the loss_type is 'hinge', which margin to use.
    loss_margin = 10.
    # Amount of label smoothing.
    label_smoothing = 0.
    # If True, batch normalization stops training after warm-up.
    is_training_off_after_warmup = False
    # If True, epsilon changes more smoothly.
    smooth_epsilon_schedule = False
    # Either 'one_vs_all', 'random_n', 'least_likely_n' or 'none'.
    verified_specification = 'one_vs_all'
    # Attack options.
    attack_specification = 'UntargetedPGDAttack_7x1x1_UnrolledAdam_.1'
    attack_scheduled = False
    attack_random_init = 1.
    # Whether the final loss from the attack should be standard cross-entropy
    # or the TRADES loss (https://arxiv.org/abs/1901.08573).
    pgd_attack_use_trades = False
    if options is not None:
        elide = options.get('elide_last_layer', elide)
        loss_type = options.get('verified_loss_type', loss_type)
        loss_margin = options.get('verified_loss_margin', loss_type)
        label_smoothing = options.get('label_smoothing', label_smoothing)
        is_training_off_after_warmup = options.get(
            'is_training_off_after_warmup', is_training_off_after_warmup)
        smooth_epsilon_schedule = options.get('smooth_epsilon_schedule',
                                              smooth_epsilon_schedule)
        verified_specification = options.get('verified_specification',
                                             verified_specification)
        attack_specification = options.get('attack_specification',
                                           attack_specification)
        attack_scheduled = options.get('attack_scheduled', attack_scheduled)
        attack_random_init = options.get('attack_random_init',
                                         attack_random_init)
        pgd_attack_use_trades = options.get('pgd_attack_use_trades',
                                            pgd_attack_use_trades)

    # Loss weights.
    def _get_schedule(init, final):
        if init == final:
            return init
        if rampup_steps < 0:
            return final
        return linear_schedule(global_step, warmup_steps,
                               warmup_steps + rampup_steps, init, final)

    def _is_active(init, final):
        return init > 0. or final > 0.

    nominal_xent = _get_schedule(**loss_weights.get('nominal'))
    attack_xent = _get_schedule(**loss_weights.get('attack'))
    use_attack = _is_active(**loss_weights.get('attack'))
    verified_loss = _get_schedule(**loss_weights.get('verified'))
    use_verification = _is_active(**loss_weights.get('verified'))
    if verified_specification == 'none':
        use_verification = False
    weight_mixture = loss.ScalarLosses(nominal_cross_entropy=nominal_xent,
                                       attack_cross_entropy=attack_xent,
                                       verified_loss=verified_loss)

    if rampup_steps < 0:
        train_epsilon = tf.constant(epsilon)
        is_training = not is_training_off_after_warmup
    else:
        if smooth_epsilon_schedule:
            train_epsilon = smooth_schedule(global_step, warmup_steps,
                                            warmup_steps + rampup_steps, 0.,
                                            epsilon)
        else:
            train_epsilon = linear_schedule(global_step, warmup_steps,
                                            warmup_steps + rampup_steps, 0.,
                                            epsilon)
        if is_training_off_after_warmup:
            is_training = global_step < warmup_steps
        else:
            is_training = True

    logits = predictor_network(inputs, is_training=is_training)
    num_classes = predictor_network.output_size
    if use_verification:
        logging.info('Verification active.')
        input_interval_bounds = bounds.IntervalBounds(
            tf.maximum(inputs - train_epsilon, input_bounds[0]),
            tf.minimum(inputs + train_epsilon, input_bounds[1]))
        predictor_network.propagate_bounds(input_interval_bounds)
        spec = create_specification(label, num_classes, logits,
                                    verified_specification)
        spec_builder = lambda *args, **kwargs: spec(
            *args, collapse=elide, **kwargs)  # pylint: disable=unnecessary-lambda
    else:
        logging.info('Verification disabled.')
        spec_builder = None
    if use_attack:
        logging.info('Attack active.')
        pgd_attack = create_attack(
            attack_specification,
            predictor_network,
            label,
            train_epsilon if attack_scheduled else epsilon,
            input_bounds=input_bounds,
            random_init=attack_random_init)
    else:
        logging.info('Attack disabled.')
        pgd_attack = None
    losses = loss.Losses(predictor_network,
                         spec_builder,
                         pgd_attack,
                         interval_bounds_loss_type=loss_type,
                         interval_bounds_hinge_margin=loss_margin,
                         label_smoothing=label_smoothing,
                         pgd_attack_use_trades=pgd_attack_use_trades)
    losses(label)
    train_loss = sum(l * w
                     for l, w in zip(losses.scalar_losses, weight_mixture))
    # Add a regularization loss.
    regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    train_loss = train_loss + tf.reduce_sum(regularizers)
    return losses, train_loss, train_epsilon
Ejemplo n.º 2
0
def create_classification_losses(
    global_step,
    inputs,
    label,
    predictor_network,
    epsilon,
    loss_weights,
    warmup_steps=0,
    rampup_steps=-1,
    input_bounds=(0., 1.),
    options=None):
  """Create the training loss."""
  elide = True
  loss_type = 'xent'
  loss_margin = 10.
  is_training_off_after_warmup = False
  smooth_epsilon_schedule = False
  if options is not None:
    elide = options.get('elide_last_layer', elide)
    loss_type = options.get('verified_loss_type', loss_type)
    loss_margin = options.get('verified_loss_margin', loss_type)
    is_training_off_after_warmup = options.get(
        'is_training_off_after_warmup', is_training_off_after_warmup)
    smooth_epsilon_schedule = options.get(
        'smooth_epsilon_schedule', smooth_epsilon_schedule)

  # Loss weights.
  def _get_schedule(init, final):
    if init == final:
      return init
    return linear_schedule(
        global_step, warmup_steps, warmup_steps + rampup_steps, init, final)
  def _is_active(init, final):
    return init > 0. or final > 0.
  nominal_xent = _get_schedule(**loss_weights.get('nominal'))
  attack_xent = _get_schedule(**loss_weights.get('attack'))
  use_attack = _is_active(**loss_weights.get('attack'))
  verified_loss = _get_schedule(**loss_weights.get('verified'))
  use_verification = _is_active(**loss_weights.get('verified'))
  weight_mixture = loss.ScalarLosses(
      nominal_cross_entropy=nominal_xent,
      attack_cross_entropy=attack_xent,
      verified_loss=verified_loss)

  if rampup_steps < 0:
    train_epsilon = tf.constant(epsilon)
    is_training = not is_training_off_after_warmup
  else:
    if smooth_epsilon_schedule:
      train_epsilon = smooth_schedule(
          global_step, warmup_steps, warmup_steps + rampup_steps, 0., epsilon)
    else:
      train_epsilon = linear_schedule(
          global_step, warmup_steps, warmup_steps + rampup_steps, 0., epsilon)
    if is_training_off_after_warmup:
      is_training = global_step < warmup_steps
    else:
      is_training = True

  predictor_network(inputs, is_training=is_training)
  num_classes = predictor_network.output_size
  if use_verification:
    logging.info('Verification active.')
    input_interval_bounds = bounds.IntervalBounds(
        tf.maximum(inputs - train_epsilon, input_bounds[0]),
        tf.minimum(inputs + train_epsilon, input_bounds[1]))
    predictor_network.propagate_bounds(input_interval_bounds)
    spec = specification.ClassificationSpecification(label, num_classes)
    spec_builder = lambda *args, **kwargs: spec(*args, collapse=elide, **kwargs)  # pylint: disable=unnecessary-lambda
  else:
    logging.info('Verification disabled.')
    spec = None
    spec_builder = None
  if use_attack:
    logging.info('Attack active.')
    s = spec
    if s is None:
      s = specification.ClassificationSpecification(label, num_classes)
    pgd_attack = attacks.UntargetedPGDAttack(
        predictor_network, s, epsilon, num_steps=7, input_bounds=input_bounds,
        optimizer_builder=attacks.UnrolledAdam)
  else:
    logging.info('Attack disabled.')
    pgd_attack = None
  losses = loss.Losses(predictor_network, spec_builder, pgd_attack,
                       interval_bounds_loss_type=loss_type,
                       interval_bounds_hinge_margin=loss_margin)
  losses(label)
  train_loss = sum(l * w for l, w in zip(losses.scalar_losses,
                                         weight_mixture))
  # Add a regularization loss.
  regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
  train_loss = train_loss + tf.reduce_sum(regularizers)
  return losses, train_loss, train_epsilon