def test_computes(self): hidden_size = (4, 4) output_size = 6 model = atari_cnn.FrameStackMLP(hidden_sizes=hidden_size, output_size=output_size) B, T, OBS = 2, 2, 3 # pylint: disable=invalid-name input_signature = ShapeDtype((1, 1, OBS)) _, _ = model.init(input_signature) x = onp.arange(B * (T + 1) * OBS).reshape(B, T + 1, OBS) y = model(x) self.assertEqual((B, T + 1, output_size), y.shape)
def test_computes(self): rng_key = jax_random.get_prng(0) hidden_size = (4, 4) output_size = 6 model = atari_cnn.FrameStackMLP(hidden_sizes=hidden_size, output_size=output_size) B, T, OBS = 2, 2, 3 # pylint: disable=invalid-name rng_key, key = jax_random.split(rng_key) _, _ = model.initialize_once((1, 1, OBS), onp.float32, key) x = onp.arange(B * (T + 1) * OBS).reshape(B, T + 1, OBS) y = model(x) self.assertEqual((B, T + 1, output_size), y.shape)