Exemple #1
0
def get_pretrained_variables_to_restore(checkpoint_path,
                                        load_moving_average=False):
    """Gets veriables_to_restore mapping from pretrained checkpoint.

  Args:
    checkpoint_path: String. Path of checkpoint.
    load_moving_average: Boolean, whether load moving average variables to
      replace variables.

  Returns:
    Mapping of variables to restore.
  """
    checkpoint_reader = tf.train.load_checkpoint(checkpoint_path)
    variable_shape_map = checkpoint_reader.get_variable_to_shape_map()

    variables_to_restore = {}
    ema_vars = mnas_utils.get_ema_vars()
    for v in tf.global_variables():
        # Skip variables if they are in excluded scopes.
        is_excluded = False
        for scope in ['global_step', 'ExponentialMovingAverage']:
            if scope in v.op.name:
                is_excluded = True
                break
        if is_excluded:
            tf.logging.info('Exclude [%s] from loading from checkpoint.',
                            v.op.name)
            continue
        variable_name_ckpt = v.op.name
        if load_moving_average and v in ema_vars:
            # To load moving average variables into non-moving version for
            # fine-tuning, maps variables here manually.
            variable_name_ckpt = v.op.name + '/ExponentialMovingAverage'

        if variable_name_ckpt not in variable_shape_map:
            tf.logging.info(
                'Skip init [%s] from [%s] as it is not in the checkpoint',
                v.op.name, variable_name_ckpt)
            continue

        variables_to_restore[variable_name_ckpt] = v
        tf.logging.info('Init variable [%s] from [%s] in ckpt', v.op.name,
                        variable_name_ckpt)
    return variables_to_restore
Exemple #2
0
def build_model_fn(features, labels, mode, params):
  """The model_fn for MnasNet to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
      `params['batch_size']` is always provided and should be used as the
      effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
  is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  # This is essential, if using a keras-derived model.
  tf.keras.backend.set_learning_phase(is_training)

  if isinstance(features, dict):
    features = features['feature']

  if mode == tf.estimator.ModeKeys.PREDICT:
    # Adds an identify node to help TFLite export.
    features = tf.identity(features, 'float_image_input')

  # In most cases, the default data format NCHW instead of NHWC should be
  # used for a significant performance boost on GPU. NHWC should be used
  # only if the network needs to be run on CPU since the pooling operations
  # are only supported on NHWC. TPU uses XLA compiler to figure out best layout.
  if params['data_format'] == 'channels_first':
    assert not params['transpose_input']    # channels_first only for GPU
    features = tf.transpose(features, [0, 3, 1, 2])
    stats_shape = [3, 1, 1]
  else:
    stats_shape = [1, 1, 3]

  if params['transpose_input'] and mode != tf.estimator.ModeKeys.PREDICT:
    features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

  # Normalize the image to zero mean and unit variance.
  features -= tf.constant(
      imagenet_input.MEAN_RGB, shape=stats_shape, dtype=features.dtype)
  features /= tf.constant(
      imagenet_input.STDDEV_RGB, shape=stats_shape, dtype=features.dtype)

  has_moving_average_decay = (params['moving_average_decay'] > 0)

  tf.logging.info('Using open-source implementation for MnasNet definition.')
  override_params = {}
  if params['batch_norm_momentum']:
    override_params['batch_norm_momentum'] = params['batch_norm_momentum']
  if params['batch_norm_epsilon']:
    override_params['batch_norm_epsilon'] = params['batch_norm_epsilon']
  if params['dropout_rate']:
    override_params['dropout_rate'] = params['dropout_rate']
  if params['data_format']:
    override_params['data_format'] = params['data_format']
  if params['num_label_classes']:
    override_params['num_classes'] = params['num_label_classes']
  if params['depth_multiplier']:
    override_params['depth_multiplier'] = params['depth_multiplier']
  if params['depth_divisor']:
    override_params['depth_divisor'] = params['depth_divisor']
  if params['min_depth']:
    override_params['min_depth'] = params['min_depth']
  override_params['use_keras'] = params['use_keras']

  def _build_model(model_name):
    """Build the model for a given model name."""
    if model_name.startswith('mnasnet'):
      return mnasnet_models.build_mnasnet_model(
          features,
          model_name=model_name,
          training=is_training,
          override_params=override_params)
    elif model_name.startswith('mixnet'):
      return mixnet_builder.build_model(
          features,
          model_name=model_name,
          training=is_training,
          override_params=override_params)
    else:
      raise ValueError('Unknown model name {}'.format(model_name))

  if params['precision'] == 'bfloat16':
    with tf.tpu.bfloat16_scope():
      logits, _ = _build_model(params['model_name'])
    logits = tf.cast(logits, tf.float32)
  else:  # params['precision'] == 'float32'
    logits, _ = _build_model(params['model_name'])

  if params['quantized_training']:
    try:
      from tensorflow.contrib import quantize  # pylint: disable=g-import-not-at-top
    except ImportError as e:
      logging.exception('Quantized training is not supported in TensorFlow 2.x')
      raise e

    if is_training:
      tf.logging.info('Adding fake quantization ops for training.')
      quantize.create_training_graph(
          quant_delay=int(params['steps_per_epoch'] *
                          FLAGS.quantization_delay_epochs))
    else:
      tf.logging.info('Adding fake quantization ops for evaluation.')
      quantize.create_eval_graph()

  if mode == tf.estimator.ModeKeys.PREDICT:
    scaffold_fn = None
    if FLAGS.export_moving_average:
      # If the model is trained with moving average decay, to match evaluation
      # metrics, we need to export the model using moving average variables.
      restore_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
      variables_to_restore = get_pretrained_variables_to_restore(
          restore_checkpoint, load_moving_average=True)
      tf.logging.info('Restoring from the latest checkpoint: %s',
                      restore_checkpoint)
      tf.logging.info(str(variables_to_restore))

      def restore_scaffold():
        saver = tf.train.Saver(variables_to_restore)
        return tf.train.Scaffold(saver=saver)

      scaffold_fn = restore_scaffold

    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
    }
    return tf.estimator.tpu.TPUEstimatorSpec(
        mode=mode,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        },
        scaffold_fn=scaffold_fn)

  # If necessary, in the model_fn, use params['batch_size'] instead the batch
  # size flags (--train_batch_size or --eval_batch_size).
  batch_size = params['batch_size']  # pylint: disable=unused-variable

  # Calculate loss, which includes softmax cross entropy and L2 regularization.
  one_hot_labels = tf.one_hot(labels, params['num_label_classes'])
  cross_entropy = tf.losses.softmax_cross_entropy(
      logits=logits,
      onehot_labels=one_hot_labels,
      label_smoothing=params['label_smoothing'])

  # Add weight decay to the loss for non-batch-normalization variables.
  loss = cross_entropy + params['weight_decay'] * tf.add_n([
      tf.nn.l2_loss(v)
      for v in tf.trainable_variables()
      if 'batch_normalization' not in v.name
  ])

  global_step = tf.train.get_global_step()
  if has_moving_average_decay:
    ema = tf.train.ExponentialMovingAverage(
        decay=params['moving_average_decay'], num_updates=global_step)
    ema_vars = mnas_utils.get_ema_vars()

  host_call = None
  if is_training:
    # Compute the current epoch and associated learning rate from global_step.
    current_epoch = (
        tf.cast(global_step, tf.float32) / params['steps_per_epoch'])

    scaled_lr = params['base_learning_rate'] * (params['train_batch_size'] / 256.0)  # pylint: disable=line-too-long
    learning_rate = mnas_utils.build_learning_rate(scaled_lr, global_step,
                                                   params['steps_per_epoch'])
    optimizer = mnas_utils.build_optimizer(learning_rate)
    if params['use_tpu']:
      # When using TPU, wrap the optimizer with CrossShardOptimizer which
      # handles synchronization details between different TPU cores. To the
      # user, this should look like regular synchronous training.
      optimizer = tf.tpu.CrossShardOptimizer(optimizer)

      if params['add_summaries']:
        summary_writer = tf2.summary.create_file_writer(
            FLAGS.model_dir, max_queue=params['iterations_per_loop'])
        with summary_writer.as_default():
          should_record = tf.equal(global_step % params['iterations_per_loop'],
                                   0)
          with tf2.summary.record_if(should_record):
            tf2.summary.scalar('loss', loss, step=global_step)
            tf2.summary.scalar('learning_rate', learning_rate, step=global_step)
            tf2.summary.scalar('current_epoch', current_epoch, step=global_step)

    # Batch normalization requires UPDATE_OPS to be added as a dependency to
    # the train operation.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops + tf.summary.all_v2_summary_ops()):
      train_op = optimizer.minimize(loss, global_step)

    if has_moving_average_decay:
      with tf.control_dependencies([train_op]):
        train_op = ema.apply(ema_vars)

  else:
    train_op = None

  eval_metrics = None
  if mode == tf.estimator.ModeKeys.EVAL:

    def metric_fn(labels, logits):
      """Evaluation metric function.

      Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      """
      predictions = tf.argmax(logits, axis=1)
      top_1_accuracy = tf.metrics.accuracy(labels, predictions)
      in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
      top_5_accuracy = tf.metrics.mean(in_top_5)

      return {
          'top_1_accuracy': top_1_accuracy,
          'top_5_accuracy': top_5_accuracy,
      }

    eval_metrics = (metric_fn, [labels, logits])

  num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
  tf.logging.info('number of trainable parameters: {}'.format(num_params))

  # Prepares scaffold_fn if needed.
  scaffold_fn = None
  if is_training and FLAGS.init_checkpoint:
    variables_to_restore = get_pretrained_variables_to_restore(
        FLAGS.init_checkpoint, has_moving_average_decay)
    tf.logging.info('Initializing from pretrained checkpoint: %s',
                    FLAGS.init_checkpoint)
    if FLAGS.use_tpu:

      def init_scaffold():
        tf.train.init_from_checkpoint(FLAGS.init_checkpoint,
                                      variables_to_restore)
        return tf.train.Scaffold()

      scaffold_fn = init_scaffold
    else:
      tf.train.init_from_checkpoint(FLAGS.init_checkpoint, variables_to_restore)

  restore_vars_dict = None
  if not is_training and has_moving_average_decay:
    # Load moving average variables for eval.
    restore_vars_dict = ema.variables_to_restore(ema_vars)

    def eval_scaffold():
      saver = tf.train.Saver(restore_vars_dict)
      return tf.train.Scaffold(saver=saver)

    scaffold_fn = eval_scaffold

  return tf.estimator.tpu.TPUEstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      host_call=host_call,
      eval_metrics=eval_metrics,
      scaffold_fn=scaffold_fn)