def policy_test(n_episodes: int, model: flax.nn.base.Model, game: str): """Perform a test of the policy in Atari environment. Args: n_episodes: number of full Atari episodes to test on model: the actor-critic model being tested game: defines the Atari game to test on Returns: total_reward: obtained score """ test_env = env_utils.create_env(game, clip_rewards=False) for _ in range(n_episodes): obs = test_env.reset() state = obs[None, ...] # add batch dimension total_reward = 0.0 for t in itertools.count(): log_probs, _ = agent.policy_action(model, state) probs = onp.exp(onp.array(log_probs, dtype=onp.float32)) probabilities = probs[0] / probs[0].sum() action = onp.random.choice(probs.shape[1], p=probabilities) obs, reward, done, _ = test_env.step(action) total_reward += reward next_state = obs[None, ...] if not done else None state = next_state if done: break return total_reward
def get_experience( state: train_state.TrainState, simulators: List[agent.RemoteSimulator], steps_per_actor: int): """Collect experience from agents. Runs `steps_per_actor` time steps of the game for each of the `simulators`. """ all_experience = [] # Range up to steps_per_actor + 1 to get one more value needed for GAE. for _ in range(steps_per_actor + 1): sim_states = [] for sim in simulators: sim_state = sim.conn.recv() sim_states.append(sim_state) sim_states = np.concatenate(sim_states, axis=0) log_probs, values = agent.policy_action(state.apply_fn, state.params, sim_states) log_probs, values = jax.device_get((log_probs, values)) probs = np.exp(np.array(log_probs)) for i, sim in enumerate(simulators): probabilities = probs[i] action = np.random.choice(probs.shape[1], p=probabilities) sim.conn.send(action) experiences = [] for i, sim in enumerate(simulators): sim_state, action, reward, done = sim.conn.recv() value = values[i, 0] log_prob = log_probs[i][action] sample = agent.ExpTuple(sim_state, action, reward, value, log_prob, done) experiences.append(sample) all_experience.append(experiences) return all_experience
def policy_test(n_episodes: int, apply_fn: Callable[..., Any], params: flax.core.frozen_dict.FrozenDict, game: str): """Perform a test of the policy in Atari environment. Args: n_episodes: number of full Atari episodes to test on apply_fn: the actor-critic apply function params: actor-critic model parameters, they define the policy being tested game: defines the Atari game to test on Returns: total_reward: obtained score """ test_env = env_utils.create_env(game, clip_rewards=False) for _ in range(n_episodes): obs = test_env.reset() state = obs[None, ...] # add batch dimension total_reward = 0.0 for t in itertools.count(): log_probs, _ = agent.policy_action(apply_fn, params, state) probs = np.exp(np.array(log_probs, dtype=np.float32)) probabilities = probs[0] / probs[0].sum() action = np.random.choice(probs.shape[1], p=probabilities) obs, reward, done, _ = test_env.step(action) total_reward += reward next_state = obs[None, ...] if not done else None state = next_state if done: break return total_reward
def test_model(self): outputs = self.choose_random_outputs() module = models.ActorCritic(num_outputs=outputs) params = ppo_lib.get_initial_params(jax.random.PRNGKey(0), module) test_batch_size, obs_shape = 10, (84, 84, 4) random_input = np.random.random(size=(test_batch_size, ) + obs_shape) log_probs, values = agent.policy_action(module.apply, params, random_input) self.assertEqual(values.shape, (test_batch_size, 1)) sum_probs = np.sum(np.exp(log_probs), axis=1) self.assertEqual(sum_probs.shape, (test_batch_size, )) np_testing.assert_allclose(sum_probs, np.ones((test_batch_size, )), atol=1e-6)
def loss_fn( params: flax.core.FrozenDict, apply_fn: Callable[..., Any], minibatch: Tuple, clip_param: float, vf_coeff: float, entropy_coeff: float): """Evaluate the loss function. Compute loss as a sum of three components: the negative of the PPO clipped surrogate objective, the value function loss and the negative of the entropy bonus. Args: params: the parameters of the actor-critic model apply_fn: the actor-critic model's apply function minibatch: Tuple of five elements forming one experience batch: states: shape (batch_size, 84, 84, 4) actions: shape (batch_size, 84, 84, 4) old_log_probs: shape (batch_size,) returns: shape (batch_size,) advantages: shape (batch_size,) clip_param: the PPO clipping parameter used to clamp ratios in loss function vf_coeff: weighs value function loss in total loss entropy_coeff: weighs entropy bonus in the total loss Returns: loss: the PPO loss, scalar quantity """ states, actions, old_log_probs, returns, advantages = minibatch log_probs, values = agent.policy_action(apply_fn, params, states) values = values[:, 0] # Convert shapes: (batch, 1) to (batch, ). probs = jnp.exp(log_probs) value_loss = jnp.mean(jnp.square(returns - values), axis=0) entropy = jnp.sum(-probs*log_probs, axis=1).mean() log_probs_act_taken = jax.vmap(lambda lp, a: lp[a])(log_probs, actions) ratios = jnp.exp(log_probs_act_taken - old_log_probs) # Advantage normalization (following the OpenAI baselines). advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) pg_loss = ratios * advantages clipped_loss = advantages * jax.lax.clamp(1. - clip_param, ratios, 1. + clip_param) ppo_loss = -jnp.mean(jnp.minimum(pg_loss, clipped_loss), axis=0) return ppo_loss + vf_coeff*value_loss - entropy_coeff*entropy
def test_model(self): key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) outputs = self.choose_random_outputs() module = models.ActorCritic(num_outputs=outputs) initial_params = models.get_initial_params(subkey, module) lr = 2.5e-4 optimizer = models.create_optimizer(initial_params, lr) self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer)) test_batch_size, obs_shape = 10, (84, 84, 4) random_input = np.random.random(size=(test_batch_size, ) + obs_shape) log_probs, values = agent.policy_action(optimizer.target, module, random_input) self.assertEqual(values.shape, (test_batch_size, 1)) sum_probs = np.sum(np.exp(log_probs), axis=1) self.assertEqual(sum_probs.shape, (test_batch_size, )) np_testing.assert_allclose(sum_probs, np.ones((test_batch_size, )), atol=1e-6)