예제 #1
0
파일: rnn_test.py 프로젝트: zsunpku/trax
 def test_grulm_forward_shape(self):
     """Runs the GRU LM forward and checks output shape."""
     input_signature = ShapeDtype((3, 28), dtype=math.numpy.int32)
     model = rnn.GRULM(vocab_size=20, d_model=16)
     model.init(input_signature)
     final_shape = tl.check_shape_agreement(model, input_signature)
     self.assertEqual((3, 28, 20), final_shape)
예제 #2
0
 def test_grulm_forward_shape(self, backend):
     with fastmath.use_backend(backend):
         model = rnn.GRULM(vocab_size=20, d_model=16)
         x = np.ones((3, 28)).astype(np.int32)
         _, _ = model.init(shapes.signature(x))
         y = model(x)
         self.assertEqual(y.shape, (3, 28, 20))