def testTransformerRelative(self): model, features = get_model(transformer.transformer_relative_tiny()) logits, _ = model(features) with self.test_session() as session: session.run(tf.global_variables_initializer()) res = session.run(logits) self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, 1, 1, VOCAB_SIZE))
def testBeamDecodeWithRelativeAttention(self): decode_length = 2 model, features = get_model(transformer.transformer_relative_tiny()) model.set_mode(tf.estimator.ModeKeys.PREDICT) beam_result = model._beam_decode( features, decode_length, beam_size=4, top_beams=1, alpha=1.0)["outputs"] with self.test_session(): tf.global_variables_initializer().run() beam_result.eval()
def testBeamDecodeWithRelativeAttention(self): decode_length = 2 model, features = self.getModel(transformer.transformer_relative_tiny()) model(features) model.set_mode(tf.estimator.ModeKeys.PREDICT) with tf.variable_scope(tf.get_variable_scope(), reuse=True): beam_result = model._beam_decode( features, decode_length, beam_size=4, top_beams=1, alpha=1.0)["outputs"] with self.test_session(): tf.global_variables_initializer().run() beam_res = beam_result.eval() self.assertEqual(beam_res.shape, (BATCH_SIZE, INPUT_LENGTH + decode_length))