def get_model(hparams=None, mode=tf_estimator.ModeKeys.TRAIN, has_input=True, model_cls=mtf_transformer.MtfTransformer): if hparams is None: hparams = mtf_transformer.mtf_transformer_single() hparams.max_length = INPUT_LENGTH hparams.batch_size = BATCH_SIZE p_hparams = problem_hparams.test_problem_hparams(VOCAB_SIZE, VOCAB_SIZE, hparams) if not has_input: del p_hparams.modality["inputs"] hparams.problem_hparams = p_hparams inputs = np.random.randint(VOCAB_SIZE, size=(BATCH_SIZE, INPUT_LENGTH, 1, 1)) targets = np.random.randint(VOCAB_SIZE, size=(BATCH_SIZE, TARGET_LENGTH, 1, 1)) features = { "targets": tf.constant(targets, dtype=tf.int32, name="targets"), "target_space_id": tf.constant(1, dtype=tf.int32) } if has_input: features["inputs"] = tf.constant(inputs, dtype=tf.int32, name="inputs") return model_cls(hparams, mode, p_hparams), features, hparams
def get_model(hparams=None, mode=tf.estimator.ModeKeys.TRAIN, has_input=True, model_cls=mtf_transformer.MtfTransformer): if hparams is None: hparams = mtf_transformer.mtf_transformer_single() hparams.max_length = INPUT_LENGTH hparams.batch_size = BATCH_SIZE p_hparams = problem_hparams.test_problem_hparams(VOCAB_SIZE, VOCAB_SIZE, hparams) if not has_input: del p_hparams.modality["inputs"] hparams.problem_hparams = p_hparams inputs = -1 + np.random.random_integers( VOCAB_SIZE, size=(BATCH_SIZE, INPUT_LENGTH, 1, 1)) targets = -1 + np.random.random_integers( VOCAB_SIZE, size=(BATCH_SIZE, TARGET_LENGTH, 1, 1)) features = { "targets": tf.constant(targets, dtype=tf.int32, name="targets"), "target_space_id": tf.constant(1, dtype=tf.int32) } if has_input: features["inputs"] = tf.constant(inputs, dtype=tf.int32, name="inputs") return model_cls(hparams, mode, p_hparams), features, hparams
def testMtfTransformerDataParallel(self): hparams = mtf_transformer.mtf_transformer_single() model, features, hparams = get_model(hparams) hparams.mesh_shape = "all:2" hparams.layout = "batch:all" mesh, mesh_impl = get_placement_mesh(hparams) logits, _ = model.mtf_model_fn(features, mesh) lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl}) tf_group = lowering.copy_masters_to_slices() tf_logits = lowering.export_to_tf_tensor(logits) with self.test_session() as session: session.run(tf.global_variables_initializer()) session.run(tf_group) res = session.run(tf_logits) self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, VOCAB_SIZE))
def testMtfTransformerDataParallel(self): hparams = mtf_transformer.mtf_transformer_single() model, features, hparams = get_model(hparams) hparams.mesh_shape = "all:2" hparams.layout = "batch:all" mesh, mesh_impl = get_placement_mesh(hparams) logits, _ = model.mtf_model_fn(features, mesh) lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl}) tf_group = lowering.copy_masters_to_slices() tf_logits = lowering.export_to_tf_tensor(logits) with self.test_session() as session: session.run(tf.global_variables_initializer()) session.run(tf_group) res = session.run(tf_logits) self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, VOCAB_SIZE))