def test_autoregressive_sample_transformer(self): model = models.Transformer(10, d_model=32, d_ff=64, n_encoder_layers=1, n_decoder_layers=1, n_heads=2, mode='predict') inputs = np.ones((1, 3), dtype=np.int32) model.init((shapes.signature(inputs), shapes.ShapeDtype((1, 1), dtype=np.int32))) s = decoding.autoregressive_sample(model, inputs=inputs, eos_id=-1, max_length=10) self.assertEqual(s.shape[0], 1) self.assertEqual(s.shape[1], 10)
def test_autoregressive_sample_transformer_quality(self): pred_model = models.Transformer( d_model=64, d_ff=128, dropout=0.05, max_len=256, n_heads=2, n_encoder_layers=2, n_decoder_layers=2, input_vocab_size=13, mode='predict') shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) model_path = os.path.join(_TESTDATA, 'transformer_copy.pkl.gz') pred_model.init_from_file(model_path, weights_only=True, input_signature=(shape11, shape11)) inputs = np.array([[3, 7, 5, 3, 2, 4, 1, 8]], dtype=np.int32) s = decoding.autoregressive_sample(pred_model, inputs=inputs, eos_id=1, max_length=10, temperature=0.0) self.assertEqual(str(s[0]), '[3 7 5 3 2 4 1]')