示例#1
0
 def test_computes(self):
     rng_key = jax_random.get_prng(0)
     hidden_size = (4, 4)
     output_size = 6
     model = atari_cnn.AtariCnn(hidden_sizes=hidden_size,
                                output_size=output_size)
     B, T, OBS = 2, 2, (28, 28, 3)  # pylint: disable=invalid-name
     rng_key, key = jax_random.split(rng_key)
     _, _ = model.initialize_once((1, 1) + OBS, onp.float32, key)
     x = onp.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape(
         B, T + 1, *OBS)
     y = model(x)
     self.assertEqual((B, T + 1, output_size), y.shape)
示例#2
0
 def test_computes(self):
   rng_key = jax_random.get_prng(0)
   hidden_size = (4, 4)
   output_size = 6
   policy = atari_cnn.AtariCnn(
       hidden_sizes=hidden_size, output_size=output_size)
   B, T, OBS = 2, 2, (28, 28, 3)  # pylint: disable=invalid-name
   rng_key, key = jax_random.split(rng_key)
   params = policy.initialize((-1, -1) + OBS, key)
   x = onp.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape(
       B, T + 1, *OBS)
   rng_key, key = jax_random.split(rng_key)
   y = policy(x, params, rng=key)
   self.assertEqual((B, T + 1, output_size), y.shape)
示例#3
0
def atari_layers():
    return [atari_cnn.AtariCnn()]