def test_computes(self): hidden_size = (4, 4) output_size = 6 model = atari_cnn.AtariCnn(hidden_sizes=hidden_size, output_size=output_size) B, T, OBS = 2, 2, (28, 28, 3) # pylint: disable=invalid-name input_signature = ShapeDtype((1, 1) + OBS) _, _ = model.init(input_signature) x = onp.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape( B, T + 1, *OBS) y = model(x) self.assertEqual((B, T + 1, output_size), y.shape)
def test_computes(self): rng_key = jax_random.get_prng(0) hidden_size = (4, 4) output_size = 6 model = atari_cnn.AtariCnn(hidden_sizes=hidden_size, output_size=output_size) B, T, OBS = 2, 2, (28, 28, 3) # pylint: disable=invalid-name rng_key, key = jax_random.split(rng_key) _, _ = model.initialize_once((1, 1) + OBS, onp.float32, key) x = onp.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape( B, T + 1, *OBS) y = model(x) self.assertEqual((B, T + 1, output_size), y.shape)