def evaluate(self): """Evaluate the agent.""" if not self._separate_eval: return logging.vlog(1, 'PolicyBasedTrainer epoch [% 6d]: evaluating policy.', self.epoch) processed_reward_sums = collections.defaultdict(list) raw_reward_sums = collections.defaultdict(list) for _ in range(self._n_evals): for temperature in self._eval_temperatures: trajs, _, _, self._model_state = self.collect_trajectories( train=False, temperature=temperature) processed_reward_sums[temperature].extend( sum(traj[2]) for traj in trajs) raw_reward_sums[temperature].extend( sum(traj[3]) for traj in trajs) # Return the mean and standard deviation for each temperature. def compute_stats(reward_dict): # pylint: disable=g-complex-comprehension return { temperature: { 'mean': onp.mean(rewards), 'std': onp.std(rewards) } for (temperature, rewards) in reward_dict.items() } # pylint: enable=g-complex-comprehension reward_stats = { 'processed': compute_stats(processed_reward_sums), 'raw': compute_stats(raw_reward_sums), } policy_based_utils.write_eval_reward_summaries(reward_stats, self._log, epoch=self.epoch)
def train_epoch(self, evaluate=True): """Train one PPO epoch.""" epoch_start_time = time.time() # Evaluate the policy. policy_eval_start_time = time.time() if evaluate and (self.epoch + 1) % self._eval_every_n == 0: key = self._get_rng() self.evaluate() policy_eval_time = policy_based_utils.get_time(policy_eval_start_time) trajectory_collection_start_time = time.time() logging.vlog(1, 'PPO epoch [% 6d]: collecting trajectories.', self.epoch) key = self._get_rng() trajs, _, timing_info, self._model_state = self.collect_trajectories( train=True, temperature=1.0) trajs = [(t[0], t[1], t[2], t[4]) for t in trajs] self._should_reset_train_env = False trajectory_collection_time = policy_based_utils.get_time( trajectory_collection_start_time) logging.vlog(1, 'Collecting trajectories took %0.2f msec.', trajectory_collection_time) rewards = np.array([np.sum(traj[2]) for traj in trajs]) avg_reward = np.mean(rewards) std_reward = np.std(rewards) max_reward = np.max(rewards) min_reward = np.min(rewards) self._log('train', 'train/reward_mean_truncated', avg_reward) if evaluate and not self._separate_eval: metrics = {'raw': {1.0: {'mean': avg_reward, 'std': std_reward}}} policy_based_utils.write_eval_reward_summaries( metrics, self._log, self.epoch) logging.vlog(1, 'Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s', avg_reward, max_reward, min_reward, [float(np.sum(traj[2])) for traj in trajs]) logging.vlog( 1, 'Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]', float(sum(len(traj[0]) for traj in trajs)) / len(trajs), max(len(traj[0]) for traj in trajs), min(len(traj[0]) for traj in trajs)) logging.vlog(2, 'Trajectory Lengths: %s', [len(traj[0]) for traj in trajs]) preprocessing_start_time = time.time() (padded_observations, padded_actions, padded_rewards, reward_mask, padded_infos) = self._preprocess_trajectories(trajs) preprocessing_time = policy_based_utils.get_time( preprocessing_start_time) logging.vlog(1, 'Preprocessing trajectories took %0.2f msec.', policy_based_utils.get_time(preprocessing_start_time)) logging.vlog(1, 'Padded Observations\' shape [%s]', str(padded_observations.shape)) logging.vlog(1, 'Padded Actions\' shape [%s]', str(padded_actions.shape)) logging.vlog(1, 'Padded Rewards\' shape [%s]', str(padded_rewards.shape)) # Some assertions. (B, T) = reward_mask.shape # pylint: disable=invalid-name assert (B, T) == padded_rewards.shape assert B == padded_observations.shape[0] log_prob_recompute_start_time = time.time() # TODO(pkozakowski): The following commented out code collects the network # predictions made while stepping the environment and uses them in PPO # training, so that we can use non-deterministic networks (e.g. with # dropout). This does not work well with serialization, so instead we # recompute all network predictions. Let's figure out a solution that will # work with both serialized sequences and non-deterministic networks. # assert ('log_prob_actions' in padded_infos and # 'value_predictions' in padded_infos) # These are the actual log-probabs and value predictions seen while picking # the actions. # actual_log_probabs_traj = padded_infos['log_prob_actions'] # actual_value_predictions_traj = padded_infos['value_predictions'] # assert (B, T, C) == actual_log_probabs_traj.shape[:3] # A = actual_log_probabs_traj.shape[3] # pylint: disable=invalid-name # assert (B, T, 1) == actual_value_predictions_traj.shape del padded_infos # NOTE: We don't have the log-probabs and value-predictions for the last # observation, so we re-calculate for everything, but use the original ones # for all but the last time-step. key = self._get_rng() # TODO(pkozakowski): Pass the actual actions here, to enable autoregressive # action sampling. dummy_actions = np.zeros_like(padded_actions) (log_probabs_traj, value_predictions_traj) = (self._policy_and_value_net_apply( (padded_observations, dummy_actions), weights=self._policy_and_value_net_weights, state=self._model_state, rng=key, )) # Cut off the last extra action to obtain shape (B, T, C, A). log_probabs_traj_cut = log_probabs_traj[:, :-1] assert (B, T) == log_probabs_traj_cut.shape[:2] assert (B, T + 1) == value_predictions_traj.shape # TODO(pkozakowski): Commented out for the same reason as before. # Concatenate the last time-step's log-probabs and value predictions to the # actual log-probabs and value predictions and use those going forward. # log_probabs_traj = np.concatenate( # (actual_log_probabs_traj, log_probabs_traj[:, -1:, :]), axis=1) # value_predictions_traj = np.concatenate( # (actual_value_predictions_traj, value_predictions_traj[:, -1:, :]), # axis=1) log_prob_recompute_time = policy_based_utils.get_time( log_prob_recompute_start_time) # Compute value and ppo losses. key1 = self._get_rng() logging.vlog(2, 'Starting to compute P&V loss.') loss_compute_start_time = time.time() (cur_combined_loss, component_losses, summaries, self._model_state) = (ppo.combined_loss( self._policy_and_value_net_weights, log_probabs_traj_cut, value_predictions_traj, self._policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, nontrainable_params=self._nontrainable_params, state=self._model_state, rng=key1)) loss_compute_time = policy_based_utils.get_time( loss_compute_start_time) (cur_ppo_loss, cur_value_loss, cur_entropy_bonus) = component_losses logging.vlog( 1, 'Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.', cur_combined_loss, cur_ppo_loss, cur_value_loss, cur_entropy_bonus, policy_based_utils.get_time(loss_compute_start_time)) key1 = self._get_rng() logging.vlog(1, 'Policy and Value Optimization') optimization_start_time = time.time() keys = jax_random.split(key1, num=self._n_optimizer_steps) opt_step = 0 opt_batch_size = min(self._optimizer_batch_size, B) index_batches = ppo.shuffled_index_batches(dataset_size=B, batch_size=opt_batch_size) for (index_batch, key) in zip(index_batches, keys): k1, k2, k3 = jax_random.split(key, num=3) t = time.time() # Update the optimizer state on the sampled minibatch. self._policy_and_value_opt_state, self._model_state = ( ppo.policy_and_value_opt_step( # We pass the optimizer slots between PPO epochs, so we need to # pass the optimization step as well, so for example the # bias-correction in Adam is calculated properly. Alternatively we # could reset the slots and the step in every PPO epoch, but then # the moment estimates in adaptive optimizers would never have # enough time to warm up. So it makes sense to reuse the slots, # even though we're optimizing a different loss in every new # epoch. self._total_opt_step, self._policy_and_value_opt_state, self._policy_and_value_opt_update, self._policy_and_value_get_params, self._policy_and_value_net_apply, log_probabs_traj_cut[index_batch], value_predictions_traj[index_batch], padded_observations[index_batch], padded_actions[index_batch], padded_rewards[index_batch], reward_mask[index_batch], nontrainable_params=self._nontrainable_params, state=self._model_state, rng=k1)) opt_step += 1 self._total_opt_step += 1 # Compute the approx KL for early stopping. Use the whole dataset - as we # only do inference, it should fit in the memory. # TODO(pkozakowski): Pass the actual actions here, to enable # autoregressive action sampling. dummy_actions = np.zeros_like(padded_actions) (log_probab_actions_new, _) = (self._policy_and_value_net_apply( (padded_observations, dummy_actions), weights=self._policy_and_value_net_weights, state=self._model_state, rng=k2)) # Cut off the last extra action to obtain shape (B, T, C, A). log_probab_actions_new_cut = log_probab_actions_new[:, :-1] approx_kl = ppo.approximate_kl(log_probab_actions_new_cut, log_probabs_traj_cut, reward_mask) early_stopping = approx_kl > 1.5 * self._target_kl if early_stopping: logging.vlog( 1, 'Early stopping policy and value optimization after %d steps, ' 'with approx_kl: %0.2f', opt_step, approx_kl) # We don't return right-away, we want the below to execute on the last # iteration. t2 = time.time() if (opt_step % self._print_every_optimizer_steps == 0 or opt_step == self._n_optimizer_steps or early_stopping): # Compute and log the loss. (combined_loss, component_losses, _, self._model_state) = (ppo.combined_loss( self._policy_and_value_net_weights, log_probabs_traj_cut, value_predictions_traj, self._policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, nontrainable_params=self._nontrainable_params, state=self._model_state, rng=k3)) logging.vlog( 1, 'One Policy and Value grad desc took: %0.2f msec', policy_based_utils.get_time(t, t2)) (ppo_loss, value_loss, entropy_bonus) = component_losses logging.vlog( 1, 'Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->' ' [%10.2f(%10.2f,%10.2f,%10.2f)]', cur_combined_loss, combined_loss, ppo_loss, value_loss, entropy_bonus) if early_stopping: break optimization_time = policy_based_utils.get_time( optimization_start_time) logging.vlog( 1, 'Total Combined Loss reduction [%0.2f]%%', (100 * (cur_combined_loss - combined_loss) / np.abs(cur_combined_loss))) summaries.update({ 'n_optimizer_steps': opt_step, 'approx_kl': approx_kl, }) for (name, value) in summaries.items(): self._log('train', 'train/{}'.format(name), value) logging.info( 'PPO epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined' ' Loss(ppo, value, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]', self.epoch, min_reward, max_reward, avg_reward, combined_loss, ppo_loss, value_loss, entropy_bonus) # Bump the epoch counter before saving a checkpoint, so that a call to # save() after the training loop is a no-op if a checkpoint was saved last # epoch - otherwise it would bump the epoch counter on the checkpoint. last_epoch = self.epoch self._epoch += 1 # Save parameters every time we see the end of at least a fraction of batch # number of trajectories that are done (not completed -- completed includes # truncated and done). # Also don't save too frequently, enforce a minimum gap. policy_save_start_time = time.time() # TODO(afrozm): Refactor to trax.save_trainer_state. if (self._n_trajectories_done_since_last_save >= self._done_frac_for_policy_save * self.train_env.batch_size and self.epoch % self._save_every_n == 0) or self._async_mode: self.save() policy_save_time = policy_based_utils.get_time(policy_save_start_time) epoch_time = policy_based_utils.get_time(epoch_start_time) timing_dict = { 'epoch': epoch_time, 'policy_eval': policy_eval_time, 'trajectory_collection': trajectory_collection_time, 'preprocessing': preprocessing_time, 'log_prob_recompute': log_prob_recompute_time, 'loss_compute': loss_compute_time, 'optimization': optimization_time, 'policy_save': policy_save_time, } timing_dict.update(timing_info) if self._should_write_summaries: for k, v in timing_dict.items(): self._timing_sw.scalar('timing/%s' % k, v, step=last_epoch) max_key_len = max(len(k) for k in timing_dict) timing_info_list = [ '%s : % 10.2f' % (k.rjust(max_key_len + 1), v) for k, v in sorted(timing_dict.items()) ] logging.info('PPO epoch [% 6d], Timings: \n%s', last_epoch, '\n'.join(timing_info_list)) # Flush summary writers once in a while. if self.epoch % 1000 == 0: self.flush_summaries()
def train_epoch(self, evaluate=True): epoch_start_time = time.time() # Evaluate the policy. policy_eval_start_time = time.time() if evaluate and (self.epoch + 1) % self._eval_every_n == 0: self.evaluate() policy_eval_time = policy_based_utils.get_time(policy_eval_start_time) def write_metric(key, value): self._train_sw.scalar(key, value, step=self.epoch) self._history.append('train', key, self.epoch, value) # Get fresh trajectories every time. self._should_reset_train_env = True trajectory_collection_start_time = time.time() logging.vlog(1, 'AWR epoch [% 6d]: collecting trajectories.', self._epoch) trajs, _, timing_info, self._model_state = self.collect_trajectories( train=True, temperature=1.0, raw_trajectory=True) del timing_info trajectory_collection_time = policy_based_utils.get_time( trajectory_collection_start_time) logging.vlog(1, 'AWR epoch [% 6d]: n_trajectories [%s].', self._epoch, len(trajs)) # Convert these into numpy now. def extract_obs_act_rew_dones(traj_np): return traj_np[0], traj_np[1], traj_np[2], traj_np[4] trajs_np = [extract_obs_act_rew_dones(traj.as_numpy) for traj in trajs] # number of new actions. new_sample_count = sum(traj[1].shape[0] for traj in trajs_np) self._n_observations_seen += new_sample_count logging.vlog(1, 'AWR epoch [% 6d]: new_sample_count [%d].', self._epoch, new_sample_count) if self._should_write_summaries: write_metric('trajs/batch', len(trajs)) write_metric('trajs/new_sample_count', new_sample_count) # The number of trajectories, i.e. `B`can keep changing from iteration to # iteration, since we are capped on the number of observations requested. # So let's operate on each trajectory on this own? # TODO(afrozm): So should our batches look like (B, T+1, *OBS) or B # different examples of (T+1, *OBS) each. Since B can keep changing? # Add these to the replay buffer. for traj in trajs: _ = self._replay_buffer.store(traj) rewards = jnp.array([jnp.sum(traj[2]) for traj in trajs_np]) avg_reward = jnp.mean(rewards) std_reward = jnp.std(rewards) max_reward = jnp.max(rewards) min_reward = jnp.min(rewards) self._log('train', 'train/reward_mean_truncated', avg_reward) if evaluate and not self._separate_eval and self._should_write_summaries: metrics = {'raw': {1.0: {'mean': avg_reward, 'std': std_reward}}} policy_based_utils.write_eval_reward_summaries( metrics, self._log, self.epoch) logging.vlog( 1, 'AWR epoch [% 6d]: Rewards avg=[%0.2f], max=[%0.2f], ' 'min=[%0.2f].', self.epoch, avg_reward, max_reward, min_reward) if self._should_write_summaries: write_metric('reward/avg', avg_reward) write_metric('reward/std', std_reward) write_metric('reward/max', max_reward) write_metric('reward/min', min_reward) # Wrap these observations/rewards inside ReplayBuffer. idx, valid_mask, valid_idx = self._replay_buffer.get_valid_indices() # pylint: disable=g-complex-comprehension observations = [ self._replay_buffer.get( replay_buffer.ReplayBuffer.OBSERVATIONS_KEY, idx[start_idx:end_plus_1_idx]) for (start_idx, end_plus_1_idx) in self._replay_buffer.iterate_over_paths(idx) ] rewards = [ self._replay_buffer.get(replay_buffer.ReplayBuffer.REWARDS_KEY, idx[start_idx:end_plus_1_idx][:-1]) for (start_idx, end_plus_1_idx) in self._replay_buffer.iterate_over_paths(idx) ] # pylint: enable=g-complex-comprehension t_final = awr_utils.padding_length(rewards, boundary=self._boundary) logging.vlog(1, 'AWR epoch [% 6d]: t_final [%s].', self._epoch, t_final) if self._should_write_summaries: write_metric('trajs/t_final', t_final) # These padded observations are over *all* the non-final observations in # the entire replay buffer. # Shapes: # padded_observations = (B, T + 1, *OBS) # padded_observations_mask = (B, T + 1) padded_observations, padded_observations_mask = ( awr_utils.pad_array_to_length(observations, t_final + 1)) batch = len(observations) self._check_shapes('padded_observations', '(batch, t_final + 1)', padded_observations, (batch, t_final + 1), array_prefix=2) self._check_shapes('padded_observations_mask', '(batch, t_final + 1)', padded_observations_mask, (batch, t_final + 1)) # Shapes: # padded_rewards = (B, T) # padded_rewards_mask = (B, T) padded_rewards, padded_rewards_mask = awr_utils.pad_array_to_length( rewards, t_final) self._check_shapes('padded_rewards', '(batch, t_final)', padded_rewards, (batch, t_final)) self._check_shapes('padded_rewards_mask', '(batch, t_final)', padded_rewards_mask, (batch, t_final)) # Shapes: # lengths = (B,) lengths = jnp.sum(padded_rewards_mask, axis=1, dtype=jnp.int32) self._check_shapes('lengths', '(batch,)', lengths, (batch, )) # TODO(pkozakowski): Pass the actual actions here, to enable autoregressive # action sampling. dummy_actions = jnp.zeros( (batch, t_final + 1) + self._action_shape, self._action_dtype, ) # Shapes: # log_probabs_traj = (B, T + 1, #controls, #actions) # value_predictions_traj = (B, T + 1) log_probabs_traj, value_predictions_traj, self._model_state, unused_rng = ( self._policy_fun_all_timesteps(padded_observations, lengths, self._model_state, self._get_rng())) self._check_shapes( 'log_probabs_traj', '(batch, t_final + 1, n_controls, n_actions)', log_probabs_traj, (batch, t_final + 1, self._n_controls, self._n_actions)) self._check_shapes('value_predictions_traj', '(batch, t_final + 1)', value_predictions_traj, (batch, t_final + 1)) # Zero out the padding's value predictions, since the net may give some # prediction to the padding observations. value_predictions_traj *= padded_observations_mask # Compute td-lambda returns, and reshape to match value_predictions_traj. list_td_lambda_returns = awr_utils.batched_compute_td_lambda_return( padded_rewards, padded_rewards_mask, value_predictions_traj, padded_observations_mask, self._gamma, self._td_lambda) if logging.vlog_is_on(1) and list_td_lambda_returns: l = len(list_td_lambda_returns) logging.vlog(1, f'Len of list_td_lambda_returns: {l}.') self._log_shape('td_lambda_returns[0]', list_td_lambda_returns[0]) # pad an extra 0 for each to match lengths of value predictions. list_target_values = [ np.pad(l, (0, 1), 'constant') for l in list_td_lambda_returns ] if batch != len(list_target_values): raise ValueError(f'batch != len(list_target_values) : ' f'{batch} vs {len(list_target_values)}') # Shape: (len(idx),) target_values = np.concatenate(list_target_values) self._check_shapes('target_values', '(len(idx),)', target_values, (len(idx), )) # Shape: (len(idx),) vals = self.flatten_vals(value_predictions_traj, padded_observations_mask) self._check_shapes('vals', '(len(idx),)', vals, (len(idx), )) # Calculate advantages. adv, norm_adv, adv_mean, adv_std = self._calc_adv( target_values, vals, valid_mask) self._check_shapes('norm_adv', '(len(idx),)', norm_adv, (len(idx), )) adv_weights, adv_weights_mean, adv_weights_min, adv_weights_max = ( self._calc_adv_weights(norm_adv, valid_mask)) self._check_shapes('adv_weights', '(len(idx),)', adv_weights, (len(idx), )) del adv, adv_mean, adv_std del adv_weights_min, adv_weights_max, adv_weights_mean combined_steps = int( jnp.ceil(self._optimization_steps * new_sample_count / self._num_samples_to_collect)) optimization_start_time = time.time() combined_losses = self._update_combined(combined_steps, valid_idx, target_values, adv_weights) optimization_time = policy_based_utils.get_time( optimization_start_time) self._epoch += 1 if self._should_write_summaries: write_metric('combined/optimization_steps', combined_steps) epoch_time = policy_based_utils.get_time(epoch_start_time) timing_dict = { 'epoch': epoch_time, 'trajectory_collection': trajectory_collection_time, 'optimization': optimization_time, 'policy_eval': policy_eval_time, } if self._should_write_summaries: for k, v in timing_dict.items(): write_metric('timing/{}'.format(k), v) # Only dump the average post losses. if combined_losses: for k, v in combined_losses.items(): if 'post_entropy' in k: write_metric(k.replace('post_entropy', 'entropy'), v) if 'post_loss' in k: write_metric(k.replace('post_loss', 'loss'), v) self.flush_summaries()