Ejemplo n.º 1
0
def policy_test(n_episodes: int, model: flax.nn.base.Model, game: str):
    """Perform a test of the policy in Atari environment.

  Args:
    n_episodes: number of full Atari episodes to test on
    model: the actor-critic model being tested
    game: defines the Atari game to test on

  Returns:
    total_reward: obtained score
  """
    test_env = env_utils.create_env(game, clip_rewards=False)
    for _ in range(n_episodes):
        obs = test_env.reset()
        state = obs[None, ...]  # add batch dimension
        total_reward = 0.0
        for t in itertools.count():
            log_probs, _ = agent.policy_action(model, state)
            probs = onp.exp(onp.array(log_probs, dtype=onp.float32))
            probabilities = probs[0] / probs[0].sum()
            action = onp.random.choice(probs.shape[1], p=probabilities)
            obs, reward, done, _ = test_env.step(action)
            total_reward += reward
            next_state = obs[None, ...] if not done else None
            state = next_state
            if done:
                break
    return total_reward
Ejemplo n.º 2
0
def get_experience(
    state: train_state.TrainState,
    simulators: List[agent.RemoteSimulator],
    steps_per_actor: int):
  """Collect experience from agents.

  Runs `steps_per_actor` time steps of the game for each of the `simulators`.
  """
  all_experience = []
  # Range up to steps_per_actor + 1 to get one more value needed for GAE.
  for _ in range(steps_per_actor + 1):
    sim_states = []
    for sim in simulators:
      sim_state = sim.conn.recv()
      sim_states.append(sim_state)
    sim_states = np.concatenate(sim_states, axis=0)
    log_probs, values = agent.policy_action(state.apply_fn, state.params, sim_states)
    log_probs, values = jax.device_get((log_probs, values))
    probs = np.exp(np.array(log_probs))
    for i, sim in enumerate(simulators):
      probabilities = probs[i]
      action = np.random.choice(probs.shape[1], p=probabilities)
      sim.conn.send(action)
    experiences = []
    for i, sim in enumerate(simulators):
      sim_state, action, reward, done = sim.conn.recv()
      value = values[i, 0]
      log_prob = log_probs[i][action]
      sample = agent.ExpTuple(sim_state, action, reward, value, log_prob, done)
      experiences.append(sample)
    all_experience.append(experiences)
  return all_experience
Ejemplo n.º 3
0
def policy_test(n_episodes: int, apply_fn: Callable[..., Any],
                params: flax.core.frozen_dict.FrozenDict, game: str):
    """Perform a test of the policy in Atari environment.

  Args:
    n_episodes: number of full Atari episodes to test on
    apply_fn: the actor-critic apply function
    params: actor-critic model parameters, they define the policy being tested
    game: defines the Atari game to test on

  Returns:
    total_reward: obtained score
  """
    test_env = env_utils.create_env(game, clip_rewards=False)
    for _ in range(n_episodes):
        obs = test_env.reset()
        state = obs[None, ...]  # add batch dimension
        total_reward = 0.0
        for t in itertools.count():
            log_probs, _ = agent.policy_action(apply_fn, params, state)
            probs = np.exp(np.array(log_probs, dtype=np.float32))
            probabilities = probs[0] / probs[0].sum()
            action = np.random.choice(probs.shape[1], p=probabilities)
            obs, reward, done, _ = test_env.step(action)
            total_reward += reward
            next_state = obs[None, ...] if not done else None
            state = next_state
            if done:
                break
    return total_reward
Ejemplo n.º 4
0
 def test_model(self):
     outputs = self.choose_random_outputs()
     module = models.ActorCritic(num_outputs=outputs)
     params = ppo_lib.get_initial_params(jax.random.PRNGKey(0), module)
     test_batch_size, obs_shape = 10, (84, 84, 4)
     random_input = np.random.random(size=(test_batch_size, ) + obs_shape)
     log_probs, values = agent.policy_action(module.apply, params,
                                             random_input)
     self.assertEqual(values.shape, (test_batch_size, 1))
     sum_probs = np.sum(np.exp(log_probs), axis=1)
     self.assertEqual(sum_probs.shape, (test_batch_size, ))
     np_testing.assert_allclose(sum_probs,
                                np.ones((test_batch_size, )),
                                atol=1e-6)
Ejemplo n.º 5
0
def loss_fn(
    params: flax.core.FrozenDict,
    apply_fn: Callable[..., Any],
    minibatch: Tuple,
    clip_param: float,
    vf_coeff: float,
    entropy_coeff: float):
  """Evaluate the loss function.

  Compute loss as a sum of three components: the negative of the PPO clipped
  surrogate objective, the value function loss and the negative of the entropy
  bonus.

  Args:
    params: the parameters of the actor-critic model
    apply_fn: the actor-critic model's apply function
    minibatch: Tuple of five elements forming one experience batch:
               states: shape (batch_size, 84, 84, 4)
               actions: shape (batch_size, 84, 84, 4)
               old_log_probs: shape (batch_size,)
               returns: shape (batch_size,)
               advantages: shape (batch_size,)
    clip_param: the PPO clipping parameter used to clamp ratios in loss function
    vf_coeff: weighs value function loss in total loss
    entropy_coeff: weighs entropy bonus in the total loss

  Returns:
    loss: the PPO loss, scalar quantity
  """
  states, actions, old_log_probs, returns, advantages = minibatch
  log_probs, values = agent.policy_action(apply_fn, params, states)
  values = values[:, 0]  # Convert shapes: (batch, 1) to (batch, ).
  probs = jnp.exp(log_probs)

  value_loss = jnp.mean(jnp.square(returns - values), axis=0)

  entropy = jnp.sum(-probs*log_probs, axis=1).mean()

  log_probs_act_taken = jax.vmap(lambda lp, a: lp[a])(log_probs, actions)
  ratios = jnp.exp(log_probs_act_taken - old_log_probs)
  # Advantage normalization (following the OpenAI baselines).
  advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
  pg_loss = ratios * advantages
  clipped_loss = advantages * jax.lax.clamp(1. - clip_param, ratios,
                                            1. + clip_param)
  ppo_loss = -jnp.mean(jnp.minimum(pg_loss, clipped_loss), axis=0)

  return ppo_loss + vf_coeff*value_loss - entropy_coeff*entropy
Ejemplo n.º 6
0
 def test_model(self):
     key = jax.random.PRNGKey(0)
     key, subkey = jax.random.split(key)
     outputs = self.choose_random_outputs()
     module = models.ActorCritic(num_outputs=outputs)
     initial_params = models.get_initial_params(subkey, module)
     lr = 2.5e-4
     optimizer = models.create_optimizer(initial_params, lr)
     self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer))
     test_batch_size, obs_shape = 10, (84, 84, 4)
     random_input = np.random.random(size=(test_batch_size, ) + obs_shape)
     log_probs, values = agent.policy_action(optimizer.target, module,
                                             random_input)
     self.assertEqual(values.shape, (test_batch_size, 1))
     sum_probs = np.sum(np.exp(log_probs), axis=1)
     self.assertEqual(sum_probs.shape, (test_batch_size, ))
     np_testing.assert_allclose(sum_probs,
                                np.ones((test_batch_size, )),
                                atol=1e-6)