def testMtfImageTransformerDataParallel(self): hparams = mtf_image_transformer.mtf_image_transformer_single() model, features, hparams = get_model(hparams) hparams.mesh_shape = "all:2" hparams.layout = "batch:all" mesh, mesh_impl = get_placement_mesh(hparams) logits, _ = model.mtf_model_fn(features, mesh) lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl}) tf_group = lowering.copy_masters_to_slices() tf_logits = lowering.export_to_tf_tensor(logits) with self.test_session() as session: session.run(tf.global_variables_initializer()) session.run(tf_group) res = session.run(tf_logits) self.assertEqual(res.shape, (BATCH_SIZE, IMG_LENGTH, IMG_LENGTH, hparams.num_channels, VOCAB_SIZE))
def get_model(hparams=None, mode=tf.estimator.ModeKeys.TRAIN, model_cls=mtf_image_transformer.MtfImageTransformer): if hparams is None: hparams = mtf_image_transformer.mtf_image_transformer_single() hparams.max_length = IMG_LENGTH * IMG_LENGTH hparams.batch_size = BATCH_SIZE hparams.img_len = IMG_LENGTH hparams.num_channels = 1 p_hparams = problem_hparams.test_problem_hparams(VOCAB_SIZE, VOCAB_SIZE, hparams) p_hparams.input_modality = {} hparams.problem_hparams = p_hparams targets = -1 + np.random.random_integers( VOCAB_SIZE, size=(BATCH_SIZE, IMG_LENGTH, IMG_LENGTH, 1, 1)) features = { "targets": tf.constant(targets, dtype=tf.int32, name="targets"), } return model_cls(hparams, mode, p_hparams), features, hparams