def test_grulm_forward_shape(self): """Runs the GRU LM forward and checks output shape.""" input_signature = ShapeDtype((3, 28), dtype=math.numpy.int32) model = rnn.GRULM(vocab_size=20, d_model=16) model.init(input_signature) final_shape = tl.check_shape_agreement(model, input_signature) self.assertEqual((3, 28, 20), final_shape)
def test_grulm_forward_shape(self, backend): with fastmath.use_backend(backend): model = rnn.GRULM(vocab_size=20, d_model=16) x = np.ones((3, 28)).astype(np.int32) _, _ = model.init(shapes.signature(x)) y = model(x) self.assertEqual(y.shape, (3, 28, 20))