def policy_test(n_episodes: int, model: flax.nn.base.Model, game: str): """Perform a test of the policy in Atari environment. Args: n_episodes: number of full Atari episodes to test on model: the actor-critic model being tested game: defines the Atari game to test on Returns: total_reward: obtained score """ test_env = env_utils.create_env(game, clip_rewards=False) for _ in range(n_episodes): obs = test_env.reset() state = obs[None, ...] # add batch dimension total_reward = 0.0 for t in itertools.count(): log_probs, _ = agent.policy_action(model, state) probs = onp.exp(onp.array(log_probs, dtype=onp.float32)) probabilities = probs[0] / probs[0].sum() action = onp.random.choice(probs.shape[1], p=probabilities) obs, reward, done, _ = test_env.step(action) total_reward += reward next_state = obs[None, ...] if not done else None state = next_state if done: break return total_reward
def policy_test(n_episodes: int, apply_fn: Callable[..., Any], params: flax.core.frozen_dict.FrozenDict, game: str): """Perform a test of the policy in Atari environment. Args: n_episodes: number of full Atari episodes to test on apply_fn: the actor-critic apply function params: actor-critic model parameters, they define the policy being tested game: defines the Atari game to test on Returns: total_reward: obtained score """ test_env = env_utils.create_env(game, clip_rewards=False) for _ in range(n_episodes): obs = test_env.reset() state = obs[None, ...] # add batch dimension total_reward = 0.0 for t in itertools.count(): log_probs, _ = agent.policy_action(apply_fn, params, state) probs = np.exp(np.array(log_probs, dtype=np.float32)) probabilities = probs[0] / probs[0].sum() action = np.random.choice(probs.shape[1], p=probabilities) obs, reward, done, _ = test_env.step(action) total_reward += reward next_state = obs[None, ...] if not done else None state = next_state if done: break return total_reward
def test_step(self): frame_shape = (84, 84, 4) game = self.choose_random_game() env = env_utils.create_env(game, clip_rewards=True) obs = env.reset() actions = [1, 2, 3, 0] for a in actions: obs, reward, done, info = env.step(a) self.assertEqual(obs.shape, frame_shape) self.assertTrue(reward <= 1. and reward >= -1.) self.assertTrue(isinstance(done, bool)) self.assertTrue(isinstance(info, dict))
def rcv_action_send_exp(conn, game: str): """Run the remote agents. Receive action from the main learner, perform one step of simulation and send back collected experience. """ env = env_utils.create_env(game, clip_rewards=True) while True: obs = env.reset() done = False # Observations fetched from Atari env need additional batch dimension. state = obs[None, ...] while not done: conn.send(state) action = conn.recv() obs, reward, done, _ = env.step(action) next_state = obs[None, ...] if not done else None experience = (state, action, reward, done) conn.send(experience) if done: break state = next_state
def test_creation(self): frame_shape = (84, 84, 4) game = self.choose_random_game() env = env_utils.create_env(game, clip_rewards=True) obs = env.reset() self.assertEqual(obs.shape, frame_shape)