def create_specification(label, num_classes, logits,
                         specification_type='one_vs_all', collapse=True):
  """Creates a specification of the desired type."""
  def _num_targets(name):
    tokens = name.rsplit('_', 1)
    return int(tokens[1]) if len(tokens) > 1 else 1
  if specification_type == 'one_vs_all':
    return specification.ClassificationSpecification(label, num_classes,
                                                     collapse=collapse)
  elif specification_type.startswith('random'):
    return specification.RandomClassificationSpecification(
        label, num_classes, _num_targets(specification_type), collapse=collapse)
  elif specification_type.startswith('least_likely'):
    return specification.LeastLikelyClassificationSpecification(
        label, num_classes, logits, _num_targets(specification_type),
        collapse=collapse)
  else:
    raise ValueError('Unknown specification type: "{}"'.format(
        specification_type))
Example #2
0
def get_attack_builder(logits,
                       label,
                       name='UntargetedPGDAttack',
                       random_seed=None,
                       manual_target_class=None):
    """Returns a callable with the same arguments as PGDAttack.

  In addition to the callable, this function also returns the targeted class
  indices as a Tensor with the same shape as label.

  Usage is as follows:
    logits = model(inputs)
    attack_cls, specification, target_class = get_attack_builder(logits, labels)
    # target_class is None, if attack_cls is not a targeted attack.
    attack_instance = attack_cls(model, specification, epsilon)
    perturbed_inputs = attack_instance(inputs, labels)

  Args:
    logits: Tensor of nominal logits of shape [batch_size, num_classes].
    label: Tensor of labels of shape [batch_size].
    name: Name of a PGDAttack class or any of "RandomMoreLikelyPGDAttack",
      "RandomMostLikelyPGDAttack", "LeastLikelyMoreLikelyPGDAttack",
      "LeastLikelyMostLikelyPGDAttack", "ManualMoreLikelyPGDAttack",
      "ManualMostLikelyPGDAttack". Any attack name can be postfixed by
      "Xent" to use the cross-entropy loss rather than margin loss.
    random_seed: Sets the random seed for "Random*" attacks.
    manual_target_class: For "Manual*" attacks, Tensor of target class indices
      of shape [batch_size].

  Returns:
    A callable, a Specification and a Tensor of target label (or None if the
    attack is not targeted).
  """
    if name.endswith('Xent'):
        use_xent = True
        name = name[:-4]
    else:
        use_xent = False
    if name.endswith('Linf'):
        use_l2 = False
        name = name[:-4]  # Just for syntactic sugar.
    elif name.endswith('L2'):
        use_l2 = True
        name = name[:-2]
    else:
        use_l2 = False
    num_classes = logits.shape[1].value
    if num_classes is None:
        raise ValueError('Cannot determine the number of classes from logits.')

    # Special case for multi-targeted attacks.
    m = re.match(
        r'((?:MemoryEfficient)?MultiTargetedPGDAttack)'
        r'(?:(Top|Random)(\d)*)?', name)
    if m is not None:
        # Request for a multi-targeted attack.
        is_multitargeted = True
        name = m.group(1)
        is_random = (m.group(2) == 'Random')
        max_specs = m.group(3)
        max_specs = int(max_specs) if max_specs is not None else 0
    else:
        is_multitargeted = False

    # Any of the readily available attack classes use the standard classification
    # specification (one-vs-all) and are untargeted.
    if hasattr(attacks, name):
        attack_cls = getattr(attacks, name)
        parameters = {}
        if use_xent:
            parameters['objective_fn'] = _maximize_cross_entropy
        if use_l2:
            parameters['project_perturbation'] = _get_projection(2)
        if is_multitargeted:
            parameters['max_specifications'] = max_specs
            parameters['random_specifications'] = is_random
        if parameters:
            attack_cls = _change_parameters(attack_cls, **parameters)
        attack_specification = specification.ClassificationSpecification(
            label, num_classes)
        return attack_cls, attack_specification, None

    # Attacks can use an adaptive scheme.
    if name.endswith('AdaptivePGDAttack'):
        name = name[:-len('AdaptivePGDAttack')] + 'PGDAttack'
        is_adaptive = True
    else:
        is_adaptive = False

    # Attacks can be preceded by a number to indicate the number of target
    # classes. For efficiency, this is only available for *MoreLikely attacks.
    m = re.match(r'(\d*)(.*MoreLikelyPGDAttack)', name)
    if m is not None:
        num_targets = int(m.group(1))
        name = m.group(2)
    else:
        num_targets = 1

    # All attacks that are not directly listed in the attacks library are
    # targeted attacks that need to be manually constructed.
    if name not in ('RandomMoreLikelyPGDAttack', 'RandomMostLikelyPGDAttack',
                    'LeastLikelyMoreLikelyPGDAttack',
                    'LeastLikelyMostLikelyPGDAttack',
                    'ManualMoreLikelyPGDAttack', 'ManualMostLikelyPGDAttack'):
        raise ValueError('Unknown attack "{}".'.format(name))

    base_attack_cls = (attacks.AdaptiveUntargetedPGDAttack
                       if is_adaptive else attacks.UntargetedPGDAttack)
    if 'More' in name:
        if use_xent:
            raise ValueError('Using cross-entropy is not supported by '
                             '"*MoreLikelyPGDAttack".')
        attack_cls = base_attack_cls
    else:
        # We need to reverse the attack direction w.r.t. the specifications.
        attack_cls = _change_parameters(
            base_attack_cls,
            objective_fn=(_minimize_cross_entropy
                          if use_xent else _minimize_margin),
            success_fn=_all_smaller)
    if use_l2:
        attack_cls = _change_parameters(
            attack_cls, project_perturbation=_get_projection(2))

    # Set attack specification and target class.
    if name == 'RandomMoreLikelyPGDAttack':
        # A random target class should become more likely than the true class.
        attack_specification = specification.RandomClassificationSpecification(
            label, num_classes, num_targets=num_targets, seed=random_seed)
        target_class = (tf.squeeze(attack_specification.target_class, 1)
                        if num_targets == 1 else None)

    elif name == 'LeastLikelyMoreLikelyPGDAttack':
        attack_specification = specification.LeastLikelyClassificationSpecification(
            label, num_classes, logits, num_targets=num_targets)
        target_class = (tf.squeeze(attack_specification.target_class, 1)
                        if num_targets == 1 else None)

    elif name == 'ManualMoreLikelyPGDAttack':
        attack_specification = specification.TargetedClassificationSpecification(
            label, num_classes, manual_target_class)
        target_class = (tf.squeeze(attack_specification.target_class, 1)
                        if num_targets == 1 else None)

    elif name == 'RandomMostLikelyPGDAttack':
        # This attack needs to make the random target the highest logits for
        # it is be successful.
        target_class = _get_random_class(label, num_classes, seed=random_seed)
        attack_specification = specification.ClassificationSpecification(
            target_class, num_classes)

    elif name == 'LeastLikelyMostLikelyPGDAttack':
        # This attack needs to make the least likely target the highest logits
        # for it is be successful.
        target_class = _get_least_likely_class(label, num_classes, logits)
        attack_specification = specification.ClassificationSpecification(
            target_class, num_classes)

    else:
        assert name == 'ManualMostLikelyPGDAttack'
        target_class = manual_target_class
        attack_specification = specification.ClassificationSpecification(
            target_class, num_classes)

    return attack_cls, attack_specification, target_class
Example #3
0
def create_classification_losses(
    global_step,
    inputs,
    label,
    predictor_network,
    epsilon,
    loss_weights,
    warmup_steps=0,
    rampup_steps=-1,
    input_bounds=(0., 1.),
    options=None):
  """Create the training loss."""
  elide = True
  loss_type = 'xent'
  loss_margin = 10.
  is_training_off_after_warmup = False
  smooth_epsilon_schedule = False
  if options is not None:
    elide = options.get('elide_last_layer', elide)
    loss_type = options.get('verified_loss_type', loss_type)
    loss_margin = options.get('verified_loss_margin', loss_type)
    is_training_off_after_warmup = options.get(
        'is_training_off_after_warmup', is_training_off_after_warmup)
    smooth_epsilon_schedule = options.get(
        'smooth_epsilon_schedule', smooth_epsilon_schedule)

  # Loss weights.
  def _get_schedule(init, final):
    if init == final:
      return init
    return linear_schedule(
        global_step, warmup_steps, warmup_steps + rampup_steps, init, final)
  def _is_active(init, final):
    return init > 0. or final > 0.
  nominal_xent = _get_schedule(**loss_weights.get('nominal'))
  attack_xent = _get_schedule(**loss_weights.get('attack'))
  use_attack = _is_active(**loss_weights.get('attack'))
  verified_loss = _get_schedule(**loss_weights.get('verified'))
  use_verification = _is_active(**loss_weights.get('verified'))
  weight_mixture = loss.ScalarLosses(
      nominal_cross_entropy=nominal_xent,
      attack_cross_entropy=attack_xent,
      verified_loss=verified_loss)

  if rampup_steps < 0:
    train_epsilon = tf.constant(epsilon)
    is_training = not is_training_off_after_warmup
  else:
    if smooth_epsilon_schedule:
      train_epsilon = smooth_schedule(
          global_step, warmup_steps, warmup_steps + rampup_steps, 0., epsilon)
    else:
      train_epsilon = linear_schedule(
          global_step, warmup_steps, warmup_steps + rampup_steps, 0., epsilon)
    if is_training_off_after_warmup:
      is_training = global_step < warmup_steps
    else:
      is_training = True

  predictor_network(inputs, is_training=is_training)
  num_classes = predictor_network.output_size
  if use_verification:
    logging.info('Verification active.')
    input_interval_bounds = bounds.IntervalBounds(
        tf.maximum(inputs - train_epsilon, input_bounds[0]),
        tf.minimum(inputs + train_epsilon, input_bounds[1]))
    predictor_network.propagate_bounds(input_interval_bounds)
    spec = specification.ClassificationSpecification(label, num_classes)
    spec_builder = lambda *args, **kwargs: spec(*args, collapse=elide, **kwargs)  # pylint: disable=unnecessary-lambda
  else:
    logging.info('Verification disabled.')
    spec = None
    spec_builder = None
  if use_attack:
    logging.info('Attack active.')
    s = spec
    if s is None:
      s = specification.ClassificationSpecification(label, num_classes)
    pgd_attack = attacks.UntargetedPGDAttack(
        predictor_network, s, epsilon, num_steps=7, input_bounds=input_bounds,
        optimizer_builder=attacks.UnrolledAdam)
  else:
    logging.info('Attack disabled.')
    pgd_attack = None
  losses = loss.Losses(predictor_network, spec_builder, pgd_attack,
                       interval_bounds_loss_type=loss_type,
                       interval_bounds_hinge_margin=loss_margin)
  losses(label)
  train_loss = sum(l * w for l, w in zip(losses.scalar_losses,
                                         weight_mixture))
  # Add a regularization loss.
  regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
  train_loss = train_loss + tf.reduce_sum(regularizers)
  return losses, train_loss, train_epsilon