Exemplo n.º 1
0
def model_fn(features, mode, params):
    '''The model_fn to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
        `params['batch_size']` is always provided and should be used as the
        effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  '''
    def preprocess_image(image):
        # In most cases, the default data format NCHW instead of NHWC should be
        # used for a significant performance boost on GPU. NHWC should be used
        # only if the network needs to be run on CPU since the pooling operations
        # are only supported on NHWC. TPU uses XLA compiler to figure out best layout.
        if FLAGS.data_format == 'channels_first':
            assert not FLAGS.transpose_input  # channels_first only for GPU
            image = tf.transpose(image, [0, 3, 1, 2])

        if FLAGS.transpose_input and mode == tf.estimator.ModeKeys.TRAIN:
            image = tf.transpose(image, [3, 0, 1, 2])  # HWCN to NHWC
        return image

    def normalize_image(image):
        # Normalize the image to zero mean and unit variance.
        if FLAGS.data_format == 'channels_first':
            stats_shape = [3, 1, 1]
        else:
            stats_shape = [1, 1, 3]
        mean, std = task_info.get_mean_std(FLAGS.task_name)
        image -= tf.constant(mean, shape=stats_shape, dtype=image.dtype)
        image /= tf.constant(std, shape=stats_shape, dtype=image.dtype)
        return image

    image = features['image']
    image = preprocess_image(image)

    image_shape = image.get_shape().as_list()
    tf.logging.info('image shape: {}'.format(image_shape))
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    if mode != tf.estimator.ModeKeys.PREDICT:
        labels = features['label']
    else:
        labels = None

    # If necessary, in the model_fn, use params['batch_size'] instead the batch
    # size flags (--train_batch_size or --eval_batch_size).
    batch_size = params['batch_size']  # pylint: disable=unused-variable

    if FLAGS.unlabel_ratio and is_training:
        unl_bsz = features['unl_probs'].shape[0]
    else:
        unl_bsz = 0

    lab_bsz = image.shape[0] - unl_bsz
    assert lab_bsz == batch_size

    metric_dict = {}
    global_step = tf.train.get_global_step()

    has_moving_average_decay = (FLAGS.moving_average_decay > 0)
    # This is essential, if using a keras-derived model.
    tf.keras.backend.set_learning_phase(is_training)
    tf.logging.info('Using open-source implementation.')
    override_params = {}
    if FLAGS.dropout_rate is not None:
        override_params['dropout_rate'] = FLAGS.dropout_rate
    if FLAGS.stochastic_depth_rate is not None:
        override_params['stochastic_depth_rate'] = FLAGS.stochastic_depth_rate
    if FLAGS.data_format:
        override_params['data_format'] = FLAGS.data_format
    if FLAGS.num_label_classes:
        override_params['num_classes'] = FLAGS.num_label_classes
    if FLAGS.depth_coefficient:
        override_params['depth_coefficient'] = FLAGS.depth_coefficient
    if FLAGS.width_coefficient:
        override_params['width_coefficient'] = FLAGS.width_coefficient

    def build_model(scope=None,
                    reuse=tf.AUTO_REUSE,
                    model_name=None,
                    model_is_training=None,
                    input_image=None,
                    use_adv_bn=False,
                    is_teacher=False):
        model_name = model_name or FLAGS.model_name
        if model_is_training is None:
            model_is_training = is_training
        if input_image is None:
            input_image = image
        input_image = normalize_image(input_image)

        scope_model_name = model_name

        if scope:
            scope = scope + '/'
        else:
            scope = ''
        with tf.variable_scope(scope + scope_model_name, reuse=reuse):
            if model_name.startswith('efficientnet'):
                logits, _ = efficientnet_builder.build_model(
                    input_image,
                    model_name=model_name,
                    training=model_is_training,
                    override_params=override_params,
                    model_dir=FLAGS.model_dir,
                    use_adv_bn=use_adv_bn,
                    is_teacher=is_teacher)
            else:
                assert False, 'model {} not implemented'.format(model_name)
        return logits

    if params['use_bfloat16']:
        with tf.tpu.bfloat16_scope():
            logits = tf.cast(build_model(), tf.float32)
    else:
        logits = build_model()

    if FLAGS.teacher_model_name:
        teacher_image = preprocess_image(features['teacher_image'])
        if params['use_bfloat16']:
            with tf.tpu.bfloat16_scope():
                teacher_logits = tf.cast(
                    build_model(scope='teacher_model',
                                model_name=FLAGS.teacher_model_name,
                                model_is_training=False,
                                input_image=teacher_image,
                                is_teacher=True), tf.float32)
        else:
            teacher_logits = build_model(scope='teacher_model',
                                         model_name=FLAGS.teacher_model_name,
                                         model_is_training=False,
                                         input_image=teacher_image,
                                         is_teacher=True)
        teacher_logits = tf.stop_gradient(teacher_logits)
        if FLAGS.teacher_softmax_temp != -1:
            teacher_prob = tf.nn.softmax(teacher_logits /
                                         FLAGS.teacher_softmax_temp)
        else:
            teacher_prob = None
            teacher_one_hot_pred = tf.argmax(teacher_logits,
                                             axis=1,
                                             output_type=labels.dtype)

    if mode == tf.estimator.ModeKeys.PREDICT:
        if has_moving_average_decay:
            ema = tf.train.ExponentialMovingAverage(
                decay=FLAGS.moving_average_decay)
            ema_vars = utils.get_all_variable()
            restore_vars_dict = ema.variables_to_restore(ema_vars)
            tf.logging.info(
                'restored variables:\n%s',
                json.dumps(sorted(restore_vars_dict.keys()), indent=4))

        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.tpu.TPUEstimatorSpec(
            mode=mode,
            predictions=predictions,
            scaffold_fn=functools.partial(_scaffold_fn,
                                          restore_vars_dict=restore_vars_dict)
            if has_moving_average_decay else None)

    if has_moving_average_decay:
        ema_step = global_step
        ema = tf.train.ExponentialMovingAverage(
            decay=FLAGS.moving_average_decay, num_updates=ema_step)
        ema_vars = utils.get_all_variable()

    lab_labels = labels[:lab_bsz]
    lab_logits = logits[:lab_bsz]
    lab_pred = tf.argmax(lab_logits, axis=-1, output_type=labels.dtype)
    lab_prob = tf.nn.softmax(lab_logits)
    lab_acc = tf.to_float(tf.equal(lab_pred, lab_labels))
    metric_dict['lab/acc'] = tf.reduce_mean(lab_acc)
    metric_dict['lab/pred_prob'] = tf.reduce_mean(
        tf.reduce_max(lab_prob, axis=-1))
    one_hot_labels = tf.one_hot(lab_labels, FLAGS.num_label_classes)

    if FLAGS.unlabel_ratio:
        unl_labels = labels[lab_bsz:]
        unl_logits = logits[lab_bsz:]
        unl_pred = tf.argmax(unl_logits, axis=-1, output_type=labels.dtype)
        unl_prob = tf.nn.softmax(unl_logits)
        unl_acc = tf.to_float(tf.equal(unl_pred, unl_labels))
        metric_dict['unl/acc_to_dump'] = tf.reduce_mean(unl_acc)
        metric_dict['unl/pred_prob'] = tf.reduce_mean(
            tf.reduce_max(unl_prob, axis=-1))

    # compute lab_loss
    one_hot_labels = tf.one_hot(lab_labels, FLAGS.num_label_classes)
    lab_loss = tf.losses.softmax_cross_entropy(
        logits=lab_logits,
        onehot_labels=one_hot_labels,
        label_smoothing=FLAGS.label_smoothing,
        reduction=tf.losses.Reduction.NONE)
    if FLAGS.label_data_sample_prob != 1:
        # mask out part of the labeled data
        random_mask = tf.floor(
            FLAGS.label_data_sample_prob +
            tf.random_uniform(tf.shape(lab_loss), dtype=lab_loss.dtype))
        lab_loss = tf.reduce_mean(lab_loss * random_mask)
    else:
        lab_loss = tf.reduce_mean(lab_loss)
    metric_dict['lab/loss'] = lab_loss

    if FLAGS.unlabel_ratio:
        if FLAGS.teacher_softmax_temp == -1:  # Hard labels
            # Get one-hot labels
            if FLAGS.teacher_model_name:
                ext_teacher_pred = teacher_one_hot_pred[lab_bsz:]
                one_hot_labels = tf.one_hot(ext_teacher_pred,
                                            FLAGS.num_label_classes)
            else:
                one_hot_labels = tf.one_hot(unl_labels,
                                            FLAGS.num_label_classes)
            # Compute cross entropy
            unl_loss = tf.losses.softmax_cross_entropy(
                logits=unl_logits,
                onehot_labels=one_hot_labels,
                label_smoothing=FLAGS.label_smoothing)
        else:  # Soft labels
            # Get teacher prob
            if FLAGS.teacher_model_name:
                unl_teacher_prob = teacher_prob[lab_bsz:]
            else:
                scaled_prob = tf.pow(features['unl_probs'],
                                     1 / FLAGS.teacher_softmax_temp)
                unl_teacher_prob = scaled_prob / tf.reduce_sum(
                    scaled_prob, axis=-1, keepdims=True)
            metric_dict['unl/target_prob'] = tf.reduce_mean(
                tf.reduce_max(unl_teacher_prob, axis=-1))
            unl_loss = cross_entropy(unl_teacher_prob,
                                     unl_logits,
                                     return_mean=True)

        metric_dict['ext/loss'] = unl_loss
    else:
        unl_loss = 0

    real_lab_bsz = tf.to_float(lab_bsz) * FLAGS.label_data_sample_prob
    real_unl_bsz = batch_size * FLAGS.label_data_sample_prob * FLAGS.unlabel_ratio
    data_loss = lab_loss * real_lab_bsz + unl_loss * real_unl_bsz
    data_loss = data_loss / real_lab_bsz

    # Add weight decay to the loss for non-batch-normalization variables.
    loss = data_loss + FLAGS.weight_decay * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])
    metric_dict['train/data_loss'] = data_loss
    metric_dict['train/loss'] = loss

    host_call = None
    restore_vars_dict = None

    if is_training:
        # Compute the current epoch and associated learning rate from global_step.
        current_epoch = (tf.cast(global_step, tf.float32) /
                         params['steps_per_epoch'])
        real_train_batch_size = FLAGS.train_batch_size
        real_train_batch_size *= FLAGS.label_data_sample_prob
        scaled_lr = FLAGS.base_learning_rate * (real_train_batch_size / 256.0)
        if FLAGS.final_base_lr:
            # total number of training epochs
            total_epochs = FLAGS.train_steps * FLAGS.train_batch_size * 1. / FLAGS.num_train_images - 5
            decay_times = math.log(FLAGS.final_base_lr /
                                   FLAGS.base_learning_rate) / math.log(0.97)
            decay_epochs = total_epochs / decay_times
            tf.logging.info(
                'setting decay_epochs to {:.2f}'.format(decay_epochs) +
                '\n' * 3)
        else:
            decay_epochs = 2.4 * FLAGS.train_ratio
        learning_rate = utils.build_learning_rate(
            scaled_lr,
            global_step,
            params['steps_per_epoch'],
            decay_epochs=decay_epochs,
            start_from_step=FLAGS.train_steps - FLAGS.train_last_step_num,
            warmup_epochs=5,
        )
        metric_dict['train/lr'] = learning_rate
        metric_dict['train/epoch'] = current_epoch
        optimizer = utils.build_optimizer(learning_rate)
        if FLAGS.use_tpu:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = tf.tpu.CrossShardOptimizer(optimizer)

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        tvars = tf.trainable_variables()
        g_vars = []
        tvars = sorted(tvars, key=lambda var: var.name)
        for var in tvars:
            if 'teacher_model' not in var.name:
                g_vars += [var]
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step, var_list=g_vars)

        if has_moving_average_decay:
            with tf.control_dependencies([train_op]):
                train_op = ema.apply(ema_vars)

        if not FLAGS.skip_host_call:
            host_call = utils.construct_scalar_host_call(metric_dict)
        scaffold_fn = None
        if FLAGS.teacher_model_name or FLAGS.init_model:
            scaffold_fn = utils.init_from_ckpt(scaffold_fn)
    else:
        train_op = None
        if has_moving_average_decay:
            # Load moving average variables for eval.
            restore_vars_dict = ema.variables_to_restore(ema_vars)

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:
        scaffold_fn = functools.partial(_scaffold_fn,
                                        restore_vars_dict=restore_vars_dict
                                        ) if has_moving_average_decay else None

        def metric_fn(labels, logits):
            '''Evaluation metric function. Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      '''

            predictions = tf.argmax(logits, axis=1)
            top_1_accuracy = tf.metrics.accuracy(labels, predictions)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            top_5_accuracy = tf.metrics.mean(in_top_5)

            result_dict = {
                'top_1_accuracy': top_1_accuracy,
                'top_5_accuracy': top_5_accuracy,
            }

            return result_dict

        eval_metrics = (metric_fn, [labels, logits])

    num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
    tf.logging.info('number of trainable parameters: {}'.format(num_params))

    return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                             loss=loss,
                                             train_op=train_op,
                                             host_call=host_call,
                                             eval_metrics=eval_metrics,
                                             scaffold_fn=scaffold_fn)
Exemplo n.º 2
0
Arquivo: main.py Projeto: JDZW2014/uda
  def model_fn(features, labels, mode, params):
    sup_labels = tf.reshape(features["label"], [-1])

    #### Configuring the optimizer
    global_step = tf.train.get_global_step()
    metric_dict = {}
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    if FLAGS.unsup_ratio > 0 and is_training:
      all_images = tf.concat([features["image"],
                              features["ori_image"],
                              features["aug_image"]], 0)
    else:
      all_images = features["image"]

    with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
      all_logits = build_model(
          inputs=all_images,
          num_classes=FLAGS.num_classes,
          is_training=is_training,
          update_bn=True and is_training,
          hparams=hparams,
      )

      sup_bsz = tf.shape(features["image"])[0]
      sup_logits = all_logits[:sup_bsz]

      sup_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
          labels=sup_labels,
          logits=sup_logits)
      sup_prob = tf.nn.softmax(sup_logits, axis=-1)
      metric_dict["sup/pred_prob"] = tf.reduce_mean(
          tf.reduce_max(sup_prob, axis=-1))
    if FLAGS.tsa:
      sup_loss, avg_sup_loss = anneal_sup_loss(sup_logits, sup_labels, sup_loss,
                                               global_step, metric_dict)
    else:
      avg_sup_loss = tf.reduce_mean(sup_loss)
    total_loss = avg_sup_loss

    if FLAGS.unsup_ratio > 0 and is_training:
      aug_bsz = tf.shape(features["ori_image"])[0]

      ori_logits = all_logits[sup_bsz : sup_bsz + aug_bsz]
      aug_logits = all_logits[sup_bsz + aug_bsz:]
      if FLAGS.uda_softmax_temp != -1:
        ori_logits_tgt = ori_logits / FLAGS.uda_softmax_temp
      else:
        ori_logits_tgt = ori_logits
      ori_prob = tf.nn.softmax(ori_logits, axis=-1)
      aug_prob = tf.nn.softmax(aug_logits, axis=-1)
      metric_dict["unsup/ori_prob"] = tf.reduce_mean(
          tf.reduce_max(ori_prob, axis=-1))
      metric_dict["unsup/aug_prob"] = tf.reduce_mean(
          tf.reduce_max(aug_prob, axis=-1))

      aug_loss = _kl_divergence_with_logits(
          p_logits=tf.stop_gradient(ori_logits_tgt),
          q_logits=aug_logits)

      if FLAGS.uda_confidence_thresh != -1:
        ori_prob = tf.nn.softmax(ori_logits, axis=-1)
        largest_prob = tf.reduce_max(ori_prob, axis=-1)
        loss_mask = tf.cast(tf.greater(
            largest_prob, FLAGS.uda_confidence_thresh), tf.float32)
        metric_dict["unsup/high_prob_ratio"] = tf.reduce_mean(loss_mask)
        loss_mask = tf.stop_gradient(loss_mask)
        aug_loss = aug_loss * loss_mask
        metric_dict["unsup/high_prob_loss"] = tf.reduce_mean(aug_loss)

      if FLAGS.ent_min_coeff > 0:
        ent_min_coeff = FLAGS.ent_min_coeff
        metric_dict["unsup/ent_min_coeff"] = ent_min_coeff
        per_example_ent = get_ent(ori_logits)
        ent_min_loss = tf.reduce_mean(per_example_ent)
        total_loss = total_loss + ent_min_coeff * ent_min_loss

      avg_unsup_loss = tf.reduce_mean(aug_loss)
      total_loss += FLAGS.unsup_coeff * avg_unsup_loss
      metric_dict["unsup/loss"] = avg_unsup_loss

    total_loss = utils.decay_weights(
        total_loss,
        FLAGS.weight_decay_rate)

    #### Check model parameters
    num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
    tf.logging.info("#params: {}".format(num_params))

    if FLAGS.verbose:
      format_str = "{{:<{0}s}}\t{{}}".format(
          max([len(v.name) for v in tf.trainable_variables()]))
      for v in tf.trainable_variables():
        tf.logging.info(format_str.format(v.name, v.get_shape()))
    if FLAGS.moving_average_decay > 0.:
      ema = tf.train.ExponentialMovingAverage(
          decay=FLAGS.moving_average_decay)
      ema_vars = utils.get_all_variable()

    #### Evaluation mode
    if mode == tf.estimator.ModeKeys.EVAL:
      if FLAGS.moving_average_decay > 0:
        restore_vars_dict = ema.variables_to_restore(ema_vars)
        scaffold_fn = functools.partial(
            _scaffold_fn,
            restore_vars_dict=restore_vars_dict) if FLAGS.moving_average_decay > 0 else None
      else:
        scaffold_fn = None

      #### Metric function for classification
      def metric_fn(per_example_loss, label_ids, logits):
        # classification loss & accuracy
        loss = tf.metrics.mean(per_example_loss)

        predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
        accuracy = tf.metrics.accuracy(label_ids, predictions)

        ret_dict = {
            "eval/classify_loss": loss,
            "eval/classify_accuracy": accuracy
        }

        return ret_dict

      eval_metrics = (metric_fn, [sup_loss, sup_labels, sup_logits])

      #### Constucting evaluation TPUEstimatorSpec.
      eval_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=total_loss,
          eval_metrics=eval_metrics,
          scaffold_fn=scaffold_fn,
      )

      return eval_spec

    # increase the learning rate linearly
    if FLAGS.warmup_steps > 0:
      warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
                  * FLAGS.learning_rate
    else:
      warmup_lr = 0.0

    # decay the learning rate using the cosine schedule
    lrate = tf.clip_by_value(tf.to_float(global_step-FLAGS.warmup_steps) / (FLAGS.train_steps-FLAGS.warmup_steps), 0, 1)
    decay_lr = FLAGS.learning_rate * tf.cos(lrate * (7. / 8) * np.pi / 2)

    learning_rate = tf.where(global_step < FLAGS.warmup_steps,
                             warmup_lr, decay_lr)

    optimizer = tf.train.MomentumOptimizer(
        learning_rate=learning_rate,
        momentum=0.9,
        use_nesterov=True)

    if FLAGS.use_tpu:
      optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

    grads_and_vars = optimizer.compute_gradients(total_loss)
    gradients, variables = zip(*grads_and_vars)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
      train_op = optimizer.apply_gradients(
          zip(gradients, variables), global_step=tf.train.get_global_step())
    if FLAGS.moving_average_decay > 0:
      with tf.control_dependencies([train_op]):
        train_op = ema.apply(ema_vars)

    #### Creating training logging hook
    # compute accuracy
    sup_pred = tf.argmax(sup_logits, axis=-1, output_type=sup_labels.dtype)
    is_correct = tf.to_float(tf.equal(sup_pred, sup_labels))
    acc = tf.reduce_mean(is_correct)
    metric_dict["sup/sup_loss"] = avg_sup_loss
    metric_dict["training/loss"] = total_loss
    metric_dict["sup/acc"] = acc
    metric_dict["training/lr"] = learning_rate
    metric_dict["training/step"] = global_step

    if not FLAGS.use_tpu:
      log_info = ("step [{training/step}] lr {training/lr:.6f} "
                  "loss {training/loss:.4f} "
                  "sup/acc {sup/acc:.4f} sup/loss {sup/sup_loss:.6f} ")
      if FLAGS.unsup_ratio > 0:
        log_info += "unsup/loss {unsup/loss:.6f} "
      formatter = lambda kwargs: log_info.format(**kwargs)
      logging_hook = tf.train.LoggingTensorHook(
          tensors=metric_dict,
          every_n_iter=FLAGS.iterations,
          formatter=formatter)
      training_hooks = [logging_hook]
      #### Constucting training TPUEstimatorSpec.
      train_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode, loss=total_loss, train_op=train_op,
          training_hooks=training_hooks)
    else:
      #### Constucting training TPUEstimatorSpec.
      host_call = utils.construct_scalar_host_call(
          metric_dict=metric_dict,
          model_dir=params["model_dir"],
          prefix="",
          reduce_fn=tf.reduce_mean)
      train_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode, loss=total_loss, train_op=train_op,
          host_call=host_call)

    return train_spec