Esempio n. 1
0
class MeshTensorFlowTest(parameterized.TestCase, tf.test.TestCase):

  @parameterized.parameters(
      (mtf.Dimension(name="x", size=5),),
      (("x", 5),),
  )
  def testConvertToDimension(self, inputs):
    dimension = mtf.convert_to_dimension(inputs)
    self.assertEqual(dimension.name, "x")
    self.assertEqual(dimension.size, 5)

  def testConvertToDimensionGenericInputs(self):
    dimension = mtf.convert_to_dimension(None)
    self.assertEqual(dimension, None)
    with self.assertRaises(TypeError):
      mtf.convert_to_dimension(5)

  @parameterized.parameters(
      (mtf.Shape([mtf.Dimension(name="x", size=4),
                  mtf.Dimension(name="y", size=8)]),),
      ("x:4;y:8",),
      ("x:4.y:8",),
      ("x:4 y:8",),
      ("x:4,y:8",),
  )
  def testConvertToShape(self, inputs):
    shape = mtf.convert_to_shape(inputs)
    self.assertEqual(shape, mtf.Shape([mtf.Dimension(name="x", size=4),
                                       mtf.Dimension(name="y", size=8)]))

  def testConvertToShapeGenericInputs(self):
    shape = mtf.convert_to_shape(None)
    self.assertEqual(shape, None)
    with self.assertRaises(ValueError):
      mtf.convert_to_shape("x;4")

  @parameterized.parameters(
      (mtf.LayoutRules([("d_ff", "model"), ("heads", "model")]),),
      ("d_ff:model;heads:model",),
      ("d_ff:model.heads:model",),
      ("d_ff:model heads:model",),
      ("d_ff:model,heads:model",),
      ([("d_ff", "model"), ("heads", "model")],),
  )
  def testConvertToLayoutRules(self, inputs):
    layout_rules = mtf.convert_to_layout_rules(inputs)
    self.assertEqual(
        layout_rules._pairs,
        mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs)

  def testConvertToLayoutRulesGenericInputs(self):
    with self.assertRaises(ValueError):
      mtf.convert_to_layout_rules("d_ff;heads")
  def testMeshImpl(self):
    shape = mtf.Shape([mtf.Dimension("batch", 4),
                       mtf.Dimension("model", 8)])
    layout_rules = mtf.LayoutRules([("batch", "batch"),
                                    ("d_ff", "model"),
                                    ("heads", "model")])
    mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules)
    self.assertEqual(mesh_impl.shape, shape)
    self.assertEqual(mesh_impl.ndims, len(shape))
    self.assertEqual(mesh_impl.layout_rules, layout_rules)
    self.assertEqual(mesh_impl.size, shape.size)
    self.assertTrue(mesh_impl.supports_control_dependencies)

    batch = mtf.Dimension("batch", 128)
    length = mtf.Dimension("length", 500)
    d_ff = mtf.Dimension("d_ff", 2048)
    heads = mtf.Dimension("heads", 8)
    self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(batch), 0)
    self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(d_ff), 1)
    self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(heads), 1)
    self.assertEqual(mesh_impl.tensor_layout(mtf.Shape([batch, length, d_ff])),
                     mtf.TensorLayout([0, None, 1]))
Esempio n. 3
0
 def testConvertToLayoutRules(self, inputs):
     layout_rules = mtf.convert_to_layout_rules(inputs)
     self.assertEqual(
         layout_rules._pairs,
         mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs)
Esempio n. 4
0
class MeshTensorFlowTest(parameterized.TestCase, tf.test.TestCase):
    @parameterized.parameters(
        (mtf.Dimension("x", 5), ),
        (("x", 5), ),
    )
    def testConvertToDimension(self, inputs):
        dimension = mtf.convert_to_dimension(inputs)
        self.assertEqual(dimension.name, "x")
        self.assertEqual(dimension.size, 5)

    def testConvertToDimensionGenericInputs(self):
        dimension = mtf.convert_to_dimension(None)
        self.assertEqual(dimension, None)
        with self.assertRaises(TypeError):
            mtf.convert_to_dimension(5)

    @parameterized.parameters(
        (mtf.Shape([mtf.Dimension("x", 4),
                    mtf.Dimension("y", 8)]), ),
        ("x:4;y:8", ),
        ("x:4.y:8", ),
        ("x:4 y:8", ),
        ("x:4,y:8", ),
    )
    def testConvertToShape(self, inputs):
        shape = mtf.convert_to_shape(inputs)
        self.assertEqual(
            shape, mtf.Shape([mtf.Dimension("x", 4),
                              mtf.Dimension("y", 8)]))

    def testConvertToShapeGenericInputs(self):
        shape = mtf.convert_to_shape([])
        self.assertEqual(shape.dims, [])
        shape = mtf.convert_to_shape(None)
        self.assertEqual(shape, None)
        with self.assertRaises(ValueError):
            mtf.convert_to_shape("x;4")

    @parameterized.parameters(
        (mtf.LayoutRules([("d_ff", "model"), ("heads", "model")]), ),
        ("d_ff:model;heads:model", ),
        ("d_ff:model.heads:model", ),
        ("d_ff:model heads:model", ),
        ("d_ff:model,heads:model", ),
        ([("d_ff", "model"), ("heads", "model")], ),
    )
    def testConvertToLayoutRules(self, inputs):
        layout_rules = mtf.convert_to_layout_rules(inputs)
        self.assertEqual(
            layout_rules._pairs,
            mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs)

    def testConvertToLayoutRulesGenericInputs(self):
        with self.assertRaises(ValueError):
            mtf.convert_to_layout_rules("d_ff;heads")

    def testTensorLayout(self):
        tensor_layout = mtf.TensorLayout([0, 2, 1])
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(0), ())
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(1), (0, ))
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(2), (0, 2))
        tensor_layout = mtf.TensorLayout([None, 0])
        self.assertFalse(tensor_layout.is_fully_replicated)
        tensor_layout = mtf.TensorLayout([None, None, None])
        self.assertTrue(tensor_layout.is_fully_replicated)

    def testGraph(self):
        graph = mtf.Graph()
        self.assertLen(graph.operations, 0)
        self.assertLen(graph.tensors, 0)
        self.assertLen(graph.trainable_variables, 0)
        self.assertLen(graph.all_variables, 0)
        mesh = mtf.Mesh(graph, "mesh_test")
        _ = mtf.import_tf_tensor(mesh,
                                 tf_tensor=tf.constant(0.),
                                 shape=mtf.Shape([]))
        self.assertLen(graph.operations, 1)
        self.assertLen(graph.tensors, 1)
        self.assertLen(graph.trainable_variables, 0)
        self.assertLen(graph.all_variables, 0)
        _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True)
        self.assertLen(graph.operations, 2)
        self.assertLen(graph.tensors, 2)
        self.assertLen(graph.trainable_variables, 1)
        self.assertLen(graph.all_variables, 1)
        _ = mtf.get_variable(mesh,
                             "variable_1",
                             mtf.Shape([]),
                             trainable=False)
        self.assertLen(graph.operations, 3)
        self.assertLen(graph.tensors, 3)
        self.assertLen(graph.trainable_variables, 1)
        self.assertLen(graph.all_variables, 2)

    def testLowering(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        inputs = tf.constant(0.)
        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          tf_tensor=inputs,
                                          shape=mtf.Shape([]))
        mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                          layout={},
                                                          devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        outputs = lowering.export_to_tf_tensor(mtf_inputs)
        with self.test_session() as sess:
            inputs_value, outputs_value = sess.run([inputs, outputs])
        self.assertEqual(inputs_value, outputs_value)

        # Check that methods run without error.
        _ = lowering.copy_masters_to_slices()
        _ = lowering.copy_slices_to_masters()

    def testMesh(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        self.assertEqual(mesh.graph, graph)

    def testMeshImpl(self):
        shape = mtf.Shape(
            [mtf.Dimension("batch", 4),
             mtf.Dimension("model", 8)])
        layout_rules = mtf.LayoutRules([("batch", "batch"), ("d_ff", "model"),
                                        ("heads", "model")])
        mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules)
        self.assertEqual(mesh_impl.shape, shape)
        self.assertEqual(mesh_impl.ndims, len(shape))
        self.assertEqual(mesh_impl.layout_rules, layout_rules)
        self.assertEqual(mesh_impl.size, shape.size)
        self.assertTrue(mesh_impl.supports_control_dependencies)

        batch = mtf.Dimension("batch", 128)
        length = mtf.Dimension("length", 500)
        d_ff = mtf.Dimension("d_ff", 2048)
        heads = mtf.Dimension("heads", 8)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(batch), 0)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(d_ff), 1)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(heads), 1)
        self.assertEqual(
            mesh_impl.tensor_layout(mtf.Shape([batch, length, d_ff])),
            mtf.TensorLayout([0, None, 1]))
Esempio n. 5
0
    def estimator_model_fn(cls,
                           hparams,
                           features,
                           labels,
                           mode,
                           config=None,
                           params=None,
                           decode_hparams=None,
                           use_tpu=False,
                           xla_compile=False):
        hparams = copy.deepcopy(hparams)
        hparams.use_tpu = use_tpu
        # merge decode_hparams into hparams if present
        if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None:
            for k, v in six.iteritems(decode_hparams.values()):
                if hasattr(hparams, k) and getattr(hparams, k) != v:
                    tf.logging.warning(
                        "Overriding hparams.%s with %s from decode_hparams" %
                        (k, v))
                setattr(hparams, k, v)

        # Instantiate model
        data_parallelism = None
        if not use_tpu and config:
            data_parallelism = config.data_parallelism
        model = cls(hparams,
                    mode,
                    data_parallelism=data_parallelism,
                    decode_hparams=decode_hparams)

        global_step = tf.train.get_global_step()
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")

        mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
        layout_rules = mtf.LayoutRules(hparams.layout)
        if use_tpu:
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, mesh_devices,
                params["context"].device_assignment)
        else:
            if len(data_parallelism.ps_devices) == 1:
                mesh_devices = [""] * mesh_shape.size
            else:
                assert len(data_parallelism.ps_devices) == mesh_shape.size
                mesh_devices = data_parallelism.ps_devices
            mesh_impl = placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        # PREDICT mode
        if mode == tf.estimator.ModeKeys.PREDICT:
            return model.estimator_spec_predict(features, mesh, mesh_impl,
                                                use_tpu)

        logits, loss = model.mtf_model_fn(features, mesh)
        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            lr = learning_rate.learning_rate_schedule(hparams)
            mtf_lr = mtf.import_tf_tensor(
                mesh, tf.convert_to_tensor(lr, dtype=tf.float32),
                mtf.Shape([]))
            optimizer = mtf_optimize.make_optimizer(hparams, mtf_lr)
            update_ops = []
            for grad, var in zip(var_grads, graph.trainable_variables):
                update_ops.extend(optimizer.apply_grad(grad, var))

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if logits and mode != tf.estimator.ModeKeys.TRAIN:
            tf_logits = lowering.export_to_tf_tensor(logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            # tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)

        with mtf_utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                hparams.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

        # EVAL mode
        if mode == tf.estimator.ModeKeys.EVAL:
            tf_logits = lowering.export_to_tf_tensor(logits)
            return model.estimator_spec_eval(features, tf_logits, labels,
                                             tf_loss, restore_hook, use_tpu)

        if use_tpu:
            _remove_summaries()
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_hooks=[restore_hook, saver_hook])
        else:
            return tf.estimator.EstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_chief_hooks=[restore_hook, saver_hook])
Esempio n. 6
0
def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""
  tf.logging.info("features = %s labels = %s mode = %s params=%s" %
                  (features, labels, mode, params))
  global_step = tf.train.get_global_step()
  graph = mtf.Graph()
  mesh = mtf.Mesh(graph, "my_mesh")
  logits, loss = mnist_model(features, labels, mesh)
  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  layout_rules = mtf.LayoutRules(FLAGS.layout)
  mesh_size = mesh_shape.size
  mesh_devices = [""] * mesh_size
  mesh_impl = placement_mesh_impl.PlacementMeshImpl(
      mesh_shape, layout_rules, mesh_devices)

  if mode == tf.estimator.ModeKeys.TRAIN:
    var_grads = mtf.gradients(
        [loss], [v.outputs[0] for v in graph.trainable_variables])
    optimizer = mtf_optimize.AdafactorOptimizer()
    update_ops = []
    for grad, var in zip(var_grads, graph.trainable_variables):
      update_ops.extend(optimizer.apply_grad(grad, var))

  lowering = mtf.Lowering(graph, {mesh: mesh_impl})
  restore_hook = mtf.MtfRestoreHook(lowering)

  tf_logits = lowering.export_to_tf_tensor(logits)
  if mode != tf.estimator.ModeKeys.PREDICT:
    tf_loss = lowering.export_to_tf_tensor(loss)
    tf.summary.scalar("loss", tf_loss)

  if mode == tf.estimator.ModeKeys.TRAIN:
    tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
    tf_update_ops.append(tf.assign_add(global_step, 1))
    train_op = tf.group(tf_update_ops)
    saver = tf.train.Saver(
        tf.global_variables(),
        sharded=True,
        max_to_keep=10,
        keep_checkpoint_every_n_hours=2,
        defer_build=False, save_relative_paths=True)
    tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
    saver_listener = mtf.MtfCheckpointSaverListener(lowering)
    saver_hook = tf.train.CheckpointSaverHook(
        FLAGS.model_dir,
        save_steps=1000,
        saver=saver,
        listeners=[saver_listener])

    accuracy = tf.metrics.accuracy(
        labels=labels, predictions=tf.argmax(tf_logits, axis=1))

    # Name tensors to be logged with LoggingTensorHook.
    tf.identity(tf_loss, "cross_entropy")
    tf.identity(accuracy[1], name="train_accuracy")

    # Save accuracy scalar to Tensorboard output.
    tf.summary.scalar("train_accuracy", accuracy[1])

    # restore_hook must come before saver_hook
    return tf.estimator.EstimatorSpec(
        tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
        training_chief_hooks=[restore_hook, saver_hook])

  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {
        "classes": tf.argmax(tf_logits, axis=1),
        "probabilities": tf.nn.softmax(tf_logits),
    }
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        prediction_hooks=[restore_hook],
        export_outputs={
            "classify": tf.estimator.export.PredictOutput(predictions)
        })
  if mode == tf.estimator.ModeKeys.EVAL:
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        loss=tf_loss,
        evaluation_hooks=[restore_hook],
        eval_metric_ops={
            "accuracy":
            tf.metrics.accuracy(
                labels=labels, predictions=tf.argmax(tf_logits, axis=1)),
        })