def train_epoch(self): """Train one PPO epoch.""" epoch_start_time = time.time() # Evaluate the policy. policy_eval_start_time = time.time() if (self._epoch + 1) % self._eval_every_n == 0: self._rng, key = jax_random.split(self._rng, num=2) self.evaluate() policy_eval_time = ppo.get_time(policy_eval_start_time) trajectory_collection_start_time = time.time() logging.vlog(1, "PPO epoch [% 6d]: collecting trajectories.", self._epoch) self._rng, key = jax_random.split(self._rng) trajs, n_done, timing_info, self._model_state = ppo.collect_trajectories( self.train_env, policy_fn=self._get_predictions, n_trajectories=self.train_env.batch_size, max_timestep=self._max_timestep, state=self._model_state, rng=key, len_history_for_policy=self._len_history_for_policy, boundary=self._boundary, reset=self._should_reset, ) self._should_reset = False trajectory_collection_time = ppo.get_time( trajectory_collection_start_time) logging.vlog(1, "Collecting trajectories took %0.2f msec.", trajectory_collection_time) avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs) max_reward = max(np.sum(traj[2]) for traj in trajs) min_reward = min(np.sum(traj[2]) for traj in trajs) self._train_sw.scalar("train/reward_mean_truncated", avg_reward, step=self._epoch) logging.vlog(1, "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s", avg_reward, max_reward, min_reward, [float(np.sum(traj[2])) for traj in trajs]) logging.vlog( 1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]", float(sum(len(traj[0]) for traj in trajs)) / len(trajs), max(len(traj[0]) for traj in trajs), min(len(traj[0]) for traj in trajs)) logging.vlog(2, "Trajectory Lengths: %s", [len(traj[0]) for traj in trajs]) padding_start_time = time.time() (_, reward_mask, padded_observations, padded_actions, padded_rewards, padded_infos) = ppo.pad_trajectories(trajs, boundary=self._boundary) padding_time = ppo.get_time(padding_start_time) logging.vlog(1, "Padding trajectories took %0.2f msec.", ppo.get_time(padding_start_time)) logging.vlog(1, "Padded Observations' shape [%s]", str(padded_observations.shape)) logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape)) logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape)) # Some assertions. B, T = padded_actions.shape # pylint: disable=invalid-name assert (B, T) == padded_rewards.shape assert (B, T) == reward_mask.shape assert (B, T + 1) == padded_observations.shape[:2] assert ((B, T + 1) + self.train_env.observation_space.shape == padded_observations.shape) log_prob_recompute_start_time = time.time() assert ("log_prob_actions" in padded_infos and "value_predictions" in padded_infos) # These are the actual log-probabs and value predictions seen while picking # the actions. actual_log_probabs_traj = padded_infos["log_prob_actions"] actual_value_predictions_traj = padded_infos["value_predictions"] assert (B, T) == actual_log_probabs_traj.shape[:2] A = actual_log_probabs_traj.shape[2] # pylint: disable=invalid-name assert (B, T, 1) == actual_value_predictions_traj.shape # TODO(afrozm): log-probabs doesn't need to be (B, T+1, A) it can do with # (B, T, A), so make that change throughout. # NOTE: We don't have the log-probabs and value-predictions for the last # observation, so we re-calculate for everything, but use the original ones # for all but the last time-step. self._rng, key = jax_random.split(self._rng) log_probabs_traj, value_predictions_traj, self._model_state, _ = ( self._get_predictions(padded_observations, self._model_state, rng=key)) assert (B, T + 1, A) == log_probabs_traj.shape assert (B, T + 1, 1) == value_predictions_traj.shape # Concatenate the last time-step's log-probabs and value predictions to the # actual log-probabs and value predictions and use those going forward. log_probabs_traj = np.concatenate( (actual_log_probabs_traj, log_probabs_traj[:, -1:, :]), axis=1) value_predictions_traj = np.concatenate( (actual_value_predictions_traj, value_predictions_traj[:, -1:, :]), axis=1) log_prob_recompute_time = ppo.get_time(log_prob_recompute_start_time) # Compute value and ppo losses. self._rng, key1 = jax_random.split(self._rng, num=2) logging.vlog(2, "Starting to compute P&V loss.") loss_compute_start_time = time.time() (cur_combined_loss, component_losses, summaries, self._model_state) = (ppo.combined_loss( self._policy_and_value_net_params, log_probabs_traj, value_predictions_traj, self._policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=self._gamma, lambda_=self._lambda_, c1=self._c1, c2=self._c2, state=self._model_state, rng=key1)) loss_compute_time = ppo.get_time(loss_compute_start_time) (cur_ppo_loss, cur_value_loss, cur_entropy_bonus) = component_losses logging.vlog( 1, "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.", cur_combined_loss, cur_ppo_loss, cur_value_loss, cur_entropy_bonus, ppo.get_time(loss_compute_start_time)) self._rng, key1 = jax_random.split(self._rng, num=2) logging.vlog(1, "Policy and Value Optimization") optimization_start_time = time.time() keys = jax_random.split(key1, num=self._n_optimizer_steps) opt_step = 0 for key in keys: k1, k2, k3 = jax_random.split(key, num=3) t = time.time() # Update the optimizer state. self._policy_and_value_opt_state, self._model_state = ( ppo.policy_and_value_opt_step( # We pass the optimizer slots between PPO epochs, so we need to # pass the optimization step as well, so for example the # bias-correction in Adam is calculated properly. Alternatively we # could reset the slots and the step in every PPO epoch, but then # the moment estimates in adaptive optimizers would never have # enough time to warm up. So it makes sense to reuse the slots, # even though we're optimizing a different loss in every new # epoch. self._total_opt_step, self._policy_and_value_opt_state, self._policy_and_value_opt_update, self._policy_and_value_get_params, self._policy_and_value_net_apply, log_probabs_traj, value_predictions_traj, padded_observations, padded_actions, padded_rewards, reward_mask, c1=self._c1, c2=self._c2, gamma=self._gamma, lambda_=self._lambda_, state=self._model_state, rng=k1)) opt_step += 1 self._total_opt_step += 1 # Compute the approx KL for early stopping. (log_probab_actions_new, _), self._model_state = (self._policy_and_value_net_apply( padded_observations, self._policy_and_value_net_params, self._model_state, rng=k2)) approx_kl = ppo.approximate_kl(log_probab_actions_new, log_probabs_traj, reward_mask) early_stopping = approx_kl > 1.5 * self._target_kl if early_stopping: logging.vlog( 1, "Early stopping policy and value optimization after %d steps, " "with approx_kl: %0.2f", opt_step, approx_kl) # We don't return right-away, we want the below to execute on the last # iteration. t2 = time.time() if (opt_step % self._print_every_optimizer_steps == 0 or opt_step == self._n_optimizer_steps or early_stopping): # Compute and log the loss. (combined_loss, component_losses, _, self._model_state) = (ppo.combined_loss( self._policy_and_value_net_params, log_probabs_traj, value_predictions_traj, self._policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=self._gamma, lambda_=self._lambda_, c1=self._c1, c2=self._c2, state=self._model_state, rng=k3)) logging.vlog( 1, "One Policy and Value grad desc took: %0.2f msec", ppo.get_time(t, t2)) (ppo_loss, value_loss, entropy_bonus) = component_losses logging.vlog( 1, "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->" " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss, combined_loss, ppo_loss, value_loss, entropy_bonus) if early_stopping: break optimization_time = ppo.get_time(optimization_start_time) logging.vlog( 1, "Total Combined Loss reduction [%0.2f]%%", (100 * (cur_combined_loss - combined_loss) / np.abs(cur_combined_loss))) summaries.update({ "n_optimizer_steps": opt_step, "approx_kl": approx_kl, }) for (name, value) in summaries.items(): self._train_sw.scalar("train/{}".format(name), value, step=self._epoch) # Save parameters every time we see the end of at least a fraction of batch # number of trajectories that are done (not completed -- completed includes # truncated and done). # Also don't save too frequently, enforce a minimum gap. # Or if this is the last iteration. policy_save_start_time = time.time() self._n_trajectories_done += n_done # TODO(afrozm): Refactor to trax.save_state. if ((self._n_trajectories_done >= self._done_frac_for_policy_save * self.train_env.batch_size) and (self._epoch - self._last_saved_at > self._eval_every_n) and (((self._epoch + 1) % self._eval_every_n == 0))): self.save() policy_save_time = ppo.get_time(policy_save_start_time) epoch_time = ppo.get_time(epoch_start_time) logging.info( "PPO epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined" " Loss(ppo, value, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]", self._epoch, min_reward, max_reward, avg_reward, combined_loss, ppo_loss, value_loss, entropy_bonus) timing_dict = { "epoch": epoch_time, "policy_eval": policy_eval_time, "trajectory_collection": trajectory_collection_time, "padding": padding_time, "log_prob_recompute": log_prob_recompute_time, "loss_compute": loss_compute_time, "optimization": optimization_time, "policy_save": policy_save_time, } timing_dict.update(timing_info) for k, v in timing_dict.items(): self._timing_sw.scalar("timing/%s" % k, v, step=self._epoch) max_key_len = max(len(k) for k in timing_dict) timing_info_list = [ "%s : % 10.2f" % (k.rjust(max_key_len + 1), v) for k, v in sorted(timing_dict.items()) ] logging.info("PPO epoch [% 6d], Timings: \n%s", self._epoch, "\n".join(timing_info_list)) self._epoch += 1 # Flush summary writers once in a while. if (self._epoch + 1) % 1000 == 0: self.flush_summaries()
def train_epoch(self, evaluate=True): """Train one PPO epoch.""" epoch_start_time = time.time() # Evaluate the policy. policy_eval_start_time = time.time() if evaluate and (self._epoch + 1) % self._eval_every_n == 0: self._rng, key = jax_random.split(self._rng, num=2) self.evaluate() policy_eval_time = ppo.get_time(policy_eval_start_time) trajectory_collection_start_time = time.time() logging.vlog(1, "PPO epoch [% 6d]: collecting trajectories.", self._epoch) self._rng, key = jax_random.split(self._rng) trajs, _, timing_info, self._model_state = self.collect_trajectories( train=True, temperature=1.0) trajs = [(t[0], t[1], t[2], t[4]) for t in trajs] self._should_reset = False trajectory_collection_time = ppo.get_time( trajectory_collection_start_time) logging.vlog(1, "Collecting trajectories took %0.2f msec.", trajectory_collection_time) rewards = np.array([np.sum(traj[2]) for traj in trajs]) avg_reward = np.mean(rewards) std_reward = np.std(rewards) max_reward = np.max(rewards) min_reward = np.min(rewards) self._train_sw.scalar("train/reward_mean_truncated", avg_reward, step=self._epoch) if evaluate and not self._separate_eval: metrics = {"raw": {1.0: {"mean": avg_reward, "std": std_reward}}} ppo.write_eval_reward_summaries(metrics, self._eval_sw, self._epoch) logging.vlog(1, "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s", avg_reward, max_reward, min_reward, [float(np.sum(traj[2])) for traj in trajs]) logging.vlog( 1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]", float(sum(len(traj[0]) for traj in trajs)) / len(trajs), max(len(traj[0]) for traj in trajs), min(len(traj[0]) for traj in trajs)) logging.vlog(2, "Trajectory Lengths: %s", [len(traj[0]) for traj in trajs]) preprocessing_start_time = time.time() (padded_observations, padded_actions, padded_rewards, reward_mask, padded_infos) = self._preprocess_trajectories(trajs) preprocessing_time = ppo.get_time(preprocessing_start_time) logging.vlog(1, "Preprocessing trajectories took %0.2f msec.", ppo.get_time(preprocessing_start_time)) logging.vlog(1, "Padded Observations' shape [%s]", str(padded_observations.shape)) logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape)) logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape)) # Some assertions. B, RT = padded_rewards.shape # pylint: disable=invalid-name B, AT = padded_actions.shape # pylint: disable=invalid-name assert (B, RT) == reward_mask.shape assert B == padded_observations.shape[0] log_prob_recompute_start_time = time.time() # TODO(pkozakowski): The following commented out code collects the network # predictions made while stepping the environment and uses them in PPO # training, so that we can use non-deterministic networks (e.g. with # dropout). This does not work well with serialization, so instead we # recompute all network predictions. Let's figure out a solution that will # work with both serialized sequences and non-deterministic networks. # assert ("log_prob_actions" in padded_infos and # "value_predictions" in padded_infos) # These are the actual log-probabs and value predictions seen while picking # the actions. # actual_log_probabs_traj = padded_infos["log_prob_actions"] # actual_value_predictions_traj = padded_infos["value_predictions"] # assert (B, T, C) == actual_log_probabs_traj.shape[:3] # A = actual_log_probabs_traj.shape[3] # pylint: disable=invalid-name # assert (B, T, 1) == actual_value_predictions_traj.shape del padded_infos # TODO(afrozm): log-probabs doesn't need to be (B, T+1, C, A) it can do with # (B, T, C, A), so make that change throughout. # NOTE: We don't have the log-probabs and value-predictions for the last # observation, so we re-calculate for everything, but use the original ones # for all but the last time-step. self._rng, key = jax_random.split(self._rng) log_probabs_traj, value_predictions_traj, self._model_state, _ = ( self._get_predictions(padded_observations, self._model_state, rng=key)) assert (B, AT) == log_probabs_traj.shape[:2] assert (B, AT) == value_predictions_traj.shape # TODO(pkozakowski): Commented out for the same reason as before. # Concatenate the last time-step's log-probabs and value predictions to the # actual log-probabs and value predictions and use those going forward. # log_probabs_traj = np.concatenate( # (actual_log_probabs_traj, log_probabs_traj[:, -1:, :]), axis=1) # value_predictions_traj = np.concatenate( # (actual_value_predictions_traj, value_predictions_traj[:, -1:, :]), # axis=1) log_prob_recompute_time = ppo.get_time(log_prob_recompute_start_time) # Compute value and ppo losses. self._rng, key1 = jax_random.split(self._rng, num=2) logging.vlog(2, "Starting to compute P&V loss.") loss_compute_start_time = time.time() (cur_combined_loss, component_losses, summaries, self._model_state) = (ppo.combined_loss( self._policy_and_value_net_params, log_probabs_traj, value_predictions_traj, self._policy_and_value_net_apply, padded_observations, padded_actions, self._rewards_to_actions, padded_rewards, reward_mask, gamma=self._gamma, lambda_=self._lambda_, c1=self._c1, c2=self._c2, state=self._model_state, rng=key1)) loss_compute_time = ppo.get_time(loss_compute_start_time) (cur_ppo_loss, cur_value_loss, cur_entropy_bonus) = component_losses logging.vlog( 1, "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.", cur_combined_loss, cur_ppo_loss, cur_value_loss, cur_entropy_bonus, ppo.get_time(loss_compute_start_time)) self._rng, key1 = jax_random.split(self._rng, num=2) logging.vlog(1, "Policy and Value Optimization") optimization_start_time = time.time() keys = jax_random.split(key1, num=self._n_optimizer_steps) opt_step = 0 opt_batch_size = min(self._optimizer_batch_size, B) index_batches = ppo.shuffled_index_batches(dataset_size=B, batch_size=opt_batch_size) for (index_batch, key) in zip(index_batches, keys): k1, k2, k3 = jax_random.split(key, num=3) t = time.time() # Update the optimizer state on the sampled minibatch. self._policy_and_value_opt_state, self._model_state = ( ppo.policy_and_value_opt_step( # We pass the optimizer slots between PPO epochs, so we need to # pass the optimization step as well, so for example the # bias-correction in Adam is calculated properly. Alternatively we # could reset the slots and the step in every PPO epoch, but then # the moment estimates in adaptive optimizers would never have # enough time to warm up. So it makes sense to reuse the slots, # even though we're optimizing a different loss in every new # epoch. self._total_opt_step, self._policy_and_value_opt_state, self._policy_and_value_opt_update, self._policy_and_value_get_params, self._policy_and_value_net_apply, log_probabs_traj[index_batch], value_predictions_traj[index_batch], padded_observations[index_batch], padded_actions[index_batch], self._rewards_to_actions, padded_rewards[index_batch], reward_mask[index_batch], c1=self._c1, c2=self._c2, gamma=self._gamma, lambda_=self._lambda_, state=self._model_state, rng=k1)) opt_step += 1 self._total_opt_step += 1 # Compute the approx KL for early stopping. Use the whole dataset - as we # only do inference, it should fit in the memory. (log_probab_actions_new, _) = (self._policy_and_value_net_apply( padded_observations, params=self._policy_and_value_net_params, state=self._model_state, rng=k2)) action_mask = np.dot(np.pad(reward_mask, ((0, 0), (0, 1))), self._rewards_to_actions) approx_kl = ppo.approximate_kl(log_probab_actions_new, log_probabs_traj, action_mask) early_stopping = approx_kl > 1.5 * self._target_kl if early_stopping: logging.vlog( 1, "Early stopping policy and value optimization after %d steps, " "with approx_kl: %0.2f", opt_step, approx_kl) # We don't return right-away, we want the below to execute on the last # iteration. t2 = time.time() if (opt_step % self._print_every_optimizer_steps == 0 or opt_step == self._n_optimizer_steps or early_stopping): # Compute and log the loss. (combined_loss, component_losses, _, self._model_state) = (ppo.combined_loss( self._policy_and_value_net_params, log_probabs_traj, value_predictions_traj, self._policy_and_value_net_apply, padded_observations, padded_actions, self._rewards_to_actions, padded_rewards, reward_mask, gamma=self._gamma, lambda_=self._lambda_, c1=self._c1, c2=self._c2, state=self._model_state, rng=k3)) logging.vlog( 1, "One Policy and Value grad desc took: %0.2f msec", ppo.get_time(t, t2)) (ppo_loss, value_loss, entropy_bonus) = component_losses logging.vlog( 1, "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->" " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss, combined_loss, ppo_loss, value_loss, entropy_bonus) if early_stopping: break optimization_time = ppo.get_time(optimization_start_time) logging.vlog( 1, "Total Combined Loss reduction [%0.2f]%%", (100 * (cur_combined_loss - combined_loss) / np.abs(cur_combined_loss))) summaries.update({ "n_optimizer_steps": opt_step, "approx_kl": approx_kl, }) for (name, value) in summaries.items(): self._train_sw.scalar("train/{}".format(name), value, step=self._epoch) logging.info( "PPO epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined" " Loss(ppo, value, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]", self._epoch, min_reward, max_reward, avg_reward, combined_loss, ppo_loss, value_loss, entropy_bonus) # Bump the epoch counter before saving a checkpoint, so that a call to # save() after the training loop is a no-op if a checkpoint was saved last # epoch - otherwise it would bump the epoch counter on the checkpoint. last_epoch = self._epoch self._epoch += 1 # Save parameters every time we see the end of at least a fraction of batch # number of trajectories that are done (not completed -- completed includes # truncated and done). # Also don't save too frequently, enforce a minimum gap. policy_save_start_time = time.time() # TODO(afrozm): Refactor to trax.save_state. if (self._n_trajectories_done >= self._done_frac_for_policy_save * self.train_env.batch_size and self._epoch % self._save_every_n == 0) or self._async_mode: self.save() policy_save_time = ppo.get_time(policy_save_start_time) epoch_time = ppo.get_time(epoch_start_time) timing_dict = { "epoch": epoch_time, "policy_eval": policy_eval_time, "trajectory_collection": trajectory_collection_time, "preprocessing": preprocessing_time, "log_prob_recompute": log_prob_recompute_time, "loss_compute": loss_compute_time, "optimization": optimization_time, "policy_save": policy_save_time, } timing_dict.update(timing_info) for k, v in timing_dict.items(): self._timing_sw.scalar("timing/%s" % k, v, step=last_epoch) max_key_len = max(len(k) for k in timing_dict) timing_info_list = [ "%s : % 10.2f" % (k.rjust(max_key_len + 1), v) for k, v in sorted(timing_dict.items()) ] logging.info("PPO epoch [% 6d], Timings: \n%s", last_epoch, "\n".join(timing_info_list)) # Flush summary writers once in a while. if self._epoch % 1000 == 0: self.flush_summaries()