Example #1
0
 def _build_model(model_name):
     """Build the model for a given model name."""
     if model_name.startswith('mnasnet'):
         return mnasnet_models.build_mnasnet_model(
             features,
             model_name=model_name,
             training=is_training,
             override_params=override_params)
     elif model_name.startswith('mixnet'):
         return mixnet_builder.build_model(features,
                                           model_name=model_name,
                                           training=is_training,
                                           override_params=override_params)
     else:
         raise ValueError('Unknown model name {}'.format(model_name))
Example #2
0
def mnasnet_model_fn(features, labels, mode, params):
    """The model_fn for MnasNet to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
      `params['batch_size']` is always provided and should be used as the
      effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    # This is essential, if using a keras-derived model.
    K.set_learning_phase(is_training)

    if isinstance(features, dict):
        features = features['feature']

    if mode == tf.estimator.ModeKeys.PREDICT:
        # Adds an identify node to help TFLite export.
        features = tf.identity(features, 'float_image_input')

    # In most cases, the default data format NCHW instead of NHWC should be
    # used for a significant performance boost on GPU. NHWC should be used
    # only if the network needs to be run on CPU since the pooling operations
    # are only supported on NHWC. TPU uses XLA compiler to figure out best layout.
    if params['data_format'] == 'channels_first':
        assert not params['transpose_input']  # channels_first only for GPU
        features = tf.transpose(features, [0, 3, 1, 2])
        stats_shape = [3, 1, 1]
    else:
        stats_shape = [1, 1, 3]

    if params['transpose_input'] and mode != tf.estimator.ModeKeys.PREDICT:
        features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

    # Normalize the image to zero mean and unit variance.
    features -= tf.constant(imagenet_input.MEAN_RGB,
                            shape=stats_shape,
                            dtype=features.dtype)
    features /= tf.constant(imagenet_input.STDDEV_RGB,
                            shape=stats_shape,
                            dtype=features.dtype)

    has_moving_average_decay = (params['moving_average_decay'] > 0)

    tf.logging.info('Using open-source implementation for MnasNet definition.')
    override_params = {}
    if params['batch_norm_momentum']:
        override_params['batch_norm_momentum'] = params['batch_norm_momentum']
    if params['batch_norm_epsilon']:
        override_params['batch_norm_epsilon'] = params['batch_norm_epsilon']
    if params['dropout_rate']:
        override_params['dropout_rate'] = params['dropout_rate']
    if params['data_format']:
        override_params['data_format'] = params['data_format']
    if params['num_label_classes']:
        override_params['num_classes'] = params['num_label_classes']
    if params['depth_multiplier']:
        override_params['depth_multiplier'] = params['depth_multiplier']
    if params['depth_divisor']:
        override_params['depth_divisor'] = params['depth_divisor']
    if params['min_depth']:
        override_params['min_depth'] = params['min_depth']
    override_params['use_keras'] = params['use_keras']

    if params['precision'] == 'bfloat16':
        with tf.contrib.tpu.bfloat16_scope():
            logits, _ = mnasnet_models.build_mnasnet_model(
                features,
                model_name=params['model_name'],
                training=is_training,
                override_params=override_params)
        logits = tf.cast(logits, tf.float32)
    else:  # params['precision'] == 'float32'
        logits, _ = mnasnet_models.build_mnasnet_model(
            features,
            model_name=params['model_name'],
            training=is_training,
            override_params=override_params)

    if params['quantized_training']:
        if is_training:
            tf.logging.info('Adding fake quantization ops for training.')
            tf.contrib.quantize.create_training_graph(
                quant_delay=int(params['steps_per_epoch'] *
                                FLAGS.quantization_delay_epochs))
        else:
            tf.logging.info('Adding fake quantization ops for evaluation.')
            tf.contrib.quantize.create_eval_graph()

    if mode == tf.estimator.ModeKeys.PREDICT:
        scaffold_fn = None
        if FLAGS.export_moving_average:
            # If the model is trained with moving average decay, to match evaluation
            # metrics, we need to export the model using moving average variables.
            restore_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
            variables_to_restore = get_pretrained_variables_to_restore(
                restore_checkpoint, load_moving_average=True)
            tf.logging.info('Restoring from the latest checkpoint: %s',
                            restore_checkpoint)
            tf.logging.info(str(variables_to_restore))

            def restore_scaffold():
                saver = tf.train.Saver(variables_to_restore)
                return tf.train.Scaffold(saver=saver)

            scaffold_fn = restore_scaffold

        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.contrib.tpu.TPUEstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            },
            scaffold_fn=scaffold_fn)

    # If necessary, in the model_fn, use params['batch_size'] instead the batch
    # size flags (--train_batch_size or --eval_batch_size).
    batch_size = params['batch_size']  # pylint: disable=unused-variable

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    one_hot_labels = tf.one_hot(labels, params['num_label_classes'])
    cross_entropy = tf.losses.softmax_cross_entropy(
        logits=logits,
        onehot_labels=one_hot_labels,
        label_smoothing=params['label_smoothing'])

    # Add weight decay to the loss for non-batch-normalization variables.
    loss = cross_entropy + params['weight_decay'] * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

    global_step = tf.train.get_global_step()
    if has_moving_average_decay:
        ema = tf.train.ExponentialMovingAverage(
            decay=params['moving_average_decay'], num_updates=global_step)
        ema_vars = utils.get_ema_vars()

    host_call = None
    if is_training:
        # Compute the current epoch and associated learning rate from global_step.
        current_epoch = (tf.cast(global_step, tf.float32) /
                         params['steps_per_epoch'])

        scaled_lr = params['base_learning_rate'] * (params['train_batch_size'] / 256.0)  # pylint: disable=line-too-long
        learning_rate = utils.build_learning_rate(scaled_lr, global_step,
                                                  params['steps_per_epoch'])
        optimizer = utils.build_optimizer(learning_rate)
        if params['use_tpu']:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)

        if has_moving_average_decay:
            with tf.control_dependencies([train_op]):
                train_op = ema.apply(ema_vars)

        if not params['skip_host_call']:

            def host_call_fn(gs, loss, lr, ce):
                """Training host call.

        Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          gs: `Tensor with shape `[batch]` for the global_step
          loss: `Tensor` with shape `[batch]` for the training loss.
          lr: `Tensor` with shape `[batch]` for the learning_rate.
          ce: `Tensor` with shape `[batch]` for the current_epoch.

        Returns:
          List of summary ops to run on the CPU host.
        """
                gs = gs[0]
                # Host call fns are executed params['iterations_per_loop'] times after
                # one TPU loop is finished, setting max_queue value to the same as
                # number of iterations will make the summary writer only flush the
                # data to storage once per loop.
                with tf.contrib.summary.create_file_writer(
                        FLAGS.model_dir,
                        max_queue=params['iterations_per_loop']).as_default():
                    with tf.contrib.summary.always_record_summaries():
                        tf.contrib.summary.scalar('loss', loss[0], step=gs)
                        tf.contrib.summary.scalar('learning_rate',
                                                  lr[0],
                                                  step=gs)
                        tf.contrib.summary.scalar('current_epoch',
                                                  ce[0],
                                                  step=gs)

                        return tf.contrib.summary.all_summary_ops()

            # To log the loss, current learning rate, and epoch for Tensorboard, the
            # summary op needs to be run on the host CPU via host_call. host_call
            # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
            # dimension. These Tensors are implicitly concatenated to
            # [params['batch_size']].
            gs_t = tf.reshape(global_step, [1])
            loss_t = tf.reshape(loss, [1])
            lr_t = tf.reshape(learning_rate, [1])
            ce_t = tf.reshape(current_epoch, [1])

            host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t])

    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits):
            """Evaluation metric function.

      Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      """
            predictions = tf.argmax(logits, axis=1)
            top_1_accuracy = tf.metrics.accuracy(labels, predictions)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            top_5_accuracy = tf.metrics.mean(in_top_5)

            return {
                'top_1_accuracy': top_1_accuracy,
                'top_5_accuracy': top_5_accuracy,
            }

        eval_metrics = (metric_fn, [labels, logits])

    num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
    tf.logging.info('number of trainable parameters: {}'.format(num_params))

    # Prepares scaffold_fn if needed.
    scaffold_fn = None
    if is_training and FLAGS.init_checkpoint:
        variables_to_restore = get_pretrained_variables_to_restore(
            FLAGS.init_checkpoint, has_moving_average_decay)
        tf.logging.info('Initializing from pretrained checkpoint: %s',
                        FLAGS.init_checkpoint)
        if FLAGS.use_tpu:

            def init_scaffold():
                tf.train.init_from_checkpoint(FLAGS.init_checkpoint,
                                              variables_to_restore)
                return tf.train.Scaffold()

            scaffold_fn = init_scaffold
        else:
            tf.train.init_from_checkpoint(FLAGS.init_checkpoint,
                                          variables_to_restore)

    restore_vars_dict = None
    if not is_training and has_moving_average_decay:
        # Load moving average variables for eval.
        restore_vars_dict = ema.variables_to_restore(ema_vars)

        def eval_scaffold():
            saver = tf.train.Saver(restore_vars_dict)
            return tf.train.Scaffold(saver=saver)

        scaffold_fn = eval_scaffold

    return tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           host_call=host_call,
                                           eval_metrics=eval_metrics,
                                           scaffold_fn=scaffold_fn)
Example #3
0
def mnasnet_model_fn(features, labels, mode, params):
  """The model_fn for MnasNet to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
      `params['batch_size']` is always provided and should be used as the
      effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
  is_training = (mode == tf.estimator.ModeKeys.TRAIN)
  # This is essential, if using a keras-derived model.
  K.set_learning_phase(is_training)

  if isinstance(features, dict):
    features = features['feature']

  # In most cases, the default data format NCHW instead of NHWC should be
  # used for a significant performance boost on GPU/TPU. NHWC should be used
  # only if the network needs to be run on CPU since the pooling operations
  # are only supported on NHWC.
  if FLAGS.data_format == 'channels_first':
    assert not FLAGS.transpose_input    # channels_first only for GPU
    features = tf.transpose(features, [0, 3, 1, 2])

  if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT:
    features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

  # Normalize the image to zero mean and unit variance.
  features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
  features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)

  has_moving_average_decay = (FLAGS.moving_average_decay > 0)

  tf.logging.info('Using open-source implementation for MnasNet definition.')
  override_params = {}
  if FLAGS.batch_norm_momentum:
    override_params['batch_norm_momentum'] = FLAGS.batch_norm_momentum
  if FLAGS.batch_norm_epsilon:
    override_params['batch_norm_epsilon'] = FLAGS.batch_norm_epsilon
  if FLAGS.dropout_rate:
    override_params['dropout_rate'] = FLAGS.dropout_rate
  if FLAGS.data_format:
    override_params['data_format'] = FLAGS.data_format
  if FLAGS.num_label_classes:
    override_params['num_classes'] = FLAGS.num_label_classes
  if FLAGS.depth_multiplier:
    override_params['depth_multiplier'] = FLAGS.depth_multiplier
  if FLAGS.depth_divisor:
    override_params['depth_divisor'] = FLAGS.depth_divisor
  if FLAGS.min_depth:
    override_params['min_depth'] = FLAGS.min_depth
  override_params['use_keras'] = FLAGS.use_keras

  if params['use_bfloat16']:
    """with tf.contrib.tpu.bfloat16_scope():
      logits, _ = mnasnet_models.build_mnasnet_model(
          features,
          model_name=FLAGS.model_name,
          training=is_training,
          override_params=override_params)"""
    #features = tf.cast(features,tf.float16)
    logits, _ = mnasnet_models.build_mnasnet_model(
        features,
        model_name=FLAGS.model_name,
        training=is_training,
        override_params=override_params)
    logits = tf.cast(logits, tf.float32)
  else:
    logits, _ = mnasnet_models.build_mnasnet_model(
        features,
        model_name=FLAGS.model_name,
        training=is_training,
        override_params=override_params)

  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
    }
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })

  # If necessary, in the model_fn, use params['batch_size'] instead the batch
  # size flags (--train_batch_size or --eval_batch_size).
  batch_size = params['batch_size']  # pylint: disable=unused-variable

  # Calculate loss, which includes softmax cross entropy and L2 regularization.
  one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes)
  cross_entropy = tf.losses.softmax_cross_entropy(
      logits=logits,
      onehot_labels=one_hot_labels,
      label_smoothing=FLAGS.label_smoothing)

  # Add weight decay to the loss for non-batch-normalization variables.
  loss = cross_entropy + FLAGS.weight_decay * tf.add_n([
      tf.nn.l2_loss(tf.cast(v,tf.float32))
      for v in tf.trainable_variables()
      if 'batch_normalization' not in v.name
  ])

  global_step = tf.train.get_global_step()
  if has_moving_average_decay:
    ema = tf.train.ExponentialMovingAverage(
        decay=FLAGS.moving_average_decay, num_updates=global_step)
    ema_vars = tf.trainable_variables() + tf.get_collection('moving_vars')
    for v in tf.global_variables():
      # We maintain mva for batch norm moving mean and variance as well.
      if 'moving_mean' in v.name or 'moving_variance' in v.name:
        ema_vars.append(v)
    ema_vars = list(set(ema_vars))

  host_call = None
  restore_vars_dict = None
  if is_training:
    # Compute the current epoch and associated learning rate from global_step.
    current_epoch = (
        tf.cast(global_step, tf.float32) / params['steps_per_epoch'])

    scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)
    learning_rate = mnasnet_utils.build_learning_rate(scaled_lr, global_step,
                                                      params['steps_per_epoch'])
    optimizer = mnasnet_utils.build_optimizer(learning_rate)
    """if FLAGS.use_tpu:
      # When using TPU, wrap the optimizer with CrossShardOptimizer which
      # handles synchronization details between different TPU cores. To the
      # user, this should look like regular synchronous training.
      optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)"""

    # Batch normalization requires UPDATE_OPS to be added as a dependency to
    # the train operation.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    if params['use_bfloat16']:
      #loss_scale_manager = tf.contrib.mixed_precision.FixedLossScaleManager(5000)
      #optimizer = tf.contrib.mixed_precision.LossScaleOptimizer(optimizer, loss_scale_manager)
      scaled_grad_vars = optimizer.compute_gradients(loss * 128)
      unscaled_grad_vars = [(grad / 128, var)
                            for grad, var in scaled_grad_vars]
      minimize_op = optimizer.apply_gradients(unscaled_grad_vars, global_step)
      train_op = tf.group(minimize_op, update_ops)
    else: 
      with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss, global_step)

    if has_moving_average_decay:
      with tf.control_dependencies([train_op]):
        train_op = ema.apply(ema_vars)
    tf.summary.scalar('cross_entropy',cross_entropy)
    tf.summary.scalar('loss',loss)
    tf.summary.scalar('learning_rate',learning_rate)
    if not FLAGS.skip_host_call:

      def host_call_fn(gs, loss, lr, ce):
        """Training host call.

        Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          gs: `Tensor with shape `[batch]` for the global_step
          loss: `Tensor` with shape `[batch]` for the training loss.
          lr: `Tensor` with shape `[batch]` for the learning_rate.
          ce: `Tensor` with shape `[batch]` for the current_epoch.

        Returns:
          List of summary ops to run on the CPU host.
        """
        gs = gs[0]
        # Host call fns are executed FLAGS.iterations_per_loop times after one
        # TPU loop is finished, setting max_queue value to the same as number of
        # iterations will make the summary writer only flush the data to storage
        # once per loop.
        with tf.contrib.summary.create_file_writer(
            FLAGS.model_dir, max_queue=FLAGS.iterations_per_loop).as_default():
          with tf.contrib.summary.always_record_summaries():
            tf.contrib.summary.scalar('loss', loss[0], step=gs)
            tf.contrib.summary.scalar('learning_rate', lr[0], step=gs)
            tf.contrib.summary.scalar('current_epoch', ce[0], step=gs)

            return tf.contrib.summary.all_summary_ops()

      # To log the loss, current learning rate, and epoch for Tensorboard, the
      # summary op needs to be run on the host CPU via host_call. host_call
      # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
      # dimension. These Tensors are implicitly concatenated to
      # [params['batch_size']].
      gs_t = tf.reshape(global_step, [1])
      loss_t = tf.reshape(loss, [1])
      lr_t = tf.reshape(learning_rate, [1])
      ce_t = tf.reshape(current_epoch, [1])

      host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t])

  else:
    train_op = None
    if has_moving_average_decay:
      # Load moving average variables for eval.
      restore_vars_dict = ema.variables_to_restore(ema_vars)

  eval_metrics = None
  if mode == tf.estimator.ModeKeys.EVAL:

    def metric_fn(labels, logits):
      """Evaluation metric function.

      Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      """
      predictions = tf.argmax(logits, axis=1)
      top_1_accuracy = tf.metrics.accuracy(labels, predictions)
      in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
      top_5_accuracy = tf.metrics.mean(in_top_5)
      tf.summary('top_1_accuracy',top_1_accuracy)
      tf.summary('top_1_accuracy',top_1_accuracy)

      return {
          'top_1_accuracy': top_1_accuracy,
          'top_5_accuracy': top_5_accuracy,
      }

    eval_metrics = (metric_fn, [labels, logits])

  num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
  tf.logging.info('number of trainable parameters: {}'.format(num_params))

  def _scaffold_fn():
    saver = tf.train.Saver(restore_vars_dict)
    return tf.train.Scaffold(saver=saver)

  """return tf.contrib.tpu.TPUEstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      host_call=host_call,
      eval_metrics=eval_metrics,
      scaffold_fn=_scaffold_fn if has_moving_average_decay else None)"""
  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      eval_metric_ops=eval_metrics,
      #scaffold=_scaffold_fn if has_moving_average_decay else None)
      scaffold=tf.train.Scaffold(saver=tf.train.Saver(restore_vars_dict)))