Esempio n. 1
0
 def model_building_fn(img, is_training):
     end_points = ss_utils.apply_model_semi(
         img,
         None,
         is_training,
         outputs={
             'embeddings': FLAGS.triplet_embed_dim,
             'classes': datasets.get_auxiliary_num_classes(),
         })
     return end_points, end_points['classes']
Esempio n. 2
0
 def model_building_fn(img, is_training):
     # This is an example of calling `apply_model_semi` with only one of the
     # inputs provided. The outputs will simply use the given names:
     end_points = ss_utils.apply_model_semi(
         img,
         None,
         is_training,
         outputs={
             'classes': datasets.get_auxiliary_num_classes(),
         })
     return end_points, end_points['classes']
Esempio n. 3
0
 def classification_net_fn(x):  # pylint: disable=missing-docstring
     with tf.variable_scope('module', reuse=True):
         end_points_x = ss_utils.apply_model_semi(
             x,
             None,
             is_training=(mode == tf.estimator.ModeKeys.TRAIN),
             outputs={'classes': datasets.get_auxiliary_num_classes()},
             # Don't update batch norm stats as we're running this on perturbed
             # (corrupted) inputs. Setting decay=1 is what does the trick.
             normalization_fn=functools.partial(
                 tpu_ops.cross_replica_batch_norm, decay=1.0))
         return end_points_x['classes']
Esempio n. 4
0
 def model_building_fn(img, is_training):
     # This is an example of calling `apply_model_semi` with only one of the
     # inputs provided. The outputs will simply use the given names:
     end_points = ss_utils.apply_model_semi(
         img,
         None,
         is_training,
         outputs={
             'rotations': num_angles,
             'classes': datasets.get_auxiliary_num_classes(),
         },
         normalization_fn=tpu_ops.cross_replica_batch_norm)
     return end_points, end_points['classes']
Esempio n. 5
0
def model_fn(data, mode):
    """Produces a loss for the rotation task with semi-supervision.

  Args:
    data: Dict of inputs containing, among others, "image" and "label."
    mode: model's mode: training, eval or prediction

  Returns:
    EstimatorSpec
  """
    num_angles = 4

    # In this mode (called once at the end of training), we create the tf.Hub
    # module in order to export the model, and use that to do one last prediction.
    if mode == tf.estimator.ModeKeys.PREDICT:
        # This defines a function called by the hub module to create the model.
        def model_building_fn(img, is_training):
            # This is an example of calling `apply_model_semi` with only one of the
            # inputs provided. The outputs will simply use the given names:
            end_points = ss_utils.apply_model_semi(
                img,
                None,
                is_training,
                outputs={
                    'rotations': num_angles,
                    'classes': datasets.get_auxiliary_num_classes(),
                })
            return end_points, end_points['classes']

        return trainer.make_estimator(mode,
                                      predict_fn=model_building_fn,
                                      predict_input=data['image'],
                                      polyak_averaging=FLAGS.get_flag_value(
                                          'polyak_averaging', False))

    # In all other cases, we are in train/eval mode.

    images_unsup = None
    if FLAGS.rot_loss_unsup:
        # Potentially flatten the rotation "R" dimension (B,R,H,W,C) into the batch
        # "B" dimension so we get (BR,H,W,C)
        images_unsup = data[0]['image']
        images_unsup = utils.into_batch_dim(images_unsup)

    images_sup = data[1]['image']
    images_sup = utils.into_batch_dim(images_sup)
    labels_class = data[1]['copy_label']

    # Forward them both through the model. The scope is needed for tf.Hub export.
    with tf.variable_scope('module'):
        # Here, we pass both inputs to `apply_model_semi`, and so we now get
        # outputs corresponding to each in `end_points` as "rotations_unsup" and
        # similar, which we will use below.
        end_points = ss_utils.apply_model_semi(
            images_unsup,
            images_sup,
            is_training=mode == tf.estimator.ModeKeys.TRAIN,
            outputs={
                'rotations': num_angles,
                'classes': datasets.get_auxiliary_num_classes(),
            })

    # Compute the rotation self-supervision loss.
    # =====

    losses_rot = []

    # Compute the rotation loss on the unsupervised images.
    if FLAGS.rot_loss_unsup:
        labels_rot_unsup = tf.reshape(data[0]['label'], [-1])
        loss_rot_unsup = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=end_points['rotations_unsup'], labels=labels_rot_unsup)
        losses_rot.append(tf.reduce_mean(loss_rot_unsup))

    # And on the supervised images too.
    if FLAGS.rot_loss_sup:
        labels_rot_sup = tf.reshape(data[1]['label'], [-1])
        loss_rot_sup = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=end_points['rotations_sup'], labels=labels_rot_sup)
        losses_rot.append(tf.reduce_mean(loss_rot_sup))

    loss_rot = tf.reduce_mean(losses_rot) if losses_rot else 0.0

    # Compute the classification loss on supervised images.
    # =====
    logits_class = end_points['classes_sup']

    # Replicate the supervised label for each rotated version.
    labels_class_repeat = tf.tile(labels_class[:, None], [1, num_angles])
    labels_class_repeat = tf.reshape(labels_class_repeat, [-1])

    loss_class = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels_class_repeat, logits=logits_class)
    loss_class = tf.reduce_mean(loss_class)

    # Combine losses and define metrics.
    # =====
    w = FLAGS.rot_loss_weight
    loss = loss_class + w * loss_rot

    # At eval time, we compute accuracy of both the unrotated image,
    # and the average prediction across all four rotations
    logits_class = utils.split_batch_dim(logits_class, [-1, num_angles])
    logits_class_orig = logits_class[:, 0]
    logits_class_avg = tf.reduce_mean(logits_class, axis=1)

    eval_metrics = (
        lambda labels_class, logits_class_orig, logits_class_avg: {  # pylint: disable=g-long-lambda
            'classification/unrotated top1 accuracy':
            utils.top_k_accuracy(1, labels_class, logits_class_orig),
            'classification/unrotated top5 accuracy':
            utils.top_k_accuracy(5, labels_class, logits_class_orig),
            'classification/rot_avg top1 accuracy':
            utils.top_k_accuracy(1, labels_class, logits_class_avg),
            'classification/rot_avg top5 accuracy':
            utils.top_k_accuracy(5, labels_class, logits_class_avg)
        },
        [labels_class, logits_class_orig, logits_class_avg])

    return trainer.make_estimator(mode,
                                  loss,
                                  eval_metrics,
                                  polyak_averaging=FLAGS.get_flag_value(
                                      'polyak_averaging', False))
Esempio n. 6
0
def model_fn(data, mode):
    """Produces a loss for the rotation task with semi-supervision.

  Args:
    data: Dict of inputs containing, among others, "image" and "label."
    mode: model's mode: training, eval or prediction

  Returns:
    EstimatorSpec
  """
    num_angles = 4

    # In this mode (called once at the end of training), we create the tf.Hub
    # module in order to export the model, and use that to do one last prediction.
    if mode == tf.estimator.ModeKeys.PREDICT:
        # This defines a function called by the hub module to create the model.
        def model_building_fn(img, is_training):
            # This is an example of calling `apply_model_semi` with only one of the
            # inputs provided. The outputs will simply use the given names:
            end_points = ss_utils.apply_model_semi(
                img,
                None,
                is_training,
                outputs={
                    'rotations': num_angles,
                    'classes': datasets.get_auxiliary_num_classes(),
                },
                normalization_fn=tpu_ops.cross_replica_batch_norm)
            return end_points, end_points['classes']

        return trainer.make_estimator(mode,
                                      predict_fn=model_building_fn,
                                      predict_input=data['image'],
                                      polyak_averaging=FLAGS.get_flag_value(
                                          'polyak_averaging', False))

    # In all other cases, we are in train/eval mode.

    # Potentially flatten the rotation "R" dimension (B,R,H,W,C) into the batch
    # "B" dimension so we get (BR,H,W,C)
    images_unsup = data[0]['image']
    images_unsup = utils.into_batch_dim(images_unsup)

    # For the supervised branch, we also apply rotation on them.
    images_sup = data[1]['image']
    images_sup = utils.into_batch_dim(images_sup)
    labels_class = data[1]['copy_label']

    # Forward them both through the model. The scope is needed for tf.Hub export.
    with tf.variable_scope('module'):
        # Here, we pass both inputs to `apply_model_semi`, and so we now get
        # outputs corresponding to each in `end_points` as "rotations_unsup" and
        # similar, which we will use below.
        end_points = ss_utils.apply_model_semi(
            images_unsup,
            images_sup,
            is_training=(mode == tf.estimator.ModeKeys.TRAIN),
            outputs={
                'rotations': num_angles,
                'classes': datasets.get_auxiliary_num_classes(),
            },
            normalization_fn=tpu_ops.cross_replica_batch_norm)

    # Compute virtual adversarial perturbation
    # =====

    def classification_net_fn(x):  # pylint: disable=missing-docstring
        with tf.variable_scope('module', reuse=True):
            end_points_x = ss_utils.apply_model_semi(
                x,
                None,
                is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                outputs={'classes': datasets.get_auxiliary_num_classes()},
                # Don't update batch norm stats as we're running this on perturbed
                # (corrupted) inputs. Setting decay=1 is what does the trick.
                normalization_fn=functools.partial(
                    tpu_ops.cross_replica_batch_norm, decay=1.0))
            return end_points_x['classes']

    vat_eps = FLAGS.get_flag_value('vat_eps', 1.0)
    vat_num_power_method_iters = FLAGS.get_flag_value(
        'vat_num_power_method_iters', 1)
    vat_perturbation = vat_utils.virtual_adversarial_perturbation_direction(
        images_unsup,
        end_points['classes_unsup'],
        net=classification_net_fn,
        num_power_method_iters=vat_num_power_method_iters,
    ) * vat_eps

    loss_vat = tf.reduce_mean(
        vat_utils.kl_divergence_from_logits(
            classification_net_fn(images_unsup + vat_perturbation),
            tf.stop_gradient(end_points['classes_unsup'])))

    # Compute the rotation self-supervision loss.
    # =====

    # Compute the rotation loss on the unsupervised images.
    labels_rot_unsup = tf.reshape(data[0]['label'], [-1])
    loss_rot_unsup = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=end_points['rotations_unsup'], labels=labels_rot_unsup)
    loss_rot = tf.reduce_mean(loss_rot_unsup)

    # And on the supervised images too.
    labels_rot_sup = tf.reshape(data[1]['label'], [-1])
    loss_rot_sup = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=end_points['rotations_sup'], labels=labels_rot_sup)
    loss_rot_sup = tf.reduce_mean(loss_rot_sup)

    loss_rot = 0.5 * loss_rot + 0.5 * loss_rot_sup

    # Compute the classification loss on supervised images.
    # =====
    logits_class = end_points['classes_sup']

    # Replicate the supervised label for each rotated version.
    labels_class_repeat = tf.tile(labels_class[:, None], [1, num_angles])
    labels_class_repeat = tf.reshape(labels_class_repeat, [-1])

    loss_class = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels_class_repeat, logits=logits_class)
    loss_class = tf.reduce_mean(loss_class)

    # Compute the EntMin regularization loss.
    # =====
    logits_unsup = end_points['classes_unsup']
    conditional_ent = -tf.reduce_sum(
        tf.nn.log_softmax(logits_unsup) * tf.nn.softmax(logits_unsup), axis=-1)
    loss_entmin = tf.reduce_mean(conditional_ent)

    # Combine losses and define metrics.
    # =====

    # Combine the two losses as a weighted average.
    wc = FLAGS.get_flag_value('sup_weight', 0.3)
    assert 0.0 <= wc <= 1.0, 'Loss weight should be in [0, 1] range.'
    wv = FLAGS.get_flag_value('vat_weight', 0.3)
    assert 0.0 <= wv <= 1.0, 'Loss weight should be in [0, 1] range.'

    # Combine VAT, classification and rotation loss as a weighted average, then
    # add weighted conditional entropy loss.
    loss = ((1.0 - wc - wv) * loss_rot + wc * loss_class + wv * loss_vat +
            FLAGS.entmin_factor * loss_entmin)

    train_scalar_summaries = {
        'vat_eps': vat_eps,
        'vat_weight': wv,
        'vat_num_power_method_iters': vat_num_power_method_iters,
        'loss_class': loss_class,
        'loss_class_weighted': wc * loss_class,
        'class_weight': wc,
        'loss_vat': loss_vat,
        'loss_vat_weighted': wv * loss_vat,
        'rot_weight': 1.0 - wc - wv,
        'loss_rot': loss_rot,
        'loss_rot_weighted': (1.0 - wc - wv) * loss_rot,
        'loss_entmin': loss_entmin,
        'loss_entmin_weighted': FLAGS.entmin_factor * loss_entmin
    }

    # For evaluation, we want to see the result of using only the un-rotated, and
    # also the average of four rotated class-predictions.
    logits_class = utils.split_batch_dim(logits_class, [-1, num_angles])
    logits_class_orig = logits_class[:, 0]
    logits_class_avg = tf.reduce_mean(logits_class, axis=1)

    eval_metrics = (
        lambda labels_rot_unsup, logits_rot_unsup, labels_class, logits_class_orig, logits_class_avg: {  # pylint: disable=g-long-lambda,line-too-long
            'rotation top1 accuracy':
                utils.top_k_accuracy(1, labels_rot_unsup, logits_rot_unsup),
            'classification/unrotated top1 accuracy':
                utils.top_k_accuracy(1, labels_class, logits_class_orig),
            'classification/unrotated top5 accuracy':
                utils.top_k_accuracy(5, labels_class, logits_class_orig),
            'classification/rot_avg top1 accuracy':
                utils.top_k_accuracy(1, labels_class, logits_class_avg),
            'classification/rot_avg top5 accuracy':
                utils.top_k_accuracy(5, labels_class, logits_class_avg),
        },
        [
            labels_rot_unsup, end_points['rotations_unsup'],
            labels_class, logits_class_orig, logits_class_avg,
        ])

    return trainer.make_estimator(
        mode,
        loss,
        eval_metrics,
        train_scalar_summaries=train_scalar_summaries,
        polyak_averaging=FLAGS.get_flag_value('polyak_averaging', False))
Esempio n. 7
0
def model_fn(data, mode):
    """Produces a loss for the rotation task with semi-supervision.

  Args:
    data: Dict of inputs containing, among others, "image" and "label."
    mode: model's mode: training, eval or prediction

  Returns:
    EstimatorSpec
  """
    # In this mode (called once at the end of training), we create the tf.Hub
    # module in order to export the model, and use that to do one last prediction.
    if mode == tf.estimator.ModeKeys.PREDICT:
        # This defines a function called by the hub module to create the model.
        def model_building_fn(img, is_training):
            # This is an example of calling `apply_model_semi` with only one of the
            # inputs provided. The outputs will simply use the given names:
            end_points = ss_utils.apply_model_semi(
                img,
                None,
                is_training,
                outputs={
                    'classes': datasets.get_auxiliary_num_classes(),
                })
            return end_points, end_points['classes']

        return trainer.make_estimator(mode,
                                      predict_fn=model_building_fn,
                                      predict_input=data['image'])

    # In all other cases, we are in train/eval mode.
    # Note that here we only use data[1], i.e. the part with labels.

    # Forward them both through the model. The scope is needed for tf.Hub export.
    with tf.variable_scope('module'):
        # Here, we pass both inputs to `apply_model_semi`, and so we now get
        # outputs corresponding to each in `end_points` as "rotations_unsup" and
        # similar, which we will use below.
        end_points = ss_utils.apply_model_semi(
            None,
            data[1]['image'],
            is_training=mode == tf.estimator.ModeKeys.TRAIN,
            outputs={'classes': datasets.get_auxiliary_num_classes()})

    # Compute the classification loss on supervised images.

    logits_class = end_points['classes']
    labels_class = data[1]['label']
    loss_class = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels_class, logits=logits_class)
    loss = tf.reduce_mean(loss_class)

    # Define metrics.

    eval_metrics = (
        lambda labels_class, logits_class: {  # pylint: disable=g-long-lambda
            'top1 accuracy': utils.top_k_accuracy(1, labels_class, logits_class),
            'top5 accuracy': utils.top_k_accuracy(5, labels_class, logits_class),
        }, [labels_class, logits_class])

    return trainer.make_estimator(mode, loss, eval_metrics)
Esempio n. 8
0
def model_fn(data, mode):
    """Produces a loss for the exemplar task.

  Args:
    data: Dict of inputs ("image" being the image)
    mode: model's mode: training, eval or prediction

  Returns:
    EstimatorSpec
  """

    # In this mode (called once at the end of training), we create the tf.Hub
    # module in order to export the model, and use that to do one last prediction.
    if mode == tf.estimator.ModeKeys.PREDICT:

        def model_building_fn(img, is_training):
            end_points = ss_utils.apply_model_semi(
                img,
                None,
                is_training,
                outputs={
                    'embeddings': FLAGS.triplet_embed_dim,
                    'classes': datasets.get_auxiliary_num_classes(),
                })
            return end_points, end_points['classes']

        return trainer.make_estimator(mode,
                                      predict_fn=model_building_fn,
                                      predict_input=data['image'])

    # In all other cases, we are in train/eval mode.
    images_unsup = data[0]['image']
    images_sup = data[1]['image']

    # There is one special case, typically in eval mode, when we don't want to use
    # multiple examples, but a single one. In that case, add the fake length-1
    # example dimension to the input so that everything still works.
    # i.e. turn BHWC into B1HWC
    if images_unsup.shape.ndims == 4:
        images_unsup = images_unsup[:, None, ...]
    if images_sup.shape.ndims == 4:
        images_sup = images_sup[:, None, ...]

    # Find out the number of examples that have been created per image, which
    # may be different for sup/unsup, and use that for creating the labels.
    ninstances_unsup, nexamples_unsup = images_unsup.shape[:2]
    ninstances_sup, nexamples_sup = images_sup.shape[:2]

    # Then, fold the examples into the batch.
    images_unsup = utils.into_batch_dim(images_unsup)
    images_sup = utils.into_batch_dim(images_sup)

    # If we're not doing exemplar on the unsupervised data, skip it!
    if not FLAGS.triplet_loss_unsup:
        images_unsup = None

    # Forward them both through the model. The scope is needed for tf.Hub export.
    with tf.variable_scope('module'):
        # Here, we pass both inputs to `apply_model_semi`, and so we now get
        # outputs corresponding to each in `end_points` as "classes_unsup" and
        # similar, which we will use below.
        end_points = ss_utils.apply_model_semi(
            images_unsup,
            images_sup,
            is_training=(mode == tf.estimator.ModeKeys.TRAIN),
            outputs={
                'embeddings': FLAGS.triplet_embed_dim,
                'classes': datasets.get_auxiliary_num_classes(),
            })

    # Labelled classification loss
    # =====

    # Compute the supervision loss for each example of the supervised branch.
    labels_class = utils.repeat(data[1]['label'], nexamples_sup)
    logits_class = end_points['classes_sup']
    losses_class = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels_class, logits=logits_class)
    loss_class = tf.reduce_mean(losses_class)

    if mode == tf.estimator.ModeKeys.EVAL:
        eval_metrics = (
            lambda labels_class, logits_class, losses_class: {  # pylint: disable=g-long-lambda
                'classification/top1 accuracy':
                    utils.top_k_accuracy(1, labels_class, logits_class),
                'classification/top5 accuracy':
                    utils.top_k_accuracy(5, labels_class, logits_class),
                'classification/loss': tf.metrics.mean(losses_class),
            }, [labels_class, logits_class, losses_class])

        return trainer.make_estimator(mode, loss_class, eval_metrics)

    # Exemplar triplet loss
    # =====
    losses_ex = []

    def do_triplet(embeddings, nexamples, ninstances):
        """Applies the triplet loss to the given embeddings."""
        # Empirically, normalizing the embeddings is more robust.
        embeddings = tf.nn.l2_normalize(embeddings, axis=-1)

        # Generate the labels as [0 0 0 1 1 1 ...]
        labels = utils.repeat(tf.range(ninstances), nexamples)

        # Apply batch-hard loss with a soft-margin.
        losses_tri = batch_hard(embeddings,
                                labels,
                                margin=0.0,
                                soft=True,
                                sample_pos=False,
                                sample_neg=False)
        return tf.reduce_mean(losses_tri)

    # Compute exemplar triplet loss on the unsupervised images
    if FLAGS.triplet_loss_unsup:
        loss_ex_unsup = do_triplet(end_points['embeddings_unsup'],
                                   ninstances_unsup, nexamples_unsup)
        losses_ex.append(tf.reduce_mean(loss_ex_unsup))

    # Compute exemplar triplet loss on the supervised images.
    if FLAGS.triplet_loss_sup:
        loss_ex_sup = do_triplet(end_points['embeddings_sup'], ninstances_sup,
                                 nexamples_sup)
        losses_ex.append(tf.reduce_mean(loss_ex_sup))

    loss_ex = tf.reduce_mean(losses_ex) if losses_ex else 0.0

    # Combine the two losses as a weighted average.
    loss = loss_class + FLAGS.triplet_loss_weight * loss_ex

    return trainer.make_estimator(mode, loss)
Esempio n. 9
0
def model_fn(data, mode):
    """Produces a loss for the VAT semi-supervised task.

  Args:
    data: Dict of inputs containing, among others, "image" and "label."
    mode: model's mode: training, eval or prediction

  Returns:
    EstimatorSpec
  """

    # In this mode (called once at the end of training), we create the tf.Hub
    # module in order to export the model, and use that to do one last prediction.
    if mode == tf.estimator.ModeKeys.PREDICT:
        # This defines a function called by the hub module to create the model.
        def model_building_fn(img, is_training):
            # This is an example of calling `apply_model_semi` with only one of the
            # inputs provided. The outputs will simply use the given names:
            end_points = ss_utils.apply_model_semi(
                img,
                None,
                is_training,
                outputs={
                    'classes': datasets.get_auxiliary_num_classes(),
                })
            return end_points, end_points['classes']

        return trainer.make_estimator(mode,
                                      predict_fn=model_building_fn,
                                      predict_input=data['image'])

    # In all other cases, we are in train/eval mode.

    images_unsup = data[0]['image']
    images_sup = data[1]['image']

    # Forward them both through the model. The scope is needed for tf.Hub export.
    with tf.variable_scope('module'):
        # Here, we pass both inputs to `apply_model_semi`, and so we now get
        # outputs corresponding to each in `end_points` as "classes_unsup" and
        # "classes_sup", which we can use below.
        end_points = ss_utils.apply_model_semi(
            images_unsup,
            images_sup,
            is_training=(mode == tf.estimator.ModeKeys.TRAIN),
            outputs={'classes': datasets.get_auxiliary_num_classes()})

    # Compute virtual adversarial perturbation
    def classification_net_fn(x):  # pylint: disable=missing-docstring
        with tf.variable_scope('module', reuse=True):
            end_points_x = ss_utils.apply_model_semi(
                x,
                None,
                is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                outputs={'classes': datasets.get_auxiliary_num_classes()},
                # Don't update batch norm stats as we're running this on perturbed
                # (corrupted) inputs. Setting momentum = 1 is what does the trick.
                normalization_fn=functools.partial(
                    tf.layers.batch_normalization, momentum=1.0))
            return end_points_x['classes']

    ## Compute VAT perturbation
    vat_eps = FLAGS.get_flag_value('vat_eps', 1.0)
    vat_num_power_method_iters = FLAGS.get_flag_value(
        'vat_num_power_method_iters', 1)

    if FLAGS.apply_vat_to_labeled:
        images_vat_baseline = tf.concat((images_sup, images_unsup), axis=0)
        predictions_vat_baseline = tf.concat(
            (end_points['classes_sup'], end_points['classes_unsup']), axis=0)
    else:
        images_vat_baseline = images_unsup
        predictions_vat_baseline = end_points['classes_unsup']

    vat_perturbation = vat_utils.virtual_adversarial_perturbation_direction(
        images_vat_baseline,
        predictions_vat_baseline,
        net=classification_net_fn,
        num_power_method_iters=vat_num_power_method_iters,
    ) * vat_eps

    loss_vat = tf.reduce_mean(
        vat_utils.kl_divergence_from_logits(
            classification_net_fn(images_vat_baseline + vat_perturbation),
            tf.stop_gradient(predictions_vat_baseline)))

    ## Compute the classification loss on supervised, clean images.
    labels_class = data[1]['label']
    logits_class = end_points['classes_sup']

    loss_class = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels_class, logits=logits_class)
    loss_class = tf.reduce_mean(loss_class)

    ## Compute conditional entropy loss
    conditional_ent = -tf.reduce_sum(
        tf.nn.log_softmax(predictions_vat_baseline) *
        tf.nn.softmax(predictions_vat_baseline),
        axis=-1)
    loss_entmin = tf.reduce_mean(conditional_ent)

    # Combine VAT and classification loss as a weighted average, then add
    # weighted conditional entropy loss.
    loss = (loss_class + FLAGS.vat_weight * loss_vat +
            FLAGS.entmin_factor * loss_entmin)

    train_scalar_summaries = {
        'vat_eps': vat_eps,
        'vat_weight': FLAGS.vat_weight,
        'entmin_weight': FLAGS.entmin_factor,
        'vat_num_power_method_iters': vat_num_power_method_iters,
        'loss_class': loss_class,
        'loss_vat': loss_vat,
        'loss_vat_weighted': FLAGS.vat_weight * loss_vat,
        'loss_entmin': loss_entmin,
        'loss_entmin_weighted': FLAGS.entmin_factor * loss_entmin
    }

    eval_metrics = (
        lambda labels_class, logits_class: {  # pylint: disable=g-long-lambda
            'classification top1 accuracy':
            utils.top_k_accuracy(1, labels_class, logits_class),
            'classification top5 accuracy':
            utils.top_k_accuracy(5, labels_class, logits_class)
        },
        [labels_class, logits_class])

    return trainer.make_estimator(
        mode,
        loss,
        eval_metrics,
        train_scalar_summaries=train_scalar_summaries)