def test_computes(self): rng_key = jax_random.get_prng(0) hidden_size = (4, 4) output_size = 6 model = atari_cnn.AtariCnn(hidden_sizes=hidden_size, output_size=output_size) B, T, OBS = 2, 2, (28, 28, 3) # pylint: disable=invalid-name rng_key, key = jax_random.split(rng_key) _, _ = model.initialize_once((1, 1) + OBS, onp.float32, key) x = onp.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape( B, T + 1, *OBS) y = model(x) self.assertEqual((B, T + 1, output_size), y.shape)
def test_computes(self): rng_key = jax_random.get_prng(0) hidden_size = (4, 4) output_size = 6 policy = atari_cnn.AtariCnn( hidden_sizes=hidden_size, output_size=output_size) B, T, OBS = 2, 2, (28, 28, 3) # pylint: disable=invalid-name rng_key, key = jax_random.split(rng_key) params = policy.initialize((-1, -1) + OBS, key) x = onp.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape( B, T + 1, *OBS) rng_key, key = jax_random.split(rng_key) y = policy(x, params, rng=key) self.assertEqual((B, T + 1, output_size), y.shape)
def atari_layers(): return [atari_cnn.AtariCnn()]