Exemplo n.º 1
0
 def test_computes(self):
     hidden_size = (4, 4)
     output_size = 6
     model = atari_cnn.FrameStackMLP(hidden_sizes=hidden_size,
                                     output_size=output_size)
     B, T, OBS = 2, 2, 3  # pylint: disable=invalid-name
     input_signature = ShapeDtype((1, 1, OBS))
     _, _ = model.init(input_signature)
     x = onp.arange(B * (T + 1) * OBS).reshape(B, T + 1, OBS)
     y = model(x)
     self.assertEqual((B, T + 1, output_size), y.shape)
Exemplo n.º 2
0
 def test_computes(self):
     rng_key = jax_random.get_prng(0)
     hidden_size = (4, 4)
     output_size = 6
     model = atari_cnn.FrameStackMLP(hidden_sizes=hidden_size,
                                     output_size=output_size)
     B, T, OBS = 2, 2, 3  # pylint: disable=invalid-name
     rng_key, key = jax_random.split(rng_key)
     _, _ = model.initialize_once((1, 1, OBS), onp.float32, key)
     x = onp.arange(B * (T + 1) * OBS).reshape(B, T + 1, OBS)
     y = model(x)
     self.assertEqual((B, T + 1, output_size), y.shape)