예제 #1
0
  def test_get_standalone_model_v3(self):
    ssd = mobile_search_space_v3.MOBILENET_V3_LARGE
    model_spec = mobile_search_space_v3.get_search_space_spec(ssd)
    model = mobile_classifier_factory.get_standalone_model(model_spec)

    inputs = tf.ones([2, 224, 224, 3])
    model.build(inputs.shape)

    outputs, _ = model.apply(inputs, training=False)
    self.assertEqual(outputs.shape, [2, 1001])
def model_fn(features,
             labels, mode,
             params):
  """Construct a TPUEstimatorSpec for a model."""
  training = (mode == tf.estimator.ModeKeys.TRAIN)

  if mode == tf.estimator.ModeKeys.EVAL:
    # At evaluation time, the function argument `features` is really a 2-element
    # tuple containing:
    # * A tensor of features w/ shape [batch_size, image_height, image_width, 3]
    # * A tensor of masks w/ shape [batch_size]. Each element of the tensor is
    #   1 (if the element is a normal image) or 0 (if it's a dummy input that
    #   should be ignored). We use this tensor to simulate dynamic batch sizes
    #   during model evaluation. It allows us to handle cases where the
    # validation set size is not a multiple of the eval batch size.
    features, mask = features

  # Data was transposed from NHWC to HWCN on the host side. Transpose it back.
  # This transposition will be optimized away by the XLA compiler. It serves
  # as a hint to the compiler that it should expect the input data to come
  # in HWCN format rather than NHWC.
  features = tf.transpose(features, [3, 0, 1, 2])

  model_spec = mobile_classifier_factory.get_model_spec(
      ssd=params['ssd'],
      indices=params['indices'],
      filters_multipliers=params['filters_multiplier'],
      path_dropout_rate=params['path_dropout_rate'],
      training=training)

  tf.io.gfile.makedirs(params['checkpoint_dir'])
  model_spec_filename = os.path.join(
      params['checkpoint_dir'], 'model_spec.json')
  with tf.io.gfile.GFile(model_spec_filename, 'w') as handle:
    handle.write(schema_io.serialize(model_spec))

  # We divide the weight_decay by 2 for backwards compatibility with the
  # tf.contrib version of the kernel regularizer, which was used in the
  # experiments from our published paper.
  kernel_regularizer = tf.keras.regularizers.l2(params['weight_decay'] / 2)
  model = mobile_classifier_factory.get_standalone_model(
      model_spec=model_spec,
      kernel_regularizer=kernel_regularizer,
      dropout_rate=params['dropout_rate'])

  model.build(features.shape)
  logits, _ = model.apply(
      inputs=features,
      training=training)
  regularization_loss = model.regularization_loss()
  # Cast back to float32 (effectively only when using use_bfloat16 is true).
  logits = tf.cast(logits, tf.float32)

  if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions={
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits),
        },
        export_outputs={
            'logits': tf.estimator.export.PredictOutput({'logits': logits}),
        })

  empirical_loss = tf.losses.softmax_cross_entropy(
      logits=logits,
      onehot_labels=labels,
      label_smoothing=0.1)
  loss = empirical_loss + regularization_loss

  # Optionally define an op for model training.
  global_step = tf.train.get_global_step()
  if mode == tf.estimator.ModeKeys.TRAIN:
    # linearly scale up the learning rate before switching to cosine decay
    learning_rate = custom_layers.cosine_decay_with_linear_warmup(
        peak_learning_rate=params['learning_rate'],
        global_step=global_step,
        max_global_step=params['max_global_step'],
        warmup_steps=params['warmup_steps'])

    optimizer = tf.train.RMSPropOptimizer(
        learning_rate,
        decay=0.9,
        momentum=params['momentum'],
        epsilon=1.0)

    scaffold_fn = None

    optimizer = tf.tpu.CrossShardOptimizer(optimizer)

    with tf.control_dependencies(model.updates()):
      train_op = optimizer.minimize(loss, global_step)
  else:
    train_op = None
    scaffold_fn = None

  # Optionally define evaluation metrics.
  if mode == tf.estimator.ModeKeys.EVAL:
    def metric_fn(labels, logits, mask):
      label_values = tf.argmax(labels, axis=1)
      predictions = tf.argmax(logits, axis=1)
      accuracy = tf.metrics.accuracy(label_values, predictions, weights=mask)
      return {'accuracy': accuracy}
    eval_metrics = (metric_fn, [labels, logits, mask])
  else:
    eval_metrics = None

  # NOTE: host_call only works on rank-1 tensors. There's also a fairly
  # large performance penalty if we try to pass too many distinct tensors
  # from the TPU to the host at once. We avoid these problems by (i) calling
  # tf.stack to merge all of the float32 scalar values into a single rank-1
  # tensor that can be sent to the host relatively cheaply and (ii) reshaping
  # the remaining values from scalars to rank-1 tensors.
  if mode == tf.estimator.ModeKeys.TRAIN:
    tensorboard_scalars = collections.OrderedDict()
    tensorboard_scalars['model/loss'] = loss
    tensorboard_scalars['model/empirical_loss'] = empirical_loss
    tensorboard_scalars['model/regularization_loss'] = regularization_loss
    tensorboard_scalars['model/learning_rate'] = learning_rate

    def host_call_fn(step, scalar_values):
      values = tf.unstack(scalar_values)
      with tf2.summary.create_file_writer(
          params['checkpoint_dir']).as_default():
        with tf2.summary.record_if(
            tf.equal(step[0] % params['tpu_iterations_per_loop'], 0)):
          for key, value in zip(list(tensorboard_scalars.keys()), values):
            tf2.summary.scalar(key, value, step=step[0])
          return tf.summary.all_v2_summary_ops()

    host_call_values = tf.stack(list(tensorboard_scalars.values()))
    host_call = (host_call_fn, [tf.reshape(global_step, [1]), host_call_values])
  else:
    host_call = None

  # Construct the estimator specification.
  return tf.estimator.tpu.TPUEstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      eval_metrics=eval_metrics,
      scaffold_fn=scaffold_fn,
      host_call=host_call)