예제 #1
0
 def test_autoregressive_sample_transformer(self):
   model = models.Transformer(10, d_model=32, d_ff=64, n_encoder_layers=1,
                              n_decoder_layers=1, n_heads=2, mode='predict')
   inputs = np.ones((1, 3), dtype=np.int32)
   model.init((shapes.signature(inputs),
               shapes.ShapeDtype((1, 1), dtype=np.int32)))
   s = decoding.autoregressive_sample(model, inputs=inputs,
                                      eos_id=-1, max_length=10)
   self.assertEqual(s.shape[0], 1)
   self.assertEqual(s.shape[1], 10)
예제 #2
0
 def test_autoregressive_sample_transformer_quality(self):
   pred_model = models.Transformer(
       d_model=64, d_ff=128, dropout=0.05, max_len=256, n_heads=2,
       n_encoder_layers=2, n_decoder_layers=2, input_vocab_size=13,
       mode='predict')
   shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
   model_path = os.path.join(_TESTDATA, 'transformer_copy.pkl.gz')
   pred_model.init_from_file(model_path, weights_only=True,
                             input_signature=(shape11, shape11))
   inputs = np.array([[3, 7, 5, 3, 2, 4, 1, 8]], dtype=np.int32)
   s = decoding.autoregressive_sample(pred_model, inputs=inputs,
                                      eos_id=1, max_length=10, temperature=0.0)
   self.assertEqual(str(s[0]), '[3 7 5 3 2 4 1]')