示例#1
0
    def evaluate(self):
        """Evaluate the agent."""
        if not self._separate_eval:
            return

        logging.vlog(1, 'PolicyBasedTrainer epoch [% 6d]: evaluating policy.',
                     self.epoch)

        processed_reward_sums = collections.defaultdict(list)
        raw_reward_sums = collections.defaultdict(list)
        for _ in range(self._n_evals):
            for temperature in self._eval_temperatures:
                trajs, _, _, self._model_state = self.collect_trajectories(
                    train=False, temperature=temperature)

                processed_reward_sums[temperature].extend(
                    sum(traj[2]) for traj in trajs)
                raw_reward_sums[temperature].extend(
                    sum(traj[3]) for traj in trajs)

        # Return the mean and standard deviation for each temperature.
        def compute_stats(reward_dict):
            # pylint: disable=g-complex-comprehension
            return {
                temperature: {
                    'mean': onp.mean(rewards),
                    'std': onp.std(rewards)
                }
                for (temperature, rewards) in reward_dict.items()
            }
            # pylint: enable=g-complex-comprehension

        reward_stats = {
            'processed': compute_stats(processed_reward_sums),
            'raw': compute_stats(raw_reward_sums),
        }

        policy_based_utils.write_eval_reward_summaries(reward_stats,
                                                       self._log,
                                                       epoch=self.epoch)
示例#2
0
    def train_epoch(self, evaluate=True):
        """Train one PPO epoch."""
        epoch_start_time = time.time()

        # Evaluate the policy.
        policy_eval_start_time = time.time()
        if evaluate and (self.epoch + 1) % self._eval_every_n == 0:
            key = self._get_rng()
            self.evaluate()

        policy_eval_time = policy_based_utils.get_time(policy_eval_start_time)

        trajectory_collection_start_time = time.time()
        logging.vlog(1, 'PPO epoch [% 6d]: collecting trajectories.',
                     self.epoch)
        key = self._get_rng()
        trajs, _, timing_info, self._model_state = self.collect_trajectories(
            train=True, temperature=1.0)
        trajs = [(t[0], t[1], t[2], t[4]) for t in trajs]
        self._should_reset_train_env = False
        trajectory_collection_time = policy_based_utils.get_time(
            trajectory_collection_start_time)

        logging.vlog(1, 'Collecting trajectories took %0.2f msec.',
                     trajectory_collection_time)

        rewards = np.array([np.sum(traj[2]) for traj in trajs])
        avg_reward = np.mean(rewards)
        std_reward = np.std(rewards)
        max_reward = np.max(rewards)
        min_reward = np.min(rewards)

        self._log('train', 'train/reward_mean_truncated', avg_reward)
        if evaluate and not self._separate_eval:
            metrics = {'raw': {1.0: {'mean': avg_reward, 'std': std_reward}}}
            policy_based_utils.write_eval_reward_summaries(
                metrics, self._log, self.epoch)

        logging.vlog(1,
                     'Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s',
                     avg_reward, max_reward, min_reward,
                     [float(np.sum(traj[2])) for traj in trajs])

        logging.vlog(
            1, 'Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]',
            float(sum(len(traj[0]) for traj in trajs)) / len(trajs),
            max(len(traj[0]) for traj in trajs),
            min(len(traj[0]) for traj in trajs))
        logging.vlog(2, 'Trajectory Lengths: %s',
                     [len(traj[0]) for traj in trajs])

        preprocessing_start_time = time.time()
        (padded_observations, padded_actions, padded_rewards, reward_mask,
         padded_infos) = self._preprocess_trajectories(trajs)
        preprocessing_time = policy_based_utils.get_time(
            preprocessing_start_time)

        logging.vlog(1, 'Preprocessing trajectories took %0.2f msec.',
                     policy_based_utils.get_time(preprocessing_start_time))
        logging.vlog(1, 'Padded Observations\' shape [%s]',
                     str(padded_observations.shape))
        logging.vlog(1, 'Padded Actions\' shape [%s]',
                     str(padded_actions.shape))
        logging.vlog(1, 'Padded Rewards\' shape [%s]',
                     str(padded_rewards.shape))

        # Some assertions.
        (B, T) = reward_mask.shape  # pylint: disable=invalid-name
        assert (B, T) == padded_rewards.shape
        assert B == padded_observations.shape[0]

        log_prob_recompute_start_time = time.time()
        # TODO(pkozakowski): The following commented out code collects the network
        # predictions made while stepping the environment and uses them in PPO
        # training, so that we can use non-deterministic networks (e.g. with
        # dropout). This does not work well with serialization, so instead we
        # recompute all network predictions. Let's figure out a solution that will
        # work with both serialized sequences and non-deterministic networks.

        # assert ('log_prob_actions' in padded_infos and
        #         'value_predictions' in padded_infos)
        # These are the actual log-probabs and value predictions seen while picking
        # the actions.
        # actual_log_probabs_traj = padded_infos['log_prob_actions']
        # actual_value_predictions_traj = padded_infos['value_predictions']

        # assert (B, T, C) == actual_log_probabs_traj.shape[:3]
        # A = actual_log_probabs_traj.shape[3]  # pylint: disable=invalid-name
        # assert (B, T, 1) == actual_value_predictions_traj.shape

        del padded_infos

        # NOTE: We don't have the log-probabs and value-predictions for the last
        # observation, so we re-calculate for everything, but use the original ones
        # for all but the last time-step.
        key = self._get_rng()

        # TODO(pkozakowski): Pass the actual actions here, to enable autoregressive
        # action sampling.
        dummy_actions = np.zeros_like(padded_actions)
        (log_probabs_traj,
         value_predictions_traj) = (self._policy_and_value_net_apply(
             (padded_observations, dummy_actions),
             weights=self._policy_and_value_net_weights,
             state=self._model_state,
             rng=key,
         ))
        # Cut off the last extra action to obtain shape (B, T, C, A).
        log_probabs_traj_cut = log_probabs_traj[:, :-1]

        assert (B, T) == log_probabs_traj_cut.shape[:2]
        assert (B, T + 1) == value_predictions_traj.shape

        # TODO(pkozakowski): Commented out for the same reason as before.

        # Concatenate the last time-step's log-probabs and value predictions to the
        # actual log-probabs and value predictions and use those going forward.
        # log_probabs_traj = np.concatenate(
        #     (actual_log_probabs_traj, log_probabs_traj[:, -1:, :]), axis=1)
        # value_predictions_traj = np.concatenate(
        #     (actual_value_predictions_traj, value_predictions_traj[:, -1:, :]),
        #     axis=1)

        log_prob_recompute_time = policy_based_utils.get_time(
            log_prob_recompute_start_time)

        # Compute value and ppo losses.
        key1 = self._get_rng()
        logging.vlog(2, 'Starting to compute P&V loss.')
        loss_compute_start_time = time.time()
        (cur_combined_loss, component_losses, summaries,
         self._model_state) = (ppo.combined_loss(
             self._policy_and_value_net_weights,
             log_probabs_traj_cut,
             value_predictions_traj,
             self._policy_and_value_net_apply,
             padded_observations,
             padded_actions,
             padded_rewards,
             reward_mask,
             nontrainable_params=self._nontrainable_params,
             state=self._model_state,
             rng=key1))
        loss_compute_time = policy_based_utils.get_time(
            loss_compute_start_time)
        (cur_ppo_loss, cur_value_loss, cur_entropy_bonus) = component_losses
        logging.vlog(
            1,
            'Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.',
            cur_combined_loss, cur_ppo_loss, cur_value_loss, cur_entropy_bonus,
            policy_based_utils.get_time(loss_compute_start_time))

        key1 = self._get_rng()
        logging.vlog(1, 'Policy and Value Optimization')
        optimization_start_time = time.time()
        keys = jax_random.split(key1, num=self._n_optimizer_steps)
        opt_step = 0
        opt_batch_size = min(self._optimizer_batch_size, B)
        index_batches = ppo.shuffled_index_batches(dataset_size=B,
                                                   batch_size=opt_batch_size)
        for (index_batch, key) in zip(index_batches, keys):
            k1, k2, k3 = jax_random.split(key, num=3)
            t = time.time()
            # Update the optimizer state on the sampled minibatch.
            self._policy_and_value_opt_state, self._model_state = (
                ppo.policy_and_value_opt_step(
                    # We pass the optimizer slots between PPO epochs, so we need to
                    # pass the optimization step as well, so for example the
                    # bias-correction in Adam is calculated properly. Alternatively we
                    # could reset the slots and the step in every PPO epoch, but then
                    # the moment estimates in adaptive optimizers would never have
                    # enough time to warm up. So it makes sense to reuse the slots,
                    # even though we're optimizing a different loss in every new
                    # epoch.
                    self._total_opt_step,
                    self._policy_and_value_opt_state,
                    self._policy_and_value_opt_update,
                    self._policy_and_value_get_params,
                    self._policy_and_value_net_apply,
                    log_probabs_traj_cut[index_batch],
                    value_predictions_traj[index_batch],
                    padded_observations[index_batch],
                    padded_actions[index_batch],
                    padded_rewards[index_batch],
                    reward_mask[index_batch],
                    nontrainable_params=self._nontrainable_params,
                    state=self._model_state,
                    rng=k1))
            opt_step += 1
            self._total_opt_step += 1

            # Compute the approx KL for early stopping. Use the whole dataset - as we
            # only do inference, it should fit in the memory.
            # TODO(pkozakowski): Pass the actual actions here, to enable
            # autoregressive action sampling.
            dummy_actions = np.zeros_like(padded_actions)
            (log_probab_actions_new, _) = (self._policy_and_value_net_apply(
                (padded_observations, dummy_actions),
                weights=self._policy_and_value_net_weights,
                state=self._model_state,
                rng=k2))
            # Cut off the last extra action to obtain shape (B, T, C, A).
            log_probab_actions_new_cut = log_probab_actions_new[:, :-1]

            approx_kl = ppo.approximate_kl(log_probab_actions_new_cut,
                                           log_probabs_traj_cut, reward_mask)

            early_stopping = approx_kl > 1.5 * self._target_kl
            if early_stopping:
                logging.vlog(
                    1,
                    'Early stopping policy and value optimization after %d steps, '
                    'with approx_kl: %0.2f', opt_step, approx_kl)
                # We don't return right-away, we want the below to execute on the last
                # iteration.

            t2 = time.time()
            if (opt_step % self._print_every_optimizer_steps == 0
                    or opt_step == self._n_optimizer_steps or early_stopping):
                # Compute and log the loss.
                (combined_loss, component_losses, _,
                 self._model_state) = (ppo.combined_loss(
                     self._policy_and_value_net_weights,
                     log_probabs_traj_cut,
                     value_predictions_traj,
                     self._policy_and_value_net_apply,
                     padded_observations,
                     padded_actions,
                     padded_rewards,
                     reward_mask,
                     nontrainable_params=self._nontrainable_params,
                     state=self._model_state,
                     rng=k3))
                logging.vlog(
                    1, 'One Policy and Value grad desc took: %0.2f msec',
                    policy_based_utils.get_time(t, t2))
                (ppo_loss, value_loss, entropy_bonus) = component_losses
                logging.vlog(
                    1, 'Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->'
                    ' [%10.2f(%10.2f,%10.2f,%10.2f)]', cur_combined_loss,
                    combined_loss, ppo_loss, value_loss, entropy_bonus)

            if early_stopping:
                break

        optimization_time = policy_based_utils.get_time(
            optimization_start_time)

        logging.vlog(
            1, 'Total Combined Loss reduction [%0.2f]%%',
            (100 *
             (cur_combined_loss - combined_loss) / np.abs(cur_combined_loss)))

        summaries.update({
            'n_optimizer_steps': opt_step,
            'approx_kl': approx_kl,
        })
        for (name, value) in summaries.items():
            self._log('train', 'train/{}'.format(name), value)

        logging.info(
            'PPO epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined'
            ' Loss(ppo, value, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]',
            self.epoch, min_reward, max_reward, avg_reward, combined_loss,
            ppo_loss, value_loss, entropy_bonus)

        # Bump the epoch counter before saving a checkpoint, so that a call to
        # save() after the training loop is a no-op if a checkpoint was saved last
        # epoch - otherwise it would bump the epoch counter on the checkpoint.
        last_epoch = self.epoch
        self._epoch += 1

        # Save parameters every time we see the end of at least a fraction of batch
        # number of trajectories that are done (not completed -- completed includes
        # truncated and done).
        # Also don't save too frequently, enforce a minimum gap.
        policy_save_start_time = time.time()
        # TODO(afrozm): Refactor to trax.save_trainer_state.
        if (self._n_trajectories_done_since_last_save >=
                self._done_frac_for_policy_save * self.train_env.batch_size
                and self.epoch % self._save_every_n == 0) or self._async_mode:
            self.save()
        policy_save_time = policy_based_utils.get_time(policy_save_start_time)

        epoch_time = policy_based_utils.get_time(epoch_start_time)

        timing_dict = {
            'epoch': epoch_time,
            'policy_eval': policy_eval_time,
            'trajectory_collection': trajectory_collection_time,
            'preprocessing': preprocessing_time,
            'log_prob_recompute': log_prob_recompute_time,
            'loss_compute': loss_compute_time,
            'optimization': optimization_time,
            'policy_save': policy_save_time,
        }

        timing_dict.update(timing_info)

        if self._should_write_summaries:
            for k, v in timing_dict.items():
                self._timing_sw.scalar('timing/%s' % k, v, step=last_epoch)

        max_key_len = max(len(k) for k in timing_dict)
        timing_info_list = [
            '%s : % 10.2f' % (k.rjust(max_key_len + 1), v)
            for k, v in sorted(timing_dict.items())
        ]
        logging.info('PPO epoch [% 6d], Timings: \n%s', last_epoch,
                     '\n'.join(timing_info_list))

        # Flush summary writers once in a while.
        if self.epoch % 1000 == 0:
            self.flush_summaries()
示例#3
0
    def train_epoch(self, evaluate=True):
        epoch_start_time = time.time()

        # Evaluate the policy.
        policy_eval_start_time = time.time()
        if evaluate and (self.epoch + 1) % self._eval_every_n == 0:
            self.evaluate()
        policy_eval_time = policy_based_utils.get_time(policy_eval_start_time)

        def write_metric(key, value):
            self._train_sw.scalar(key, value, step=self.epoch)
            self._history.append('train', key, self.epoch, value)

        # Get fresh trajectories every time.
        self._should_reset_train_env = True

        trajectory_collection_start_time = time.time()
        logging.vlog(1, 'AWR epoch [% 6d]: collecting trajectories.',
                     self._epoch)
        trajs, _, timing_info, self._model_state = self.collect_trajectories(
            train=True, temperature=1.0, raw_trajectory=True)
        del timing_info
        trajectory_collection_time = policy_based_utils.get_time(
            trajectory_collection_start_time)

        logging.vlog(1, 'AWR epoch [% 6d]: n_trajectories [%s].', self._epoch,
                     len(trajs))

        # Convert these into numpy now.
        def extract_obs_act_rew_dones(traj_np):
            return traj_np[0], traj_np[1], traj_np[2], traj_np[4]

        trajs_np = [extract_obs_act_rew_dones(traj.as_numpy) for traj in trajs]

        # number of new actions.
        new_sample_count = sum(traj[1].shape[0] for traj in trajs_np)
        self._n_observations_seen += new_sample_count
        logging.vlog(1, 'AWR epoch [% 6d]: new_sample_count [%d].',
                     self._epoch, new_sample_count)

        if self._should_write_summaries:
            write_metric('trajs/batch', len(trajs))
            write_metric('trajs/new_sample_count', new_sample_count)

        # The number of trajectories, i.e. `B`can keep changing from iteration to
        # iteration, since we are capped on the number of observations requested.
        # So let's operate on each trajectory on this own?

        # TODO(afrozm): So should our batches look like (B, T+1, *OBS) or B
        # different examples of (T+1, *OBS) each. Since B can keep changing?

        # Add these to the replay buffer.
        for traj in trajs:
            _ = self._replay_buffer.store(traj)

        rewards = jnp.array([jnp.sum(traj[2]) for traj in trajs_np])
        avg_reward = jnp.mean(rewards)
        std_reward = jnp.std(rewards)
        max_reward = jnp.max(rewards)
        min_reward = jnp.min(rewards)

        self._log('train', 'train/reward_mean_truncated', avg_reward)
        if evaluate and not self._separate_eval and self._should_write_summaries:
            metrics = {'raw': {1.0: {'mean': avg_reward, 'std': std_reward}}}
            policy_based_utils.write_eval_reward_summaries(
                metrics, self._log, self.epoch)

        logging.vlog(
            1, 'AWR epoch [% 6d]: Rewards avg=[%0.2f], max=[%0.2f], '
            'min=[%0.2f].', self.epoch, avg_reward, max_reward, min_reward)

        if self._should_write_summaries:
            write_metric('reward/avg', avg_reward)
            write_metric('reward/std', std_reward)
            write_metric('reward/max', max_reward)
            write_metric('reward/min', min_reward)

        # Wrap these observations/rewards inside ReplayBuffer.
        idx, valid_mask, valid_idx = self._replay_buffer.get_valid_indices()

        # pylint: disable=g-complex-comprehension
        observations = [
            self._replay_buffer.get(
                replay_buffer.ReplayBuffer.OBSERVATIONS_KEY,
                idx[start_idx:end_plus_1_idx])
            for (start_idx,
                 end_plus_1_idx) in self._replay_buffer.iterate_over_paths(idx)
        ]

        rewards = [
            self._replay_buffer.get(replay_buffer.ReplayBuffer.REWARDS_KEY,
                                    idx[start_idx:end_plus_1_idx][:-1])
            for (start_idx,
                 end_plus_1_idx) in self._replay_buffer.iterate_over_paths(idx)
        ]
        # pylint: enable=g-complex-comprehension

        t_final = awr_utils.padding_length(rewards, boundary=self._boundary)
        logging.vlog(1, 'AWR epoch [% 6d]: t_final [%s].', self._epoch,
                     t_final)

        if self._should_write_summaries:
            write_metric('trajs/t_final', t_final)

        # These padded observations are over *all* the non-final observations in
        # the entire replay buffer.
        # Shapes:
        # padded_observations      = (B, T + 1, *OBS)
        # padded_observations_mask = (B, T + 1)
        padded_observations, padded_observations_mask = (
            awr_utils.pad_array_to_length(observations, t_final + 1))

        batch = len(observations)
        self._check_shapes('padded_observations',
                           '(batch, t_final + 1)',
                           padded_observations, (batch, t_final + 1),
                           array_prefix=2)
        self._check_shapes('padded_observations_mask', '(batch, t_final + 1)',
                           padded_observations_mask, (batch, t_final + 1))

        # Shapes:
        # padded_rewards      = (B, T)
        # padded_rewards_mask = (B, T)
        padded_rewards, padded_rewards_mask = awr_utils.pad_array_to_length(
            rewards, t_final)
        self._check_shapes('padded_rewards', '(batch, t_final)',
                           padded_rewards, (batch, t_final))
        self._check_shapes('padded_rewards_mask', '(batch, t_final)',
                           padded_rewards_mask, (batch, t_final))

        # Shapes:
        # lengths = (B,)
        lengths = jnp.sum(padded_rewards_mask, axis=1, dtype=jnp.int32)
        self._check_shapes('lengths', '(batch,)', lengths, (batch, ))

        # TODO(pkozakowski): Pass the actual actions here, to enable autoregressive
        # action sampling.
        dummy_actions = jnp.zeros(
            (batch, t_final + 1) + self._action_shape,
            self._action_dtype,
        )

        # Shapes:
        # log_probabs_traj       = (B, T + 1, #controls, #actions)
        # value_predictions_traj = (B, T + 1)
        log_probabs_traj, value_predictions_traj, self._model_state, unused_rng = (
            self._policy_fun_all_timesteps(padded_observations, lengths,
                                           self._model_state, self._get_rng()))
        self._check_shapes(
            'log_probabs_traj', '(batch, t_final + 1, n_controls, n_actions)',
            log_probabs_traj,
            (batch, t_final + 1, self._n_controls, self._n_actions))
        self._check_shapes('value_predictions_traj', '(batch, t_final + 1)',
                           value_predictions_traj, (batch, t_final + 1))

        # Zero out the padding's value predictions, since the net may give some
        # prediction to the padding observations.
        value_predictions_traj *= padded_observations_mask

        # Compute td-lambda returns, and reshape to match value_predictions_traj.
        list_td_lambda_returns = awr_utils.batched_compute_td_lambda_return(
            padded_rewards, padded_rewards_mask, value_predictions_traj,
            padded_observations_mask, self._gamma, self._td_lambda)

        if logging.vlog_is_on(1) and list_td_lambda_returns:
            l = len(list_td_lambda_returns)
            logging.vlog(1, f'Len of list_td_lambda_returns: {l}.')
            self._log_shape('td_lambda_returns[0]', list_td_lambda_returns[0])

        # pad an extra 0 for each to match lengths of value predictions.
        list_target_values = [
            np.pad(l, (0, 1), 'constant') for l in list_td_lambda_returns
        ]

        if batch != len(list_target_values):
            raise ValueError(f'batch != len(list_target_values) : '
                             f'{batch} vs {len(list_target_values)}')

        # Shape: (len(idx),)
        target_values = np.concatenate(list_target_values)
        self._check_shapes('target_values', '(len(idx),)', target_values,
                           (len(idx), ))

        # Shape: (len(idx),)
        vals = self.flatten_vals(value_predictions_traj,
                                 padded_observations_mask)
        self._check_shapes('vals', '(len(idx),)', vals, (len(idx), ))

        # Calculate advantages.
        adv, norm_adv, adv_mean, adv_std = self._calc_adv(
            target_values, vals, valid_mask)
        self._check_shapes('norm_adv', '(len(idx),)', norm_adv, (len(idx), ))

        adv_weights, adv_weights_mean, adv_weights_min, adv_weights_max = (
            self._calc_adv_weights(norm_adv, valid_mask))
        self._check_shapes('adv_weights', '(len(idx),)', adv_weights,
                           (len(idx), ))

        del adv, adv_mean, adv_std
        del adv_weights_min, adv_weights_max, adv_weights_mean

        combined_steps = int(
            jnp.ceil(self._optimization_steps * new_sample_count /
                     self._num_samples_to_collect))
        optimization_start_time = time.time()
        combined_losses = self._update_combined(combined_steps, valid_idx,
                                                target_values, adv_weights)
        optimization_time = policy_based_utils.get_time(
            optimization_start_time)

        self._epoch += 1

        if self._should_write_summaries:
            write_metric('combined/optimization_steps', combined_steps)
            epoch_time = policy_based_utils.get_time(epoch_start_time)
            timing_dict = {
                'epoch': epoch_time,
                'trajectory_collection': trajectory_collection_time,
                'optimization': optimization_time,
                'policy_eval': policy_eval_time,
            }

            if self._should_write_summaries:
                for k, v in timing_dict.items():
                    write_metric('timing/{}'.format(k), v)

            # Only dump the average post losses.
            if combined_losses:
                for k, v in combined_losses.items():
                    if 'post_entropy' in k:
                        write_metric(k.replace('post_entropy', 'entropy'), v)
                    if 'post_loss' in k:
                        write_metric(k.replace('post_loss', 'loss'), v)

        self.flush_summaries()