def model_building_fn(img, is_training): end_points = ss_utils.apply_model_semi( img, None, is_training, outputs={ 'embeddings': FLAGS.triplet_embed_dim, 'classes': datasets.get_auxiliary_num_classes(), }) return end_points, end_points['classes']
def model_building_fn(img, is_training): # This is an example of calling `apply_model_semi` with only one of the # inputs provided. The outputs will simply use the given names: end_points = ss_utils.apply_model_semi( img, None, is_training, outputs={ 'classes': datasets.get_auxiliary_num_classes(), }) return end_points, end_points['classes']
def classification_net_fn(x): # pylint: disable=missing-docstring with tf.variable_scope('module', reuse=True): end_points_x = ss_utils.apply_model_semi( x, None, is_training=(mode == tf.estimator.ModeKeys.TRAIN), outputs={'classes': datasets.get_auxiliary_num_classes()}, # Don't update batch norm stats as we're running this on perturbed # (corrupted) inputs. Setting decay=1 is what does the trick. normalization_fn=functools.partial( tpu_ops.cross_replica_batch_norm, decay=1.0)) return end_points_x['classes']
def model_building_fn(img, is_training): # This is an example of calling `apply_model_semi` with only one of the # inputs provided. The outputs will simply use the given names: end_points = ss_utils.apply_model_semi( img, None, is_training, outputs={ 'rotations': num_angles, 'classes': datasets.get_auxiliary_num_classes(), }, normalization_fn=tpu_ops.cross_replica_batch_norm) return end_points, end_points['classes']
def model_fn(data, mode): """Produces a loss for the rotation task with semi-supervision. Args: data: Dict of inputs containing, among others, "image" and "label." mode: model's mode: training, eval or prediction Returns: EstimatorSpec """ num_angles = 4 # In this mode (called once at the end of training), we create the tf.Hub # module in order to export the model, and use that to do one last prediction. if mode == tf.estimator.ModeKeys.PREDICT: # This defines a function called by the hub module to create the model. def model_building_fn(img, is_training): # This is an example of calling `apply_model_semi` with only one of the # inputs provided. The outputs will simply use the given names: end_points = ss_utils.apply_model_semi( img, None, is_training, outputs={ 'rotations': num_angles, 'classes': datasets.get_auxiliary_num_classes(), }) return end_points, end_points['classes'] return trainer.make_estimator(mode, predict_fn=model_building_fn, predict_input=data['image'], polyak_averaging=FLAGS.get_flag_value( 'polyak_averaging', False)) # In all other cases, we are in train/eval mode. images_unsup = None if FLAGS.rot_loss_unsup: # Potentially flatten the rotation "R" dimension (B,R,H,W,C) into the batch # "B" dimension so we get (BR,H,W,C) images_unsup = data[0]['image'] images_unsup = utils.into_batch_dim(images_unsup) images_sup = data[1]['image'] images_sup = utils.into_batch_dim(images_sup) labels_class = data[1]['copy_label'] # Forward them both through the model. The scope is needed for tf.Hub export. with tf.variable_scope('module'): # Here, we pass both inputs to `apply_model_semi`, and so we now get # outputs corresponding to each in `end_points` as "rotations_unsup" and # similar, which we will use below. end_points = ss_utils.apply_model_semi( images_unsup, images_sup, is_training=mode == tf.estimator.ModeKeys.TRAIN, outputs={ 'rotations': num_angles, 'classes': datasets.get_auxiliary_num_classes(), }) # Compute the rotation self-supervision loss. # ===== losses_rot = [] # Compute the rotation loss on the unsupervised images. if FLAGS.rot_loss_unsup: labels_rot_unsup = tf.reshape(data[0]['label'], [-1]) loss_rot_unsup = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=end_points['rotations_unsup'], labels=labels_rot_unsup) losses_rot.append(tf.reduce_mean(loss_rot_unsup)) # And on the supervised images too. if FLAGS.rot_loss_sup: labels_rot_sup = tf.reshape(data[1]['label'], [-1]) loss_rot_sup = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=end_points['rotations_sup'], labels=labels_rot_sup) losses_rot.append(tf.reduce_mean(loss_rot_sup)) loss_rot = tf.reduce_mean(losses_rot) if losses_rot else 0.0 # Compute the classification loss on supervised images. # ===== logits_class = end_points['classes_sup'] # Replicate the supervised label for each rotated version. labels_class_repeat = tf.tile(labels_class[:, None], [1, num_angles]) labels_class_repeat = tf.reshape(labels_class_repeat, [-1]) loss_class = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels_class_repeat, logits=logits_class) loss_class = tf.reduce_mean(loss_class) # Combine losses and define metrics. # ===== w = FLAGS.rot_loss_weight loss = loss_class + w * loss_rot # At eval time, we compute accuracy of both the unrotated image, # and the average prediction across all four rotations logits_class = utils.split_batch_dim(logits_class, [-1, num_angles]) logits_class_orig = logits_class[:, 0] logits_class_avg = tf.reduce_mean(logits_class, axis=1) eval_metrics = ( lambda labels_class, logits_class_orig, logits_class_avg: { # pylint: disable=g-long-lambda 'classification/unrotated top1 accuracy': utils.top_k_accuracy(1, labels_class, logits_class_orig), 'classification/unrotated top5 accuracy': utils.top_k_accuracy(5, labels_class, logits_class_orig), 'classification/rot_avg top1 accuracy': utils.top_k_accuracy(1, labels_class, logits_class_avg), 'classification/rot_avg top5 accuracy': utils.top_k_accuracy(5, labels_class, logits_class_avg) }, [labels_class, logits_class_orig, logits_class_avg]) return trainer.make_estimator(mode, loss, eval_metrics, polyak_averaging=FLAGS.get_flag_value( 'polyak_averaging', False))
def model_fn(data, mode): """Produces a loss for the rotation task with semi-supervision. Args: data: Dict of inputs containing, among others, "image" and "label." mode: model's mode: training, eval or prediction Returns: EstimatorSpec """ num_angles = 4 # In this mode (called once at the end of training), we create the tf.Hub # module in order to export the model, and use that to do one last prediction. if mode == tf.estimator.ModeKeys.PREDICT: # This defines a function called by the hub module to create the model. def model_building_fn(img, is_training): # This is an example of calling `apply_model_semi` with only one of the # inputs provided. The outputs will simply use the given names: end_points = ss_utils.apply_model_semi( img, None, is_training, outputs={ 'rotations': num_angles, 'classes': datasets.get_auxiliary_num_classes(), }, normalization_fn=tpu_ops.cross_replica_batch_norm) return end_points, end_points['classes'] return trainer.make_estimator(mode, predict_fn=model_building_fn, predict_input=data['image'], polyak_averaging=FLAGS.get_flag_value( 'polyak_averaging', False)) # In all other cases, we are in train/eval mode. # Potentially flatten the rotation "R" dimension (B,R,H,W,C) into the batch # "B" dimension so we get (BR,H,W,C) images_unsup = data[0]['image'] images_unsup = utils.into_batch_dim(images_unsup) # For the supervised branch, we also apply rotation on them. images_sup = data[1]['image'] images_sup = utils.into_batch_dim(images_sup) labels_class = data[1]['copy_label'] # Forward them both through the model. The scope is needed for tf.Hub export. with tf.variable_scope('module'): # Here, we pass both inputs to `apply_model_semi`, and so we now get # outputs corresponding to each in `end_points` as "rotations_unsup" and # similar, which we will use below. end_points = ss_utils.apply_model_semi( images_unsup, images_sup, is_training=(mode == tf.estimator.ModeKeys.TRAIN), outputs={ 'rotations': num_angles, 'classes': datasets.get_auxiliary_num_classes(), }, normalization_fn=tpu_ops.cross_replica_batch_norm) # Compute virtual adversarial perturbation # ===== def classification_net_fn(x): # pylint: disable=missing-docstring with tf.variable_scope('module', reuse=True): end_points_x = ss_utils.apply_model_semi( x, None, is_training=(mode == tf.estimator.ModeKeys.TRAIN), outputs={'classes': datasets.get_auxiliary_num_classes()}, # Don't update batch norm stats as we're running this on perturbed # (corrupted) inputs. Setting decay=1 is what does the trick. normalization_fn=functools.partial( tpu_ops.cross_replica_batch_norm, decay=1.0)) return end_points_x['classes'] vat_eps = FLAGS.get_flag_value('vat_eps', 1.0) vat_num_power_method_iters = FLAGS.get_flag_value( 'vat_num_power_method_iters', 1) vat_perturbation = vat_utils.virtual_adversarial_perturbation_direction( images_unsup, end_points['classes_unsup'], net=classification_net_fn, num_power_method_iters=vat_num_power_method_iters, ) * vat_eps loss_vat = tf.reduce_mean( vat_utils.kl_divergence_from_logits( classification_net_fn(images_unsup + vat_perturbation), tf.stop_gradient(end_points['classes_unsup']))) # Compute the rotation self-supervision loss. # ===== # Compute the rotation loss on the unsupervised images. labels_rot_unsup = tf.reshape(data[0]['label'], [-1]) loss_rot_unsup = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=end_points['rotations_unsup'], labels=labels_rot_unsup) loss_rot = tf.reduce_mean(loss_rot_unsup) # And on the supervised images too. labels_rot_sup = tf.reshape(data[1]['label'], [-1]) loss_rot_sup = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=end_points['rotations_sup'], labels=labels_rot_sup) loss_rot_sup = tf.reduce_mean(loss_rot_sup) loss_rot = 0.5 * loss_rot + 0.5 * loss_rot_sup # Compute the classification loss on supervised images. # ===== logits_class = end_points['classes_sup'] # Replicate the supervised label for each rotated version. labels_class_repeat = tf.tile(labels_class[:, None], [1, num_angles]) labels_class_repeat = tf.reshape(labels_class_repeat, [-1]) loss_class = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels_class_repeat, logits=logits_class) loss_class = tf.reduce_mean(loss_class) # Compute the EntMin regularization loss. # ===== logits_unsup = end_points['classes_unsup'] conditional_ent = -tf.reduce_sum( tf.nn.log_softmax(logits_unsup) * tf.nn.softmax(logits_unsup), axis=-1) loss_entmin = tf.reduce_mean(conditional_ent) # Combine losses and define metrics. # ===== # Combine the two losses as a weighted average. wc = FLAGS.get_flag_value('sup_weight', 0.3) assert 0.0 <= wc <= 1.0, 'Loss weight should be in [0, 1] range.' wv = FLAGS.get_flag_value('vat_weight', 0.3) assert 0.0 <= wv <= 1.0, 'Loss weight should be in [0, 1] range.' # Combine VAT, classification and rotation loss as a weighted average, then # add weighted conditional entropy loss. loss = ((1.0 - wc - wv) * loss_rot + wc * loss_class + wv * loss_vat + FLAGS.entmin_factor * loss_entmin) train_scalar_summaries = { 'vat_eps': vat_eps, 'vat_weight': wv, 'vat_num_power_method_iters': vat_num_power_method_iters, 'loss_class': loss_class, 'loss_class_weighted': wc * loss_class, 'class_weight': wc, 'loss_vat': loss_vat, 'loss_vat_weighted': wv * loss_vat, 'rot_weight': 1.0 - wc - wv, 'loss_rot': loss_rot, 'loss_rot_weighted': (1.0 - wc - wv) * loss_rot, 'loss_entmin': loss_entmin, 'loss_entmin_weighted': FLAGS.entmin_factor * loss_entmin } # For evaluation, we want to see the result of using only the un-rotated, and # also the average of four rotated class-predictions. logits_class = utils.split_batch_dim(logits_class, [-1, num_angles]) logits_class_orig = logits_class[:, 0] logits_class_avg = tf.reduce_mean(logits_class, axis=1) eval_metrics = ( lambda labels_rot_unsup, logits_rot_unsup, labels_class, logits_class_orig, logits_class_avg: { # pylint: disable=g-long-lambda,line-too-long 'rotation top1 accuracy': utils.top_k_accuracy(1, labels_rot_unsup, logits_rot_unsup), 'classification/unrotated top1 accuracy': utils.top_k_accuracy(1, labels_class, logits_class_orig), 'classification/unrotated top5 accuracy': utils.top_k_accuracy(5, labels_class, logits_class_orig), 'classification/rot_avg top1 accuracy': utils.top_k_accuracy(1, labels_class, logits_class_avg), 'classification/rot_avg top5 accuracy': utils.top_k_accuracy(5, labels_class, logits_class_avg), }, [ labels_rot_unsup, end_points['rotations_unsup'], labels_class, logits_class_orig, logits_class_avg, ]) return trainer.make_estimator( mode, loss, eval_metrics, train_scalar_summaries=train_scalar_summaries, polyak_averaging=FLAGS.get_flag_value('polyak_averaging', False))
def model_fn(data, mode): """Produces a loss for the rotation task with semi-supervision. Args: data: Dict of inputs containing, among others, "image" and "label." mode: model's mode: training, eval or prediction Returns: EstimatorSpec """ # In this mode (called once at the end of training), we create the tf.Hub # module in order to export the model, and use that to do one last prediction. if mode == tf.estimator.ModeKeys.PREDICT: # This defines a function called by the hub module to create the model. def model_building_fn(img, is_training): # This is an example of calling `apply_model_semi` with only one of the # inputs provided. The outputs will simply use the given names: end_points = ss_utils.apply_model_semi( img, None, is_training, outputs={ 'classes': datasets.get_auxiliary_num_classes(), }) return end_points, end_points['classes'] return trainer.make_estimator(mode, predict_fn=model_building_fn, predict_input=data['image']) # In all other cases, we are in train/eval mode. # Note that here we only use data[1], i.e. the part with labels. # Forward them both through the model. The scope is needed for tf.Hub export. with tf.variable_scope('module'): # Here, we pass both inputs to `apply_model_semi`, and so we now get # outputs corresponding to each in `end_points` as "rotations_unsup" and # similar, which we will use below. end_points = ss_utils.apply_model_semi( None, data[1]['image'], is_training=mode == tf.estimator.ModeKeys.TRAIN, outputs={'classes': datasets.get_auxiliary_num_classes()}) # Compute the classification loss on supervised images. logits_class = end_points['classes'] labels_class = data[1]['label'] loss_class = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels_class, logits=logits_class) loss = tf.reduce_mean(loss_class) # Define metrics. eval_metrics = ( lambda labels_class, logits_class: { # pylint: disable=g-long-lambda 'top1 accuracy': utils.top_k_accuracy(1, labels_class, logits_class), 'top5 accuracy': utils.top_k_accuracy(5, labels_class, logits_class), }, [labels_class, logits_class]) return trainer.make_estimator(mode, loss, eval_metrics)
def model_fn(data, mode): """Produces a loss for the exemplar task. Args: data: Dict of inputs ("image" being the image) mode: model's mode: training, eval or prediction Returns: EstimatorSpec """ # In this mode (called once at the end of training), we create the tf.Hub # module in order to export the model, and use that to do one last prediction. if mode == tf.estimator.ModeKeys.PREDICT: def model_building_fn(img, is_training): end_points = ss_utils.apply_model_semi( img, None, is_training, outputs={ 'embeddings': FLAGS.triplet_embed_dim, 'classes': datasets.get_auxiliary_num_classes(), }) return end_points, end_points['classes'] return trainer.make_estimator(mode, predict_fn=model_building_fn, predict_input=data['image']) # In all other cases, we are in train/eval mode. images_unsup = data[0]['image'] images_sup = data[1]['image'] # There is one special case, typically in eval mode, when we don't want to use # multiple examples, but a single one. In that case, add the fake length-1 # example dimension to the input so that everything still works. # i.e. turn BHWC into B1HWC if images_unsup.shape.ndims == 4: images_unsup = images_unsup[:, None, ...] if images_sup.shape.ndims == 4: images_sup = images_sup[:, None, ...] # Find out the number of examples that have been created per image, which # may be different for sup/unsup, and use that for creating the labels. ninstances_unsup, nexamples_unsup = images_unsup.shape[:2] ninstances_sup, nexamples_sup = images_sup.shape[:2] # Then, fold the examples into the batch. images_unsup = utils.into_batch_dim(images_unsup) images_sup = utils.into_batch_dim(images_sup) # If we're not doing exemplar on the unsupervised data, skip it! if not FLAGS.triplet_loss_unsup: images_unsup = None # Forward them both through the model. The scope is needed for tf.Hub export. with tf.variable_scope('module'): # Here, we pass both inputs to `apply_model_semi`, and so we now get # outputs corresponding to each in `end_points` as "classes_unsup" and # similar, which we will use below. end_points = ss_utils.apply_model_semi( images_unsup, images_sup, is_training=(mode == tf.estimator.ModeKeys.TRAIN), outputs={ 'embeddings': FLAGS.triplet_embed_dim, 'classes': datasets.get_auxiliary_num_classes(), }) # Labelled classification loss # ===== # Compute the supervision loss for each example of the supervised branch. labels_class = utils.repeat(data[1]['label'], nexamples_sup) logits_class = end_points['classes_sup'] losses_class = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels_class, logits=logits_class) loss_class = tf.reduce_mean(losses_class) if mode == tf.estimator.ModeKeys.EVAL: eval_metrics = ( lambda labels_class, logits_class, losses_class: { # pylint: disable=g-long-lambda 'classification/top1 accuracy': utils.top_k_accuracy(1, labels_class, logits_class), 'classification/top5 accuracy': utils.top_k_accuracy(5, labels_class, logits_class), 'classification/loss': tf.metrics.mean(losses_class), }, [labels_class, logits_class, losses_class]) return trainer.make_estimator(mode, loss_class, eval_metrics) # Exemplar triplet loss # ===== losses_ex = [] def do_triplet(embeddings, nexamples, ninstances): """Applies the triplet loss to the given embeddings.""" # Empirically, normalizing the embeddings is more robust. embeddings = tf.nn.l2_normalize(embeddings, axis=-1) # Generate the labels as [0 0 0 1 1 1 ...] labels = utils.repeat(tf.range(ninstances), nexamples) # Apply batch-hard loss with a soft-margin. losses_tri = batch_hard(embeddings, labels, margin=0.0, soft=True, sample_pos=False, sample_neg=False) return tf.reduce_mean(losses_tri) # Compute exemplar triplet loss on the unsupervised images if FLAGS.triplet_loss_unsup: loss_ex_unsup = do_triplet(end_points['embeddings_unsup'], ninstances_unsup, nexamples_unsup) losses_ex.append(tf.reduce_mean(loss_ex_unsup)) # Compute exemplar triplet loss on the supervised images. if FLAGS.triplet_loss_sup: loss_ex_sup = do_triplet(end_points['embeddings_sup'], ninstances_sup, nexamples_sup) losses_ex.append(tf.reduce_mean(loss_ex_sup)) loss_ex = tf.reduce_mean(losses_ex) if losses_ex else 0.0 # Combine the two losses as a weighted average. loss = loss_class + FLAGS.triplet_loss_weight * loss_ex return trainer.make_estimator(mode, loss)
def model_fn(data, mode): """Produces a loss for the VAT semi-supervised task. Args: data: Dict of inputs containing, among others, "image" and "label." mode: model's mode: training, eval or prediction Returns: EstimatorSpec """ # In this mode (called once at the end of training), we create the tf.Hub # module in order to export the model, and use that to do one last prediction. if mode == tf.estimator.ModeKeys.PREDICT: # This defines a function called by the hub module to create the model. def model_building_fn(img, is_training): # This is an example of calling `apply_model_semi` with only one of the # inputs provided. The outputs will simply use the given names: end_points = ss_utils.apply_model_semi( img, None, is_training, outputs={ 'classes': datasets.get_auxiliary_num_classes(), }) return end_points, end_points['classes'] return trainer.make_estimator(mode, predict_fn=model_building_fn, predict_input=data['image']) # In all other cases, we are in train/eval mode. images_unsup = data[0]['image'] images_sup = data[1]['image'] # Forward them both through the model. The scope is needed for tf.Hub export. with tf.variable_scope('module'): # Here, we pass both inputs to `apply_model_semi`, and so we now get # outputs corresponding to each in `end_points` as "classes_unsup" and # "classes_sup", which we can use below. end_points = ss_utils.apply_model_semi( images_unsup, images_sup, is_training=(mode == tf.estimator.ModeKeys.TRAIN), outputs={'classes': datasets.get_auxiliary_num_classes()}) # Compute virtual adversarial perturbation def classification_net_fn(x): # pylint: disable=missing-docstring with tf.variable_scope('module', reuse=True): end_points_x = ss_utils.apply_model_semi( x, None, is_training=(mode == tf.estimator.ModeKeys.TRAIN), outputs={'classes': datasets.get_auxiliary_num_classes()}, # Don't update batch norm stats as we're running this on perturbed # (corrupted) inputs. Setting momentum = 1 is what does the trick. normalization_fn=functools.partial( tf.layers.batch_normalization, momentum=1.0)) return end_points_x['classes'] ## Compute VAT perturbation vat_eps = FLAGS.get_flag_value('vat_eps', 1.0) vat_num_power_method_iters = FLAGS.get_flag_value( 'vat_num_power_method_iters', 1) if FLAGS.apply_vat_to_labeled: images_vat_baseline = tf.concat((images_sup, images_unsup), axis=0) predictions_vat_baseline = tf.concat( (end_points['classes_sup'], end_points['classes_unsup']), axis=0) else: images_vat_baseline = images_unsup predictions_vat_baseline = end_points['classes_unsup'] vat_perturbation = vat_utils.virtual_adversarial_perturbation_direction( images_vat_baseline, predictions_vat_baseline, net=classification_net_fn, num_power_method_iters=vat_num_power_method_iters, ) * vat_eps loss_vat = tf.reduce_mean( vat_utils.kl_divergence_from_logits( classification_net_fn(images_vat_baseline + vat_perturbation), tf.stop_gradient(predictions_vat_baseline))) ## Compute the classification loss on supervised, clean images. labels_class = data[1]['label'] logits_class = end_points['classes_sup'] loss_class = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels_class, logits=logits_class) loss_class = tf.reduce_mean(loss_class) ## Compute conditional entropy loss conditional_ent = -tf.reduce_sum( tf.nn.log_softmax(predictions_vat_baseline) * tf.nn.softmax(predictions_vat_baseline), axis=-1) loss_entmin = tf.reduce_mean(conditional_ent) # Combine VAT and classification loss as a weighted average, then add # weighted conditional entropy loss. loss = (loss_class + FLAGS.vat_weight * loss_vat + FLAGS.entmin_factor * loss_entmin) train_scalar_summaries = { 'vat_eps': vat_eps, 'vat_weight': FLAGS.vat_weight, 'entmin_weight': FLAGS.entmin_factor, 'vat_num_power_method_iters': vat_num_power_method_iters, 'loss_class': loss_class, 'loss_vat': loss_vat, 'loss_vat_weighted': FLAGS.vat_weight * loss_vat, 'loss_entmin': loss_entmin, 'loss_entmin_weighted': FLAGS.entmin_factor * loss_entmin } eval_metrics = ( lambda labels_class, logits_class: { # pylint: disable=g-long-lambda 'classification top1 accuracy': utils.top_k_accuracy(1, labels_class, logits_class), 'classification top5 accuracy': utils.top_k_accuracy(5, labels_class, logits_class) }, [labels_class, logits_class]) return trainer.make_estimator( mode, loss, eval_metrics, train_scalar_summaries=train_scalar_summaries)