コード例 #1
0
    def testMtfImageTransformerDataParallel(self):
        hparams = mtf_image_transformer.mtf_image_transformer_single()

        model, features, hparams = get_model(hparams)
        hparams.mesh_shape = "all:2"
        hparams.layout = "batch:all"
        mesh, mesh_impl = get_placement_mesh(hparams)

        logits, _ = model.mtf_model_fn(features, mesh)
        lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
        tf_group = lowering.copy_masters_to_slices()
        tf_logits = lowering.export_to_tf_tensor(logits)

        with self.test_session() as session:
            session.run(tf.global_variables_initializer())
            session.run(tf_group)
            res = session.run(tf_logits)
        self.assertEqual(res.shape, (BATCH_SIZE, IMG_LENGTH, IMG_LENGTH,
                                     hparams.num_channels, VOCAB_SIZE))
コード例 #2
0
def get_model(hparams=None,
              mode=tf.estimator.ModeKeys.TRAIN,
              model_cls=mtf_image_transformer.MtfImageTransformer):
    if hparams is None:
        hparams = mtf_image_transformer.mtf_image_transformer_single()
    hparams.max_length = IMG_LENGTH * IMG_LENGTH
    hparams.batch_size = BATCH_SIZE
    hparams.img_len = IMG_LENGTH
    hparams.num_channels = 1

    p_hparams = problem_hparams.test_problem_hparams(VOCAB_SIZE, VOCAB_SIZE,
                                                     hparams)
    p_hparams.input_modality = {}
    hparams.problem_hparams = p_hparams

    targets = -1 + np.random.random_integers(
        VOCAB_SIZE, size=(BATCH_SIZE, IMG_LENGTH, IMG_LENGTH, 1, 1))
    features = {
        "targets": tf.constant(targets, dtype=tf.int32, name="targets"),
    }

    return model_cls(hparams, mode, p_hparams), features, hparams