def create_classification_losses( global_step, inputs, label, predictor_network, epsilon, loss_weights, warmup_steps=0, rampup_steps=-1, input_bounds=(0., 1.), options=None): """Create the training loss.""" elide = True loss_type = 'xent' loss_margin = 10. is_training_off_after_warmup = False smooth_epsilon_schedule = False if options is not None: elide = options.get('elide_last_layer', elide) loss_type = options.get('verified_loss_type', loss_type) loss_margin = options.get('verified_loss_margin', loss_type) is_training_off_after_warmup = options.get( 'is_training_off_after_warmup', is_training_off_after_warmup) smooth_epsilon_schedule = options.get( 'smooth_epsilon_schedule', smooth_epsilon_schedule) # Loss weights. def _get_schedule(init, final): if init == final: return init return linear_schedule( global_step, warmup_steps, warmup_steps + rampup_steps, init, final) def _is_active(init, final): return init > 0. or final > 0. nominal_xent = _get_schedule(**loss_weights.get('nominal')) attack_xent = _get_schedule(**loss_weights.get('attack')) use_attack = _is_active(**loss_weights.get('attack')) verified_loss = _get_schedule(**loss_weights.get('verified')) use_verification = _is_active(**loss_weights.get('verified')) weight_mixture = loss.ScalarLosses( nominal_cross_entropy=nominal_xent, attack_cross_entropy=attack_xent, verified_loss=verified_loss) if rampup_steps < 0: train_epsilon = tf.constant(epsilon) is_training = not is_training_off_after_warmup else: if smooth_epsilon_schedule: train_epsilon = smooth_schedule( global_step, warmup_steps, warmup_steps + rampup_steps, 0., epsilon) else: train_epsilon = linear_schedule( global_step, warmup_steps, warmup_steps + rampup_steps, 0., epsilon) if is_training_off_after_warmup: is_training = global_step < warmup_steps else: is_training = True predictor_network(inputs, is_training=is_training) num_classes = predictor_network.output_size if use_verification: logging.info('Verification active.') input_interval_bounds = bounds.IntervalBounds( tf.maximum(inputs - train_epsilon, input_bounds[0]), tf.minimum(inputs + train_epsilon, input_bounds[1])) predictor_network.propagate_bounds(input_interval_bounds) spec = specification.ClassificationSpecification(label, num_classes) spec_builder = lambda *args, **kwargs: spec(*args, collapse=elide, **kwargs) # pylint: disable=unnecessary-lambda else: logging.info('Verification disabled.') spec = None spec_builder = None if use_attack: logging.info('Attack active.') s = spec if s is None: s = specification.ClassificationSpecification(label, num_classes) pgd_attack = attacks.UntargetedPGDAttack( predictor_network, s, epsilon, num_steps=7, input_bounds=input_bounds, optimizer_builder=attacks.UnrolledAdam) else: logging.info('Attack disabled.') pgd_attack = None losses = loss.Losses(predictor_network, spec_builder, pgd_attack, interval_bounds_loss_type=loss_type, interval_bounds_hinge_margin=loss_margin) losses(label) train_loss = sum(l * w for l, w in zip(losses.scalar_losses, weight_mixture)) # Add a regularization loss. regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) train_loss = train_loss + tf.reduce_sum(regularizers) return losses, train_loss, train_epsilon
def create_classification_losses(global_step, inputs, label, predictor_network, epsilon, loss_weights, warmup_steps=0, rampup_steps=-1, input_bounds=(0., 1.), options=None): """Create the training loss.""" # Whether to elide the last linear layer with the specification. elide = True # Which loss to use for the IBP loss. loss_type = 'xent' # If the loss_type is 'hinge', which margin to use. loss_margin = 10. # Amount of label smoothing. label_smoothing = 0. # If True, batch normalization stops training after warm-up. is_training_off_after_warmup = False # If True, epsilon changes more smoothly. smooth_epsilon_schedule = False # Either 'one_vs_all', 'random_n', 'least_likely_n' or 'none'. verified_specification = 'one_vs_all' # Attack options. attack_specification = 'UntargetedPGDAttack_7x1x1_UnrolledAdam_.1' attack_scheduled = False attack_random_init = 1. # Whether the final loss from the attack should be standard cross-entropy # or the TRADES loss (https://arxiv.org/abs/1901.08573). pgd_attack_use_trades = False if options is not None: elide = options.get('elide_last_layer', elide) loss_type = options.get('verified_loss_type', loss_type) loss_margin = options.get('verified_loss_margin', loss_type) label_smoothing = options.get('label_smoothing', label_smoothing) is_training_off_after_warmup = options.get( 'is_training_off_after_warmup', is_training_off_after_warmup) smooth_epsilon_schedule = options.get('smooth_epsilon_schedule', smooth_epsilon_schedule) verified_specification = options.get('verified_specification', verified_specification) attack_specification = options.get('attack_specification', attack_specification) attack_scheduled = options.get('attack_scheduled', attack_scheduled) attack_random_init = options.get('attack_random_init', attack_random_init) pgd_attack_use_trades = options.get('pgd_attack_use_trades', pgd_attack_use_trades) # Loss weights. def _get_schedule(init, final): if init == final: return init if rampup_steps < 0: return final return linear_schedule(global_step, warmup_steps, warmup_steps + rampup_steps, init, final) def _is_active(init, final): return init > 0. or final > 0. nominal_xent = _get_schedule(**loss_weights.get('nominal')) attack_xent = _get_schedule(**loss_weights.get('attack')) use_attack = _is_active(**loss_weights.get('attack')) verified_loss = _get_schedule(**loss_weights.get('verified')) use_verification = _is_active(**loss_weights.get('verified')) if verified_specification == 'none': use_verification = False weight_mixture = loss.ScalarLosses(nominal_cross_entropy=nominal_xent, attack_cross_entropy=attack_xent, verified_loss=verified_loss) if rampup_steps < 0: train_epsilon = tf.constant(epsilon) is_training = not is_training_off_after_warmup else: if smooth_epsilon_schedule: train_epsilon = smooth_schedule(global_step, warmup_steps, warmup_steps + rampup_steps, 0., epsilon) else: train_epsilon = linear_schedule(global_step, warmup_steps, warmup_steps + rampup_steps, 0., epsilon) if is_training_off_after_warmup: is_training = global_step < warmup_steps else: is_training = True logits = predictor_network(inputs, is_training=is_training) num_classes = predictor_network.output_size if use_verification: logging.info('Verification active.') input_interval_bounds = bounds.IntervalBounds( tf.maximum(inputs - train_epsilon, input_bounds[0]), tf.minimum(inputs + train_epsilon, input_bounds[1])) predictor_network.propagate_bounds(input_interval_bounds) spec = create_specification(label, num_classes, logits, verified_specification) spec_builder = lambda *args, **kwargs: spec( *args, collapse=elide, **kwargs) # pylint: disable=unnecessary-lambda else: logging.info('Verification disabled.') spec_builder = None if use_attack: logging.info('Attack active.') pgd_attack = create_attack( attack_specification, predictor_network, label, train_epsilon if attack_scheduled else epsilon, input_bounds=input_bounds, random_init=attack_random_init) else: logging.info('Attack disabled.') pgd_attack = None losses = loss.Losses(predictor_network, spec_builder, pgd_attack, interval_bounds_loss_type=loss_type, interval_bounds_hinge_margin=loss_margin, label_smoothing=label_smoothing, pgd_attack_use_trades=pgd_attack_use_trades) losses(label) train_loss = sum(l * w for l, w in zip(losses.scalar_losses, weight_mixture)) # Add a regularization loss. regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) train_loss = train_loss + tf.reduce_sum(regularizers) return losses, train_loss, train_epsilon
def create_classification_losses( global_step, inputs, label, predictor_network, epsilon, loss_weights, warmup_steps=0, rampup_steps=-1, input_bounds=(0., 1.), loss_builder=loss.Losses, options=None): """Create the training loss.""" # Whether to elide the last linear layer with the specification. elide = True # Which loss to use for the IBP loss. loss_type = 'xent' # If the loss_type is 'hinge', which margin to use. loss_margin = 10. # Amount of label smoothing. label_smoothing = 0. # If True, batch normalization stops training after warm-up. is_training_off_after = -1 # If True, epsilon changes more smoothly. smooth_epsilon_schedule = False # Either 'one_vs_all', 'random_n', 'least_likely_n' or 'none'. verified_specification = 'one_vs_all' # Attack options. attack_specification = 'UntargetedPGDAttack_7x1x1_UnrolledAdam_.1' attack_scheduled = False attack_random_init = 1. # Model arguments. nominal_args = dict(is_training=True, test_local_stats=False, reuse=False) attack_args = { 'intermediate': dict(is_training=False, test_local_stats=False, reuse=True), 'final': dict(is_training=False, test_local_stats=False, reuse=True), } if options is not None: elide = options.get('elide_last_layer', elide) loss_type = options.get('verified_loss_type', loss_type) loss_margin = options.get('verified_loss_margin', loss_type) label_smoothing = options.get('label_smoothing', label_smoothing) is_training_off_after = options.get( 'is_training_off_after', is_training_off_after) smooth_epsilon_schedule = options.get( 'smooth_epsilon_schedule', smooth_epsilon_schedule) verified_specification = options.get( 'verified_specification', verified_specification) attack_specification = options.get( 'attack_specification', attack_specification) attack_scheduled = options.get('attack_scheduled', attack_scheduled) attack_random_init = options.get('attack_random_init', attack_random_init) nominal_args = dict(options.get('nominal_args', nominal_args)) attack_args = dict(options.get('attack_args', attack_args)) def _get_schedule(init, final, warmup=None): return build_loss_schedule(global_step, warmup_steps, rampup_steps, init, final, warmup) def _is_loss_active(init, final, warmup=None): return init > 0. or final > 0. or (warmup is not None and warmup > 0.) nominal_xent = _get_schedule(**loss_weights.get('nominal')) attack_xent = _get_schedule(**loss_weights.get('attack')) use_attack = _is_loss_active(**loss_weights.get('attack')) verified_loss = _get_schedule(**loss_weights.get('verified')) use_verification = _is_loss_active(**loss_weights.get('verified')) if verified_specification == 'none': use_verification = False weight_mixture = loss.ScalarLosses( nominal_cross_entropy=nominal_xent, attack_cross_entropy=attack_xent, verified_loss=verified_loss) # Ramp-up. if rampup_steps < 0: train_epsilon = tf.constant(epsilon) else: if smooth_epsilon_schedule: train_epsilon = smooth_schedule( global_step, warmup_steps, warmup_steps + rampup_steps, 0., epsilon) else: train_epsilon = linear_schedule( global_step, warmup_steps, warmup_steps + rampup_steps, 0., epsilon) # Set is_training according to options. if is_training_off_after >= 0: is_training = global_step < is_training_off_after else: is_training = True # If the build arguments want training off, we set is_training to False. # Otherwise, we respect the is_training_off_after option. def _update_is_training(kwargs): if 'is_training' in kwargs: kwargs['is_training'] &= is_training _update_is_training(nominal_args) _update_is_training(attack_args['intermediate']) _update_is_training(attack_args['final']) logits = predictor_network(inputs, override=True, **nominal_args) num_classes = predictor_network.output_size if use_verification: logging.info('Verification active.') input_interval_bounds = bounds.IntervalBounds( tf.maximum(inputs - train_epsilon, input_bounds[0]), tf.minimum(inputs + train_epsilon, input_bounds[1])) predictor_network.propagate_bounds(input_interval_bounds) spec = create_specification(label, num_classes, logits, verified_specification, collapse=elide) else: logging.info('Verification disabled.') spec = None if use_attack: logging.info('Attack active.') pgd_attack = create_attack( attack_specification, predictor_network, label, train_epsilon if attack_scheduled else epsilon, input_bounds=input_bounds, random_init=attack_random_init, predictor_kwargs=attack_args) else: logging.info('Attack disabled.') pgd_attack = None losses = loss_builder(predictor_network, spec, pgd_attack, interval_bounds_loss_type=loss_type, interval_bounds_hinge_margin=loss_margin, label_smoothing=label_smoothing) losses(label) train_loss = sum(l * w for l, w in zip(losses.scalar_losses, weight_mixture)) # Add a regularization loss. regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) train_loss = train_loss + tf.reduce_sum(regularizers) return losses, train_loss, train_epsilon