Ejemplo n.º 1
0
 def testTransformerRelative(self):
   model, features = get_model(transformer.transformer_relative_tiny())
   logits, _ = model(features)
   with self.test_session() as session:
     session.run(tf.global_variables_initializer())
     res = session.run(logits)
   self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, 1, 1, VOCAB_SIZE))
Ejemplo n.º 2
0
 def testTransformerRelative(self):
   model, features = get_model(transformer.transformer_relative_tiny())
   logits, _ = model(features)
   with self.test_session() as session:
     session.run(tf.global_variables_initializer())
     res = session.run(logits)
   self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, 1, 1, VOCAB_SIZE))
Ejemplo n.º 3
0
  def testBeamDecodeWithRelativeAttention(self):
    decode_length = 2
    model, features = get_model(transformer.transformer_relative_tiny())
    model.set_mode(tf.estimator.ModeKeys.PREDICT)

    beam_result = model._beam_decode(
        features, decode_length, beam_size=4, top_beams=1,
        alpha=1.0)["outputs"]

    with self.test_session():
      tf.global_variables_initializer().run()
      beam_result.eval()
Ejemplo n.º 4
0
  def testBeamDecodeWithRelativeAttention(self):
    decode_length = 2
    model, features = get_model(transformer.transformer_relative_tiny())
    model.set_mode(tf.estimator.ModeKeys.PREDICT)

    beam_result = model._beam_decode(
        features, decode_length, beam_size=4, top_beams=1,
        alpha=1.0)["outputs"]

    with self.test_session():
      tf.global_variables_initializer().run()
      beam_result.eval()
Ejemplo n.º 5
0
  def testBeamDecodeWithRelativeAttention(self):
    decode_length = 2
    model, features = self.getModel(transformer.transformer_relative_tiny())
    model(features)
    model.set_mode(tf.estimator.ModeKeys.PREDICT)

    with tf.variable_scope(tf.get_variable_scope(), reuse=True):
      beam_result = model._beam_decode(
          features, decode_length, beam_size=4, top_beams=1,
          alpha=1.0)["outputs"]

    with self.test_session():
      tf.global_variables_initializer().run()
      beam_res = beam_result.eval()

    self.assertEqual(beam_res.shape, (BATCH_SIZE, INPUT_LENGTH + decode_length))