Beispiel #1
0
def model_fn(features, labels, mode, params):
    """The model_fn argument for creating an Estimator."""
    tf.logging.info("features = %s labels = %s mode = %s params=%s" %
                    (features, labels, mode, params))
    global_step = tf.train.get_global_step()
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    logits, loss = mnist_model(features, labels, mesh)
    mesh_shape = mtf.parse_mesh_shape(FLAGS.mesh_shape)
    mesh_size = mtf.list_product(mesh_shape)
    mesh_devices = [""] * mesh_size
    mesh_impl = placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, mtf.parse_layout(FLAGS.layout), mesh_devices)

    if mode == tf.estimator.ModeKeys.TRAIN:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])
        optimizer = mtf_optimize.AdafactorOptimizer()
        update_ops = []
        for grad, var in zip(var_grads, graph.trainable_variables):
            update_ops.extend(optimizer.apply_grad(grad, var))

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    restore_hook = mtf.MtfRestoreHook(lowering)

    tf_logits = lowering.outfeed(logits)
    if mode != tf.estimator.ModeKeys.PREDICT:
        tf_loss = lowering.outfeed(loss)
        tf.summary.scalar("loss", tf_loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))
        train_op = tf.group(tf_update_ops)
        saver = tf.train.Saver(tf.global_variables(),
                               sharded=True,
                               max_to_keep=10,
                               keep_checkpoint_every_n_hours=2,
                               defer_build=False,
                               save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        saver_hook = tf.train.CheckpointSaverHook(FLAGS.model_dir,
                                                  save_steps=1000,
                                                  saver=saver,
                                                  listeners=[saver_listener])

        accuracy = tf.metrics.accuracy(labels=labels,
                                       predictions=tf.argmax(tf_logits,
                                                             axis=1))

        # Name tensors to be logged with LoggingTensorHook.
        tf.identity(tf_loss, "cross_entropy")
        tf.identity(accuracy[1], name="train_accuracy")

        # Save accuracy scalar to Tensorboard output.
        tf.summary.scalar("train_accuracy", accuracy[1])

        # restore_hook must come before saver_hook
        return tf.estimator.EstimatorSpec(
            tf.estimator.ModeKeys.TRAIN,
            loss=tf_loss,
            train_op=train_op,
            training_chief_hooks=[restore_hook, saver_hook])

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            "classes": tf.argmax(tf_logits, axis=1),
            "probabilities": tf.nn.softmax(tf_logits),
        }
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            prediction_hooks=[restore_hook],
            export_outputs={
                "classify": tf.estimator.export.PredictOutput(predictions)
            })
    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.EVAL,
            loss=tf_loss,
            evaluation_hooks=[restore_hook],
            eval_metric_ops={
                "accuracy":
                tf.metrics.accuracy(labels=labels,
                                    predictions=tf.argmax(tf_logits, axis=1)),
            })
def model_fn(features, labels, mode, params):
  """A model is called by TpuEstimator."""
  del labels
  global_step = tf.train.get_global_step()
  graph = mtf.Graph()
  mesh = mtf.Mesh(graph, 'my_mesh')
  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  mesh_devices = [''] * mesh_shape.size
  mesh_impl = SimdMeshImpl(
      mesh_shape, mtf.convert_to_layout_rules(FLAGS.layout),
      mesh_devices, params['context'].device_assignment)
  with mtf_utils.outside_all_rewrites():
    logits, loss = toy_model(features, mesh)

  # TRAIN mode
  if mode == tf.estimator.ModeKeys.TRAIN:
    var_grads = mtf.gradients([loss],
                              [v.outputs[0] for v in graph.trainable_variables])
    optimizer = mtf_optimize.AdafactorOptimizer()
    update_ops = []
    for grad, var in zip(var_grads, graph.trainable_variables):
      update_ops.extend(optimizer.apply_grad(grad, var))
  else:
    # for now, we can only export fully-replicated tensors.
    fully_replicated_logits = mtf.anonymize(logits)

  lowering = mtf.Lowering(graph, {mesh: mesh_impl})

  tf_loss = lowering.export_to_tf_tensor(loss)

  if mode == tf.estimator.ModeKeys.TRAIN:
    tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
    tf_update_ops.append(tf.assign_add(global_step, 1))
    tf.logging.info('tf_update_ops: {}'.format(tf_update_ops))
    train_op = tf.group(tf_update_ops)
  else:
    tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits)

  with mtf_utils.outside_all_rewrites():
    # Copy master variables to slices. Must be called first.
    restore_hook = mtf.MtfRestoreHook(lowering)
    if mode == tf.estimator.ModeKeys.TRAIN:
      saver = tf.train.Saver(
          tf.global_variables(),
          sharded=True,
          max_to_keep=10,
          keep_checkpoint_every_n_hours=2,
          defer_build=False,
          save_relative_paths=True)
      tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
      saver_listener = mtf.MtfCheckpointSaverListener(lowering)
      saver_hook = tf.train.CheckpointSaverHook(
          FLAGS.model_dir,
          save_steps=1000,
          saver=saver,
          listeners=[saver_listener])

      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.TRAIN,
          loss=tf_loss,
          train_op=train_op,
          training_hooks=[restore_hook, saver_hook])
    elif mode == tf.estimator.ModeKeys.EVAL:

      def metric_fn(tf_logits):
        mean_logitss = tf.metrics.mean(tf_logits)
        return {'mean_logitss': mean_logitss}

      eval_metrics = (metric_fn, [tf_logits])

      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.EVAL,
          evaluation_hooks=[restore_hook],
          loss=tf_loss,
          eval_metrics=eval_metrics)