def get_model(hparams=None,
              mode=tf_estimator.ModeKeys.TRAIN,
              has_input=True,
              model_cls=mtf_transformer.MtfTransformer):
    if hparams is None:
        hparams = mtf_transformer.mtf_transformer_single()
    hparams.max_length = INPUT_LENGTH
    hparams.batch_size = BATCH_SIZE

    p_hparams = problem_hparams.test_problem_hparams(VOCAB_SIZE, VOCAB_SIZE,
                                                     hparams)
    if not has_input:
        del p_hparams.modality["inputs"]
    hparams.problem_hparams = p_hparams

    inputs = np.random.randint(VOCAB_SIZE,
                               size=(BATCH_SIZE, INPUT_LENGTH, 1, 1))
    targets = np.random.randint(VOCAB_SIZE,
                                size=(BATCH_SIZE, TARGET_LENGTH, 1, 1))
    features = {
        "targets": tf.constant(targets, dtype=tf.int32, name="targets"),
        "target_space_id": tf.constant(1, dtype=tf.int32)
    }
    if has_input:
        features["inputs"] = tf.constant(inputs, dtype=tf.int32, name="inputs")

    return model_cls(hparams, mode, p_hparams), features, hparams
def get_model(hparams=None, mode=tf.estimator.ModeKeys.TRAIN,
              has_input=True, model_cls=mtf_transformer.MtfTransformer):
  if hparams is None:
    hparams = mtf_transformer.mtf_transformer_single()
  hparams.max_length = INPUT_LENGTH
  hparams.batch_size = BATCH_SIZE

  p_hparams = problem_hparams.test_problem_hparams(VOCAB_SIZE,
                                                   VOCAB_SIZE,
                                                   hparams)
  if not has_input:
    del p_hparams.modality["inputs"]
  hparams.problem_hparams = p_hparams

  inputs = -1 + np.random.random_integers(
      VOCAB_SIZE, size=(BATCH_SIZE, INPUT_LENGTH, 1, 1))
  targets = -1 + np.random.random_integers(
      VOCAB_SIZE, size=(BATCH_SIZE, TARGET_LENGTH, 1, 1))
  features = {
      "targets": tf.constant(targets, dtype=tf.int32, name="targets"),
      "target_space_id": tf.constant(1, dtype=tf.int32)
  }
  if has_input:
    features["inputs"] = tf.constant(inputs, dtype=tf.int32, name="inputs")

  return model_cls(hparams, mode, p_hparams), features, hparams
    def testMtfTransformerDataParallel(self):
        hparams = mtf_transformer.mtf_transformer_single()

        model, features, hparams = get_model(hparams)
        hparams.mesh_shape = "all:2"
        hparams.layout = "batch:all"
        mesh, mesh_impl = get_placement_mesh(hparams)

        logits, _ = model.mtf_model_fn(features, mesh)
        lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
        tf_group = lowering.copy_masters_to_slices()
        tf_logits = lowering.export_to_tf_tensor(logits)

        with self.test_session() as session:
            session.run(tf.global_variables_initializer())
            session.run(tf_group)
            res = session.run(tf_logits)
        self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, VOCAB_SIZE))
  def testMtfTransformerDataParallel(self):
    hparams = mtf_transformer.mtf_transformer_single()

    model, features, hparams = get_model(hparams)
    hparams.mesh_shape = "all:2"
    hparams.layout = "batch:all"
    mesh, mesh_impl = get_placement_mesh(hparams)

    logits, _ = model.mtf_model_fn(features, mesh)
    lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
    tf_group = lowering.copy_masters_to_slices()
    tf_logits = lowering.export_to_tf_tensor(logits)

    with self.test_session() as session:
      session.run(tf.global_variables_initializer())
      session.run(tf_group)
      res = session.run(tf_logits)
    self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, VOCAB_SIZE))