def _preprocess_trajectories(self, trajectories):
     (_, reward_mask, observations, actions, rewards,
      infos) = (ppo.pad_trajectories(trajectories,
                                     boundary=self._max_timestep))
     assert self.train_env.observation_space.shape == observations.shape[2:]
     if not self._serialized_sequence_policy:
         # Add one timestep at the end, so it's compatible with
         # self._rewards_to_actions.
         pad_width = ((0, 0), (0, 1)) + ((0, 0), ) * (actions.ndim - 2)
         actions = np.pad(actions, pad_width)
         actions = np.reshape(actions, (actions.shape[0], -1))
     else:
         (observations,
          actions) = self._serialize_trajectories(observations, actions,
                                                  reward_mask)
     return (observations, actions, rewards, reward_mask, infos)
Пример #2
0
    def train_epoch(self):
        """Train one PPO epoch."""
        epoch_start_time = time.time()

        # Evaluate the policy.
        policy_eval_start_time = time.time()
        if (self._epoch + 1) % self._eval_every_n == 0:
            self._rng, key = jax_random.split(self._rng, num=2)
            self.evaluate()

        policy_eval_time = ppo.get_time(policy_eval_start_time)

        trajectory_collection_start_time = time.time()
        logging.vlog(1, "PPO epoch [% 6d]: collecting trajectories.",
                     self._epoch)
        self._rng, key = jax_random.split(self._rng)
        trajs, n_done, timing_info, self._model_state = ppo.collect_trajectories(
            self.train_env,
            policy_fn=self._get_predictions,
            n_trajectories=self.train_env.batch_size,
            max_timestep=self._max_timestep,
            state=self._model_state,
            rng=key,
            len_history_for_policy=self._len_history_for_policy,
            boundary=self._boundary,
            reset=self._should_reset,
        )
        self._should_reset = False
        trajectory_collection_time = ppo.get_time(
            trajectory_collection_start_time)

        logging.vlog(1, "Collecting trajectories took %0.2f msec.",
                     trajectory_collection_time)

        avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs)
        max_reward = max(np.sum(traj[2]) for traj in trajs)
        min_reward = min(np.sum(traj[2]) for traj in trajs)

        self._train_sw.scalar("train/reward_mean_truncated",
                              avg_reward,
                              step=self._epoch)

        logging.vlog(1,
                     "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s",
                     avg_reward, max_reward, min_reward,
                     [float(np.sum(traj[2])) for traj in trajs])

        logging.vlog(
            1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]",
            float(sum(len(traj[0]) for traj in trajs)) / len(trajs),
            max(len(traj[0]) for traj in trajs),
            min(len(traj[0]) for traj in trajs))
        logging.vlog(2, "Trajectory Lengths: %s",
                     [len(traj[0]) for traj in trajs])

        padding_start_time = time.time()
        (_, reward_mask, padded_observations, padded_actions, padded_rewards,
         padded_infos) = ppo.pad_trajectories(trajs, boundary=self._boundary)
        padding_time = ppo.get_time(padding_start_time)

        logging.vlog(1, "Padding trajectories took %0.2f msec.",
                     ppo.get_time(padding_start_time))
        logging.vlog(1, "Padded Observations' shape [%s]",
                     str(padded_observations.shape))
        logging.vlog(1, "Padded Actions' shape [%s]",
                     str(padded_actions.shape))
        logging.vlog(1, "Padded Rewards' shape [%s]",
                     str(padded_rewards.shape))

        # Some assertions.
        B, T = padded_actions.shape  # pylint: disable=invalid-name
        assert (B, T) == padded_rewards.shape
        assert (B, T) == reward_mask.shape
        assert (B, T + 1) == padded_observations.shape[:2]
        assert ((B, T + 1) + self.train_env.observation_space.shape ==
                padded_observations.shape)

        log_prob_recompute_start_time = time.time()
        assert ("log_prob_actions" in padded_infos
                and "value_predictions" in padded_infos)
        # These are the actual log-probabs and value predictions seen while picking
        # the actions.
        actual_log_probabs_traj = padded_infos["log_prob_actions"]
        actual_value_predictions_traj = padded_infos["value_predictions"]

        assert (B, T) == actual_log_probabs_traj.shape[:2]
        A = actual_log_probabs_traj.shape[2]  # pylint: disable=invalid-name
        assert (B, T, 1) == actual_value_predictions_traj.shape

        # TODO(afrozm): log-probabs doesn't need to be (B, T+1, A) it can do with
        # (B, T, A), so make that change throughout.

        # NOTE: We don't have the log-probabs and value-predictions for the last
        # observation, so we re-calculate for everything, but use the original ones
        # for all but the last time-step.
        self._rng, key = jax_random.split(self._rng)

        log_probabs_traj, value_predictions_traj, self._model_state, _ = (
            self._get_predictions(padded_observations,
                                  self._model_state,
                                  rng=key))

        assert (B, T + 1, A) == log_probabs_traj.shape
        assert (B, T + 1, 1) == value_predictions_traj.shape

        # Concatenate the last time-step's log-probabs and value predictions to the
        # actual log-probabs and value predictions and use those going forward.
        log_probabs_traj = np.concatenate(
            (actual_log_probabs_traj, log_probabs_traj[:, -1:, :]), axis=1)
        value_predictions_traj = np.concatenate(
            (actual_value_predictions_traj, value_predictions_traj[:, -1:, :]),
            axis=1)

        log_prob_recompute_time = ppo.get_time(log_prob_recompute_start_time)

        # Compute value and ppo losses.
        self._rng, key1 = jax_random.split(self._rng, num=2)
        logging.vlog(2, "Starting to compute P&V loss.")
        loss_compute_start_time = time.time()
        (cur_combined_loss, component_losses, summaries,
         self._model_state) = (ppo.combined_loss(
             self._policy_and_value_net_params,
             log_probabs_traj,
             value_predictions_traj,
             self._policy_and_value_net_apply,
             padded_observations,
             padded_actions,
             padded_rewards,
             reward_mask,
             gamma=self._gamma,
             lambda_=self._lambda_,
             c1=self._c1,
             c2=self._c2,
             state=self._model_state,
             rng=key1))
        loss_compute_time = ppo.get_time(loss_compute_start_time)
        (cur_ppo_loss, cur_value_loss, cur_entropy_bonus) = component_losses
        logging.vlog(
            1,
            "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.",
            cur_combined_loss, cur_ppo_loss, cur_value_loss, cur_entropy_bonus,
            ppo.get_time(loss_compute_start_time))

        self._rng, key1 = jax_random.split(self._rng, num=2)
        logging.vlog(1, "Policy and Value Optimization")
        optimization_start_time = time.time()
        keys = jax_random.split(key1, num=self._n_optimizer_steps)
        opt_step = 0
        for key in keys:
            k1, k2, k3 = jax_random.split(key, num=3)
            t = time.time()
            # Update the optimizer state.
            self._policy_and_value_opt_state, self._model_state = (
                ppo.policy_and_value_opt_step(
                    # We pass the optimizer slots between PPO epochs, so we need to
                    # pass the optimization step as well, so for example the
                    # bias-correction in Adam is calculated properly. Alternatively we
                    # could reset the slots and the step in every PPO epoch, but then
                    # the moment estimates in adaptive optimizers would never have
                    # enough time to warm up. So it makes sense to reuse the slots,
                    # even though we're optimizing a different loss in every new
                    # epoch.
                    self._total_opt_step,
                    self._policy_and_value_opt_state,
                    self._policy_and_value_opt_update,
                    self._policy_and_value_get_params,
                    self._policy_and_value_net_apply,
                    log_probabs_traj,
                    value_predictions_traj,
                    padded_observations,
                    padded_actions,
                    padded_rewards,
                    reward_mask,
                    c1=self._c1,
                    c2=self._c2,
                    gamma=self._gamma,
                    lambda_=self._lambda_,
                    state=self._model_state,
                    rng=k1))
            opt_step += 1
            self._total_opt_step += 1

            # Compute the approx KL for early stopping.
            (log_probab_actions_new,
             _), self._model_state = (self._policy_and_value_net_apply(
                 padded_observations,
                 self._policy_and_value_net_params,
                 self._model_state,
                 rng=k2))

            approx_kl = ppo.approximate_kl(log_probab_actions_new,
                                           log_probabs_traj, reward_mask)

            early_stopping = approx_kl > 1.5 * self._target_kl
            if early_stopping:
                logging.vlog(
                    1,
                    "Early stopping policy and value optimization after %d steps, "
                    "with approx_kl: %0.2f", opt_step, approx_kl)
                # We don't return right-away, we want the below to execute on the last
                # iteration.

            t2 = time.time()
            if (opt_step % self._print_every_optimizer_steps == 0
                    or opt_step == self._n_optimizer_steps or early_stopping):
                # Compute and log the loss.
                (combined_loss, component_losses, _,
                 self._model_state) = (ppo.combined_loss(
                     self._policy_and_value_net_params,
                     log_probabs_traj,
                     value_predictions_traj,
                     self._policy_and_value_net_apply,
                     padded_observations,
                     padded_actions,
                     padded_rewards,
                     reward_mask,
                     gamma=self._gamma,
                     lambda_=self._lambda_,
                     c1=self._c1,
                     c2=self._c2,
                     state=self._model_state,
                     rng=k3))
                logging.vlog(
                    1, "One Policy and Value grad desc took: %0.2f msec",
                    ppo.get_time(t, t2))
                (ppo_loss, value_loss, entropy_bonus) = component_losses
                logging.vlog(
                    1, "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->"
                    " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss,
                    combined_loss, ppo_loss, value_loss, entropy_bonus)

            if early_stopping:
                break

        optimization_time = ppo.get_time(optimization_start_time)

        logging.vlog(
            1, "Total Combined Loss reduction [%0.2f]%%",
            (100 *
             (cur_combined_loss - combined_loss) / np.abs(cur_combined_loss)))

        summaries.update({
            "n_optimizer_steps": opt_step,
            "approx_kl": approx_kl,
        })
        for (name, value) in summaries.items():
            self._train_sw.scalar("train/{}".format(name),
                                  value,
                                  step=self._epoch)

        # Save parameters every time we see the end of at least a fraction of batch
        # number of trajectories that are done (not completed -- completed includes
        # truncated and done).
        # Also don't save too frequently, enforce a minimum gap.
        # Or if this is the last iteration.
        policy_save_start_time = time.time()
        self._n_trajectories_done += n_done
        # TODO(afrozm): Refactor to trax.save_state.
        if ((self._n_trajectories_done >=
             self._done_frac_for_policy_save * self.train_env.batch_size)
                and (self._epoch - self._last_saved_at > self._eval_every_n)
                and (((self._epoch + 1) % self._eval_every_n == 0))):
            self.save()
        policy_save_time = ppo.get_time(policy_save_start_time)

        epoch_time = ppo.get_time(epoch_start_time)

        logging.info(
            "PPO epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined"
            " Loss(ppo, value, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]",
            self._epoch, min_reward, max_reward, avg_reward, combined_loss,
            ppo_loss, value_loss, entropy_bonus)

        timing_dict = {
            "epoch": epoch_time,
            "policy_eval": policy_eval_time,
            "trajectory_collection": trajectory_collection_time,
            "padding": padding_time,
            "log_prob_recompute": log_prob_recompute_time,
            "loss_compute": loss_compute_time,
            "optimization": optimization_time,
            "policy_save": policy_save_time,
        }

        timing_dict.update(timing_info)

        for k, v in timing_dict.items():
            self._timing_sw.scalar("timing/%s" % k, v, step=self._epoch)

        max_key_len = max(len(k) for k in timing_dict)
        timing_info_list = [
            "%s : % 10.2f" % (k.rjust(max_key_len + 1), v)
            for k, v in sorted(timing_dict.items())
        ]
        logging.info("PPO epoch [% 6d], Timings: \n%s", self._epoch,
                     "\n".join(timing_info_list))

        self._epoch += 1

        # Flush summary writers once in a while.
        if (self._epoch + 1) % 1000 == 0:
            self.flush_summaries()
Пример #3
0
    def test_pad_trajectories(self):
        observation_shape = (2, 3, 4)
        trajectories = []
        n_trajectories = 7
        n_actions = 10

        # Time-steps are between [min_allowable_time_step, max_allowable_time_step]
        max_allowable_time_step = 19
        min_allowable_time_step = 5

        # The actual max we see in the data.
        max_time_step = -1

        # Bucket length.
        bucket_length = 15

        # Make `n_trajectories` random trajectories.
        for i in range(n_trajectories):
            time_steps = np.random.randint(min_allowable_time_step,
                                           max_allowable_time_step + 1)
            if time_steps > max_time_step:
                max_time_step = time_steps
            observations = np.random.randint(
                0, 255,
                size=(time_steps + 1, ) + observation_shape).astype(np.uint8)
            rewards = np.random.uniform(size=(time_steps, )).astype(np.float32)
            actions = np.random.randint(0, n_actions,
                                        size=(time_steps, )).astype(np.int32)
            infos = {
                "a": np.random.uniform(size=(time_steps, )).astype(np.float32),
                "b": np.random.uniform(size=(time_steps, )).astype(np.float32)
            }
            trajectories.append((observations, rewards, actions, infos))

        # Now pad these trajectories.
        padded_trajectories = ppo.pad_trajectories(trajectories,
                                                   boundary=bucket_length)

        # Expected padding.
        i = 1
        while i * bucket_length < max_time_step:
            i += 1
        expected_padding = i * bucket_length

        # Get the padded objects.
        (pad_lengths, reward_mask, padded_observations, padded_actions,
         padded_rewards, padded_infos) = padded_trajectories

        # Expectations on the padded shapes.
        self.assertEqual(padded_observations.shape, (
            n_trajectories,
            expected_padding + 1,
        ) + observation_shape)
        self.assertEqual(padded_actions.shape,
                         (n_trajectories, expected_padding))
        self.assertEqual(padded_rewards.shape,
                         (n_trajectories, expected_padding))
        self.assertEqual(reward_mask.shape, (n_trajectories, expected_padding))

        self.assertEqual(padded_infos["a"].shape,
                         (n_trajectories, expected_padding))
        self.assertEqual(padded_infos["b"].shape,
                         (n_trajectories, expected_padding))

        # Assert that the padding lengths and reward mask are consistent.
        self.assertAllEqual(
            np.full((n_trajectories, ), expected_padding),
            np.array(np.sum(reward_mask, axis=1)) + pad_lengths)