예제 #1
0
 def test_rnnlm_forward_shape(self, backend):
     with fastmath.use_backend(backend):
         model = rnn.RNNLM(vocab_size=20, d_model=16)
         x = np.ones((3, 28)).astype(np.int32)
         _, _ = model.init(shapes.signature(x))
         y = model(x)
         self.assertEqual(y.shape, (3, 28, 20))
예제 #2
0
 def test_rnnlm_forward_shape(self):
   """Runs the RNN LM forward and checks output shape."""
   input_signature = ShapeDtype((3, 28), dtype=math.numpy.int32)
   model = rnn.RNNLM(vocab_size=20, d_model=16)
   final_shape = tl.check_shape_agreement(model, input_signature)
   self.assertEqual((3, 28, 20), final_shape)