def __init__(self, train_env, eval_env, output_dir, policy_and_value_model=trax_models.FrameStackMLP, policy_and_value_optimizer=functools.partial( trax_opt.Adam, learning_rate=1e-3), policy_and_value_two_towers=False, n_optimizer_steps=N_OPTIMIZER_STEPS, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, target_kl=0.01, boundary=20, max_timestep=None, max_timestep_eval=20000, random_seed=None, gamma=GAMMA, lambda_=LAMBDA, c1=1.0, c2=0.01, eval_every_n=1000, done_frac_for_policy_save=0.5, n_evals=1, len_history_for_policy=4, eval_temperatures=(1.0, 0.5), **kwargs): """Creates the PPO trainer. Args: train_env: gym.Env to use for training. eval_env: gym.Env to use for evaluation. output_dir: Output dir. policy_and_value_model: Function defining the policy and value network, without the policy and value heads. policy_and_value_optimizer: Function defining the optimizer. policy_and_value_two_towers: Whether to use two separate models as the policy and value networks. If False, share their parameters. n_optimizer_steps: Number of optimizer steps. print_every_optimizer_steps: How often to log during the policy optimization process. target_kl: Policy iteration early stopping. Set to infinity to disable early stopping. boundary: We pad trajectories at integer multiples of this number. max_timestep: If set to an integer, maximum number of time-steps in a trajectory. Used in the collect procedure. max_timestep_eval: If set to an integer, maximum number of time-steps in an evaluation trajectory. Used in the collect procedure. random_seed: Random seed. gamma: Reward discount factor. lambda_: N-step TD-error discount factor in GAE. c1: Value loss coefficient. c2: Entropy loss coefficient. eval_every_n: How frequently to eval the policy. done_frac_for_policy_save: Fraction of the trajectories that should be done to checkpoint the policy. n_evals: Number of times to evaluate. len_history_for_policy: How much of history to give to the policy. eval_temperatures: Sequence of temperatures to try for categorical sampling during evaluation. **kwargs: Additional keyword arguments passed to the base class. """ # Set in base class constructor. self._train_env = None self._should_reset = None super(PPO, self).__init__(train_env, eval_env, output_dir, **kwargs) self._n_optimizer_steps = n_optimizer_steps self._print_every_optimizer_steps = print_every_optimizer_steps self._target_kl = target_kl self._boundary = boundary self._max_timestep = max_timestep self._max_timestep_eval = max_timestep_eval self._gamma = gamma self._lambda_ = lambda_ self._c1 = c1 self._c2 = c2 self._eval_every_n = eval_every_n self._done_frac_for_policy_save = done_frac_for_policy_save self._n_evals = n_evals self._len_history_for_policy = len_history_for_policy self._eval_temperatures = eval_temperatures assert isinstance(self.train_env.action_space, gym.spaces.Discrete) n_actions = self.train_env.action_space.n # Batch Observations Shape = [1, 1] + OBS, because we will eventually call # policy and value networks on shape [B, T] +_OBS batch_observations_shape = (1, 1) + self.train_env.observation_space.shape observations_dtype = self.train_env.observation_space.dtype self._rng = trax.get_random_number_generator_and_set_seed(random_seed) self._rng, key1 = jax_random.split(self._rng, num=2) # Initialize the policy and value network. policy_and_value_net_params, self._model_state, policy_and_value_net_apply = ( ppo.policy_and_value_net( rng_key=key1, batch_observations_shape=batch_observations_shape, observations_dtype=observations_dtype, n_actions=n_actions, bottom_layers_fn=policy_and_value_model, two_towers=policy_and_value_two_towers, )) self._policy_and_value_net_apply = jit(policy_and_value_net_apply) # Initialize the optimizer. (policy_and_value_opt_state, self._policy_and_value_opt_update, self._policy_and_value_get_params) = ppo.optimizer_fn( policy_and_value_optimizer, policy_and_value_net_params) # Maybe restore the optimization state. If there is nothing to restore, then # iteration = 0 and policy_and_value_opt_state is returned as is. (restored, self._policy_and_value_opt_state, self._model_state, self._epoch, self._total_opt_step) = ppo.maybe_restore_opt_state( output_dir, policy_and_value_opt_state, self._model_state) if restored: logging.info("Restored parameters from iteration [%d]", self._epoch) # We should start from the next iteration. self._epoch += 1 # Create summary writers and history. self._train_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "train")) self._timing_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "timing")) self._eval_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "eval")) self._n_trajectories_done = 0 self._last_saved_at = 0
def training_loop( env=None, epochs=EPOCHS, policy_and_value_net_fn=None, policy_and_value_optimizer_fn=None, batch_size=BATCH_TRAJECTORIES, n_optimizer_steps=N_OPTIMIZER_STEPS, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, target_kl=0.01, boundary=20, max_timestep=None, max_timestep_eval=20000, random_seed=None, gamma=GAMMA, lambda_=LAMBDA, epsilon=EPSILON, c1=1.0, c2=0.01, output_dir=None, eval_every_n=1000, eval_env=None, done_frac_for_policy_save=0.5, enable_early_stopping=True, env_name=None, n_evals=1, len_history_for_policy=4, ): """Runs the training loop for PPO, with fixed policy and value nets.""" assert env assert output_dir assert env_name gfile.makedirs(output_dir) # Create summary writers and history. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) timing_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "timing")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) train_sw.text("env_name", env_name) timing_sw.text("env_name", env_name) eval_sw.text("env_name", env_name) jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed) # Batch Observations Shape = [1, 1] + OBS, because we will eventually call # policy and value networks on shape [B, T] +_OBS batch_observations_shape = (1, 1) + env.observation_space.shape observations_dtype = env.observation_space.dtype assert isinstance(env.action_space, gym.spaces.Discrete) n_actions = env.action_space.n jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2) # Initialize the policy and value network. policy_and_value_net_params, policy_and_value_net_apply = ( policy_and_value_net_fn(key1, batch_observations_shape, observations_dtype, n_actions)) # Maybe restore the policy params. If there is nothing to restore, then # iteration = 0 and policy_and_value_net_params are returned as is. restore, policy_and_value_net_params, iteration = (maybe_restore_params( output_dir, policy_and_value_net_params)) if restore: logging.info("Restored parameters from iteration [%d]", iteration) # We should start from the next iteration. iteration += 1 policy_and_value_net_apply = jit(policy_and_value_net_apply) # Initialize the optimizers. policy_and_value_optimizer = ( policy_and_value_optimizer_fn(policy_and_value_net_params)) (policy_and_value_opt_state, policy_and_value_opt_update, policy_and_value_get_params) = policy_and_value_optimizer n_trajectories_done = 0 last_saved_at = 0 logging.info("Starting the PPO training loop.") for i in range(iteration, epochs): epoch_start_time = time.time() # Params we'll use to collect the trajectories. policy_and_value_net_params = policy_and_value_get_params( policy_and_value_opt_state) # A function to get the policy and value predictions. def get_predictions(observations, rng=None): """Returns log-probs, value predictions and key back.""" key, key1 = jax_random.split(rng, num=2) log_probs, value_preds = policy_and_value_net_apply( observations, policy_and_value_net_params, rng=key1) return log_probs, value_preds, key # Evaluate the policy. policy_eval_start_time = time.time() if ((i + 1) % eval_every_n == 0) or (i == epochs - 1): jax_rng_key, key = jax_random.split(jax_rng_key, num=2) logging.vlog(1, "Epoch [% 6d] evaluating policy.", i) avg_reward, avg_reward_unclipped = evaluate_policy( eval_env, get_predictions, max_timestep=max_timestep_eval, n_evals=n_evals, len_history_for_policy=len_history_for_policy, rng=key) for k, v in avg_reward.items(): eval_sw.scalar("eval/mean_reward/%s" % k, v, step=i) logging.info( "Epoch [% 6d] Policy Evaluation (clipped) [%s] = %10.2f", i, k, v) for k, v in avg_reward_unclipped.items(): eval_sw.scalar("eval/mean_reward_unclipped/%s" % k, v, step=i) logging.info( "Epoch [% 6d] Policy Evaluation (unclipped) [%s] = %10.2f", i, k, v) policy_eval_time = get_time(policy_eval_start_time) trajectory_collection_start_time = time.time() logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i) jax_rng_key, key = jax_random.split(jax_rng_key) trajs, n_done, timing_info = collect_trajectories( env, policy_fn=get_predictions, n_trajectories=batch_size, max_timestep=max_timestep, rng=key, len_history_for_policy=len_history_for_policy, reset=(i == 0) or restore, epsilon=(10.0 / (i + 10.0))) # this is a different epsilon. trajectory_collection_time = get_time(trajectory_collection_start_time) logging.vlog(1, "Collecting trajectories took %0.2f msec.", trajectory_collection_time) avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs) max_reward = max(np.sum(traj[2]) for traj in trajs) min_reward = min(np.sum(traj[2]) for traj in trajs) train_sw.scalar("train/mean_reward", avg_reward, step=i) logging.vlog(1, "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s", avg_reward, max_reward, min_reward, [float(np.sum(traj[2])) for traj in trajs]) logging.vlog( 1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]", float(sum(len(traj[0]) for traj in trajs)) / len(trajs), max(len(traj[0]) for traj in trajs), min(len(traj[0]) for traj in trajs)) logging.vlog(2, "Trajectory Lengths: %s", [len(traj[0]) for traj in trajs]) padding_start_time = time.time() (_, reward_mask, padded_observations, padded_actions, padded_rewards) = pad_trajectories(trajs, boundary=boundary) padding_time = get_time(padding_start_time) logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(padding_start_time)) logging.vlog(1, "Padded Observations' shape [%s]", str(padded_observations.shape)) logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape)) logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape)) # Calculate log-probabilities and value predictions of the trajectories. # We'll pass these to the loss functions so as to not get recomputed. # NOTE: # There is a slight problem here, if the policy network contains # stochasticity in the log-probabilities (ex: dropout), then calculating # these again here is not going to be correct and should be done in the # collect function. log_prob_recompute_start_time = time.time() jax_rng_key, key = jax_random.split(jax_rng_key) log_probabs_traj, value_predictions_traj, _ = get_predictions( padded_observations, rng=key) log_prob_recompute_time = get_time(log_prob_recompute_start_time) # Some assertions. B, T = padded_actions.shape # pylint: disable=invalid-name assert (B, T) == padded_rewards.shape assert (B, T) == reward_mask.shape assert (B, T + 1) == padded_observations.shape[:2] assert (B, T + 1) + env.observation_space.shape == padded_observations.shape # Linear annealing from 0.1 to 0.0 # epsilon_schedule = epsilon if epochs == 1 else epsilon * (1.0 - # (i / # (epochs - 1))) # Constant epsilon. epsilon_schedule = epsilon # Compute value and ppo losses. jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2) logging.vlog(2, "Starting to compute P&V loss.") loss_compute_start_time = time.time() cur_combined_loss, cur_ppo_loss, cur_value_loss, entropy_bonus = ( combined_loss(policy_and_value_net_params, log_probabs_traj, value_predictions_traj, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2, rng=key1)) loss_compute_time = get_time(loss_compute_start_time) logging.vlog( 1, "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.", cur_combined_loss, cur_value_loss, cur_ppo_loss, entropy_bonus, get_time(loss_compute_start_time)) jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2) logging.vlog(1, "Policy and Value Optimization") optimization_start_time = time.time() keys = jax_random.split(key1, num=n_optimizer_steps) for j in range(n_optimizer_steps): k1, k2, k3 = jax_random.split(keys[j], num=3) t = time.time() # Update the optimizer state. policy_and_value_opt_state = policy_and_value_opt_step( j, policy_and_value_opt_state, policy_and_value_opt_update, policy_and_value_get_params, policy_and_value_net_apply, log_probabs_traj, value_predictions_traj, padded_observations, padded_actions, padded_rewards, reward_mask, c1=c1, c2=c2, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, rng=k1) # Compute the approx KL for early stopping. new_policy_and_value_net_params = policy_and_value_get_params( policy_and_value_opt_state) log_probab_actions_new, _ = policy_and_value_net_apply( padded_observations, new_policy_and_value_net_params, rng=k2) approx_kl = approximate_kl(log_probab_actions_new, log_probabs_traj, reward_mask) early_stopping = enable_early_stopping and approx_kl > 1.5 * target_kl if early_stopping: logging.vlog( 1, "Early stopping policy and value optimization at iter: %d, " "with approx_kl: %0.2f", j, approx_kl) # We don't return right-away, we want the below to execute on the last # iteration. t2 = time.time() if (((j + 1) % print_every_optimizer_steps == 0) or (j == n_optimizer_steps - 1) or early_stopping): # Compute and log the loss. (loss_combined, loss_ppo, loss_value, entropy_bonus) = (combined_loss( new_policy_and_value_net_params, log_probabs_traj, value_predictions_traj, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2, rng=k3)) logging.vlog( 1, "One Policy and Value grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog( 1, "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->" " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss, loss_combined, loss_value, loss_ppo, entropy_bonus) if early_stopping: break optimization_time = get_time(optimization_start_time) logging.vlog( 1, "Total Combined Loss reduction [%0.2f]%%", (100 * (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss))) # Save parameters every time we see the end of at least a fraction of batch # number of trajectories that are done (not completed -- completed includes # truncated and done). # Also don't save too frequently, enforce a minimum gap. # Or if this is the last iteration. policy_save_start_time = time.time() n_trajectories_done += n_done # TODO(afrozm): Refactor to trax.save_state. if (((n_trajectories_done >= done_frac_for_policy_save * batch_size) and (i - last_saved_at > eval_every_n) and (((i + 1) % eval_every_n == 0))) or (i == epochs - 1)): logging.vlog(1, "Epoch [% 6d] saving model.", i) old_model_files = gfile.glob( os.path.join(output_dir, "model-??????.pkl")) params_file = os.path.join(output_dir, "model-%06d.pkl" % i) with gfile.GFile(params_file, "wb") as f: pickle.dump(policy_and_value_net_params, f) # Remove the old model files. for path in old_model_files: gfile.remove(path) # Reset this number. n_trajectories_done = 0 last_saved_at = i policy_save_time = get_time(policy_save_start_time) epoch_time = get_time(epoch_start_time) logging.info( "Epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined" " Loss(value, ppo, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]", i, min_reward, max_reward, avg_reward, loss_combined, loss_value, loss_ppo, entropy_bonus) timing_dict = { "epoch": epoch_time, "policy_eval": policy_eval_time, "trajectory_collection": trajectory_collection_time, "padding": padding_time, "log_prob_recompute": log_prob_recompute_time, "loss_compute": loss_compute_time, "optimization": optimization_time, "policy_save": policy_save_time, } timing_dict.update(timing_info) for k, v in timing_dict.items(): timing_sw.scalar("timing/%s" % k, v, step=i) max_key_len = max(len(k) for k in timing_dict) timing_info_list = [ "%s : % 10.2f" % (k.rjust(max_key_len + 1), v) for k, v in sorted(timing_dict.items()) ] logging.info("Epoch [% 6d], Timings: \n%s", i, "\n".join(timing_info_list)) # Reset restore. restore = False # Flush summary writers once in a while. if (i + 1) % 1000 == 0 or i == epochs - 1: train_sw.flush() timing_sw.flush() eval_sw.flush()
def __init__(self, train_env, eval_env, output_dir, policy_and_value_model=trax_models.FrameStackMLP, policy_and_value_optimizer=functools.partial( trax_opt.Adam, learning_rate=1e-3), policy_and_value_two_towers=False, policy_and_value_vocab_size=None, n_optimizer_steps=N_OPTIMIZER_STEPS, optimizer_batch_size=64, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, target_kl=0.01, boundary=20, max_timestep=100, max_timestep_eval=20000, random_seed=None, gamma=GAMMA, lambda_=LAMBDA, c1=1.0, c2=0.01, eval_every_n=1000, save_every_n=1000, done_frac_for_policy_save=0.5, n_evals=1, len_history_for_policy=4, eval_temperatures=(1.0, 0.5), separate_eval=True, init_policy_from_world_model_output_dir=None, **kwargs): """Creates the PPO trainer. Args: train_env: gym.Env to use for training. eval_env: gym.Env to use for evaluation. output_dir: Output dir. policy_and_value_model: Function defining the policy and value network, without the policy and value heads. policy_and_value_optimizer: Function defining the optimizer. policy_and_value_two_towers: Whether to use two separate models as the policy and value networks. If False, share their parameters. policy_and_value_vocab_size: Vocabulary size of a policy and value network operating on serialized representation. If None, use raw continuous representation. n_optimizer_steps: Number of optimizer steps. optimizer_batch_size: Batch size of an optimizer step. print_every_optimizer_steps: How often to log during the policy optimization process. target_kl: Policy iteration early stopping. Set to infinity to disable early stopping. boundary: We pad trajectories at integer multiples of this number. max_timestep: If set to an integer, maximum number of time-steps in a trajectory. Used in the collect procedure. max_timestep_eval: If set to an integer, maximum number of time-steps in an evaluation trajectory. Used in the collect procedure. random_seed: Random seed. gamma: Reward discount factor. lambda_: N-step TD-error discount factor in GAE. c1: Value loss coefficient. c2: Entropy loss coefficient. eval_every_n: How frequently to eval the policy. save_every_n: How frequently to save the policy. done_frac_for_policy_save: Fraction of the trajectories that should be done to checkpoint the policy. n_evals: Number of times to evaluate. len_history_for_policy: How much of history to give to the policy. eval_temperatures: Sequence of temperatures to try for categorical sampling during evaluation. separate_eval: Whether to run separate evaluation using a set of temperatures. If False, the training reward is reported as evaluation reward with temperature 1.0. init_policy_from_world_model_output_dir: Model output dir for initializing the policy. If None, initialize randomly. **kwargs: Additional keyword arguments passed to the base class. """ # Set in base class constructor. self._train_env = None self._should_reset = None super(PPO, self).__init__(train_env, eval_env, output_dir, **kwargs) self._n_optimizer_steps = n_optimizer_steps self._optimizer_batch_size = optimizer_batch_size self._print_every_optimizer_steps = print_every_optimizer_steps self._target_kl = target_kl self._boundary = boundary self._max_timestep = max_timestep self._max_timestep_eval = max_timestep_eval self._gamma = gamma self._lambda_ = lambda_ self._c1 = c1 self._c2 = c2 self._eval_every_n = eval_every_n self._save_every_n = save_every_n self._done_frac_for_policy_save = done_frac_for_policy_save self._n_evals = n_evals self._len_history_for_policy = len_history_for_policy self._eval_temperatures = eval_temperatures self._separate_eval = separate_eval action_space = self.train_env.action_space assert isinstance(action_space, (gym.spaces.Discrete, gym.spaces.MultiDiscrete)) if isinstance(action_space, gym.spaces.Discrete): n_actions = action_space.n n_controls = 1 else: (n_controls, ) = action_space.nvec.shape assert n_controls > 0 assert onp.min(action_space.nvec) == onp.max(action_space.nvec), ( "Every control must have the same number of actions.") n_actions = action_space.nvec[0] self._n_actions = n_actions self._n_controls = n_controls self._rng = trax.get_random_number_generator_and_set_seed(random_seed) self._rng, key1 = jax_random.split(self._rng, num=2) vocab_size = policy_and_value_vocab_size self._serialized_sequence_policy = vocab_size is not None if self._serialized_sequence_policy: self._serialization_kwargs = self._init_serialization(vocab_size) else: self._serialization_kwargs = {} # Initialize the policy and value network. policy_and_value_net = ppo.policy_and_value_net( n_actions=n_actions, n_controls=n_controls, vocab_size=vocab_size, bottom_layers_fn=policy_and_value_model, two_towers=policy_and_value_two_towers, ) self._policy_and_value_net_apply = jit(policy_and_value_net) (batch_obs_shape, obs_dtype) = self._batch_obs_shape_and_dtype policy_and_value_net_params, self._model_state = ( policy_and_value_net.initialize_once(batch_obs_shape, obs_dtype, key1)) if init_policy_from_world_model_output_dir is not None: policy_and_value_net_params = ppo.init_policy_from_world_model_checkpoint( policy_and_value_net_params, init_policy_from_world_model_output_dir) # Initialize the optimizer. (policy_and_value_opt_state, self._policy_and_value_opt_update, self._policy_and_value_get_params) = ppo.optimizer_fn( policy_and_value_optimizer, policy_and_value_net_params) # Restore the optimizer state. self._policy_and_value_opt_state = policy_and_value_opt_state self._epoch = 0 self._total_opt_step = 0 self.update_optimization_state( output_dir, policy_and_value_opt_state=policy_and_value_opt_state) # Create summary writers and history. self._train_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "train")) self._timing_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "timing")) self._eval_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "eval")) self._n_trajectories_done = 0 self._last_saved_at = 0 if self._async_mode: logging.info( "Saving model on startup to have a model policy file.") self.save() self._rewards_to_actions = self._init_rewards_to_actions()
def setUp(self): super(PpoTest, self).setUp() self.rng_key = trax.get_random_number_generator_and_set_seed(0)
def setUp(self): self.rng_key = trax.get_random_number_generator_and_set_seed(0)
def training_loop( env=None, env_name="CartPole-v0", epochs=EPOCHS, policy_net_fun=None, value_net_fun=None, policy_and_value_net_fun=None, # TODO(afrozm): Implement. policy_optimizer_fun=optimizer_fun, value_optimizer_fun=optimizer_fun, batch_size=BATCH_TRAJECTORIES, num_optimizer_steps=NUM_OPTIMIZER_STEPS, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, boundary=20, max_timestep=None, random_seed=None): """Runs the training loop for PPO, with fixed policy and value nets.""" jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed) value_losses = [] ppo_objective = [] average_rewards = [] env = env if env is not None else gym.make(env_name) # Batch Observations Shape = [-1, -1] + OBS, because we will eventually call # policy and value networks on shape [B, T] +_OBS batch_observations_shape = (-1, -1) + env.observation_space.shape assert isinstance(env.action_space, gym.spaces.Discrete) num_actions = env.action_space.n # TODO(afrozm): Have a single net for both policy and action. assert policy_and_value_net_fun is None # Initialize the policy and value functions. assert policy_net_fun and value_net_fun jax_rng_key, key1, key2 = jax_random.split(jax_rng_key, num=3) policy_net_params, policy_net_apply = policy_net_fun( key1, batch_observations_shape, num_actions) value_net_params, value_net_apply = value_net_fun( key2, batch_observations_shape, num_actions) # Initialize the optimizers. assert policy_optimizer_fun and value_optimizer_fun ppo_opt_state, ppo_opt_update = policy_optimizer_fun(policy_net_params) value_opt_state, value_opt_update = value_optimizer_fun(value_net_params) for i in range(epochs): t = time.time() t0 = t logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i) trajs = collect_trajectories( env, policy_net_apply, policy_net_params, num_trajectories=batch_size, policy=POLICY, max_timestep=max_timestep, epsilon=(10.0 / (i + 10.0))) # this is a different epsilon. avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs) max_reward = max(np.sum(traj[2]) for traj in trajs) min_reward = min(np.sum(traj[2]) for traj in trajs) average_rewards.append(avg_reward) logging.vlog(1, "Rewards average=[%0.2f], max=[%0.2f], min=[%0.2f]", avg_reward, max_reward, min_reward) logging.vlog(1, "Collecting trajectories took %0.2f msec.", get_time(t)) logging.vlog( 1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]", float(sum(len(traj[0]) for traj in trajs)) / len(trajs), max(len(traj[0]) for traj in trajs), min(len(traj[0]) for traj in trajs)) t = time.time() (_, reward_mask, padded_observations, padded_actions, padded_rewards) = pad_trajectories(trajs, boundary=boundary) logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(t)) logging.vlog(1, "Padded Observations' shape [%s]", str(padded_observations.shape)) logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape)) logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape)) # Some assertions. B, T = padded_actions.shape # pylint: disable=invalid-name assert (B, T) == padded_rewards.shape assert (B, T) == reward_mask.shape assert (B, T + 1) == padded_observations.shape[:2] assert (B, T + 1) + env.observation_space.shape == padded_observations.shape # Linear annealing from 0.1 to 0.0 epsilon = 0.1 if epochs == 1 else 0.1 * (1.0 - (i / (epochs - 1))) t = time.time() cur_value_loss = value_loss(value_net_apply, value_net_params, padded_observations, padded_rewards, reward_mask, gamma=GAMMA) logging.vlog(1, "Calculating value loss took %0.2f msec.", get_time(t)) value_losses.append(cur_value_loss) t = time.time() cur_ppo_loss = ppo_loss(policy_net_apply, policy_net_params, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=GAMMA, lambda_=LAMBDA, epsilon=epsilon) # ppo_loss = 11.00110011 logging.vlog(1, "Calculating PPO loss took %0.2f msec.", get_time(t)) ppo_objective.append(-cur_ppo_loss) # Run optimizers. logging.vlog(1, "PPO Optimization") t1 = time.time() for j in range(num_optimizer_steps): t = time.time() # Update the optimizer state. ppo_opt_state = ppo_opt_step(j, ppo_opt_state, ppo_opt_update, policy_net_apply, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=GAMMA, lambda_=LAMBDA, epsilon=epsilon) t2 = time.time() # Get the new params. new_policy_net_params = trax_opt.get_params(ppo_opt_state) if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): new_ppo_loss = ppo_loss(policy_net_apply, new_policy_net_params, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=GAMMA, lambda_=LAMBDA, epsilon=epsilon) logging.vlog(1, "One PPO grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog(1, "PPO loss [%10.2f] -> [%10.2f]", cur_ppo_loss, new_ppo_loss) # Update the params. policy_net_params = new_policy_net_params logging.vlog(1, "Total PPO loss reduction [%0.2f]%%", (100 * (cur_ppo_loss - new_ppo_loss) / np.abs(cur_ppo_loss))) logging.vlog(1, "Value Optimization") for j in range(num_optimizer_steps): t = time.time() value_opt_state = value_opt_step(j, value_opt_state, value_opt_update, value_net_apply, padded_observations, padded_rewards, reward_mask, gamma=GAMMA) t2 = time.time() value_net_params = trax_opt.get_params(value_opt_state) if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): new_value_loss = value_loss(value_net_apply, value_net_params, padded_observations, padded_rewards, reward_mask, gamma=GAMMA) logging.vlog(1, "One value grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog(1, "Value loss [%10.2f] -> [%10.2f]", cur_value_loss, new_value_loss) logging.vlog( 1, "Total value loss reduction [%0.2f]%%", (100 * (cur_value_loss - new_value_loss) / np.abs(cur_value_loss))) logging.vlog(1, "Grad desc took %0.2f msec", get_time(t1)) # Set the optimized params to new params. policy_net_params = trax_opt.get_params(ppo_opt_state) value_net_params = trax_opt.get_params(value_opt_state) logging.info( "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], " "ppo loss [%10.2f], value loss [%10.2f], took [%10.2f msec]", i, min_reward, max_reward, avg_reward, new_ppo_loss, new_value_loss, get_time(t0)) logging.vlog(1, "value_losses: %s", np.stack(value_losses)) logging.vlog(1, "ppo_objective: %s", np.stack(ppo_objective)) logging.vlog(1, "average_rewards: %s", average_rewards) return ((policy_net_params, value_net_params), average_rewards, np.stack(value_losses), np.stack(ppo_objective))
def training_loop(env=None, env_name="CartPole-v0", epochs=EPOCHS, policy_net_fun=None, value_net_fun=None, policy_and_value_net_fun=None, policy_optimizer_fun=None, value_optimizer_fun=None, policy_and_value_optimizer_fun=None, batch_size=BATCH_TRAJECTORIES, num_optimizer_steps=NUM_OPTIMIZER_STEPS, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, boundary=20, max_timestep=None, random_seed=None, gamma=GAMMA, lambda_=LAMBDA, epsilon=EPSILON, c1=1.0, c2=0.01): """Runs the training loop for PPO, with fixed policy and value nets.""" jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed) value_losses = [] ppo_objective = [] combined_losses = [] average_rewards = [] env = env if env is not None else gym.make(env_name) # Batch Observations Shape = [-1, -1] + OBS, because we will eventually call # policy and value networks on shape [B, T] +_OBS batch_observations_shape = (-1, -1) + env.observation_space.shape assert isinstance(env.action_space, gym.spaces.Discrete) num_actions = env.action_space.n policy_and_value_net_params, policy_and_value_net_apply = None, None policy_and_value_opt_state, policy_and_value_opt_update = None, None policy_net_params, policy_net_apply = None, None value_net_params, value_net_apply = None, None if policy_and_value_net_fun is not None: jax_rng_key, subkey = jax_random.split(jax_rng_key) # Initialize the policy and value network. policy_and_value_net_params, policy_and_value_net_apply = ( policy_and_value_net_fun(subkey, batch_observations_shape, num_actions)) # Initialize the optimizers. policy_and_value_opt_state, policy_and_value_opt_update = ( policy_and_value_optimizer_fun(policy_and_value_net_params)) else: # Initialize the policy and value functions. assert policy_net_fun and value_net_fun jax_rng_key, key1, key2 = jax_random.split(jax_rng_key, num=3) policy_net_params, policy_net_apply = policy_net_fun( key1, batch_observations_shape, num_actions) value_net_params, value_net_apply = value_net_fun( key2, batch_observations_shape, num_actions) # Initialize the optimizers. ppo_opt_state, ppo_opt_update = policy_optimizer_fun(policy_net_params) value_opt_state, value_opt_update = value_optimizer_fun( value_net_params) # A function that will call the appropriate policy function with parameters. def get_policy_output(observations): if policy_net_apply is not None: assert policy_net_params return policy_net_apply(observations, policy_net_params) assert policy_and_value_net_apply and policy_and_value_net_params policy_predictions, unused_value_predictions = policy_and_value_net_apply( observations, policy_and_value_net_params) return policy_predictions for i in range(epochs): t = time.time() t0 = t logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i) trajs = collect_trajectories( env, policy_fun=get_policy_output, num_trajectories=batch_size, policy=POLICY, max_timestep=max_timestep, epsilon=(10.0 / (i + 10.0))) # this is a different epsilon. avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs) max_reward = max(np.sum(traj[2]) for traj in trajs) min_reward = min(np.sum(traj[2]) for traj in trajs) average_rewards.append(avg_reward) logging.vlog(1, "Rewards average=[%0.2f], max=[%0.2f], min=[%0.2f]", avg_reward, max_reward, min_reward) logging.vlog(1, "Collecting trajectories took %0.2f msec.", get_time(t)) logging.vlog( 1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]", float(sum(len(traj[0]) for traj in trajs)) / len(trajs), max(len(traj[0]) for traj in trajs), min(len(traj[0]) for traj in trajs)) t = time.time() (_, reward_mask, padded_observations, padded_actions, padded_rewards) = pad_trajectories(trajs, boundary=boundary) logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(t)) logging.vlog(1, "Padded Observations' shape [%s]", str(padded_observations.shape)) logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape)) logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape)) # Some assertions. B, T = padded_actions.shape # pylint: disable=invalid-name assert (B, T) == padded_rewards.shape assert (B, T) == reward_mask.shape assert (B, T + 1) == padded_observations.shape[:2] assert (B, T + 1) + env.observation_space.shape == padded_observations.shape # Linear annealing from 0.1 to 0.0 epsilon_schedule = epsilon if epochs == 1 else epsilon * ( 1.0 - (i / (epochs - 1))) # Compute value and ppo losses. cur_value_loss, cur_ppo_loss, cur_combined_loss = None, None, None if policy_and_value_net_apply is not None: t = time.time() cur_combined_loss, cur_ppo_loss, cur_value_loss, _ = ( combined_loss(policy_and_value_net_params, policy_and_value_net_params, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2)) logging.vlog( 1, "Calculating P&V loss [%10.2f(%10.2f, %10.2f)] took %0.2f msec.", cur_combined_loss, cur_value_loss, cur_ppo_loss, get_time(t)) else: t = time.time() cur_value_loss = value_loss(value_net_apply, value_net_params, padded_observations, padded_rewards, reward_mask, gamma=gamma) logging.vlog(1, "Calculating value loss took %0.2f msec.", get_time(t)) t = time.time() cur_ppo_loss = ppo_loss(policy_net_apply, policy_net_params, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule) logging.vlog(1, "Calculating PPO loss took %0.2f msec.", get_time(t)) value_losses.append(cur_value_loss) ppo_objective.append(-1.0 * cur_ppo_loss) combined_losses.append(cur_combined_loss) if policy_and_value_net_apply: logging.vlog(1, "Policy and Value Optimization") t1 = time.time() for j in range(num_optimizer_steps): t = time.time() # Update the optimizer state. policy_and_value_opt_state = policy_and_value_opt_step( j, policy_and_value_opt_state, policy_and_value_opt_update, policy_and_value_net_apply, policy_and_value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, c1=c1, c2=c2, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule) t2 = time.time() # Get the new params. new_policy_and_value_net_params = trax_opt.get_params( policy_and_value_opt_state) if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): # Compute and log the loss. (loss_combined, loss_ppo, loss_value, unused_entropy_bonus) = ( combined_loss( new_policy_and_value_net_params, policy_and_value_net_params, # old params policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2)) logging.vlog( 1, "One Policy and Value grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog( 1, "Combined Loss(value, ppo) [%10.2f] -> [%10.2f(%10.2f,%10.2f)]", cur_combined_loss, loss_combined, loss_value, loss_ppo) # Update the params. policy_and_value_net_params = new_policy_and_value_net_params logging.vlog(1, "Total PPO loss reduction [%0.2f]%%", (100 * (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss))) logging.info( "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], Combined" " Loss(value, ppo) [%10.2f(%10.2f,%10.2f)], took [%10.2f msec]", i, min_reward, max_reward, avg_reward, loss_combined, loss_value, loss_ppo, get_time(t1)) else: # Run optimizers. logging.vlog(1, "PPO Optimization") t1 = time.time() for j in range(num_optimizer_steps): t = time.time() # Update the optimizer state. ppo_opt_state = ppo_opt_step( j, ppo_opt_state, ppo_opt_update, policy_net_apply, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, ) t2 = time.time() # Get the new params. new_policy_net_params = trax_opt.get_params(ppo_opt_state) if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): new_ppo_loss = ppo_loss( policy_net_apply, new_policy_net_params, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, ) logging.vlog(1, "One PPO grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog(1, "PPO loss [%10.2f] -> [%10.2f]", cur_ppo_loss, new_ppo_loss) # Update the params. policy_net_params = new_policy_net_params logging.vlog( 1, "Total PPO loss reduction [%0.2f]%%", (100 * (cur_ppo_loss - new_ppo_loss) / np.abs(cur_ppo_loss))) logging.vlog(1, "Value Optimization") for j in range(num_optimizer_steps): t = time.time() value_opt_state = value_opt_step(j, value_opt_state, value_opt_update, value_net_apply, padded_observations, padded_rewards, reward_mask, gamma=gamma) t2 = time.time() value_net_params = trax_opt.get_params(value_opt_state) if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): new_value_loss = value_loss(value_net_apply, value_net_params, padded_observations, padded_rewards, reward_mask, gamma=gamma) logging.vlog(1, "One value grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog(1, "Value loss [%10.2f] -> [%10.2f]", cur_value_loss, new_value_loss) logging.vlog( 1, "Total value loss reduction [%0.2f]%%", (100 * (cur_value_loss - new_value_loss) / np.abs(cur_value_loss))) logging.vlog(1, "Grad desc took %0.2f msec", get_time(t1)) # Set the optimized params to new params. policy_net_params = trax_opt.get_params(ppo_opt_state) value_net_params = trax_opt.get_params(value_opt_state) logging.info( "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], " "ppo loss [%10.2f], value loss [%10.2f], took [%10.2f msec]", i, min_reward, max_reward, avg_reward, new_ppo_loss, new_value_loss, get_time(t0)) # Log the parameters, just for the sake of it. if policy_net_params: log_params(policy_net_params, "policy_net_params") if value_net_params: log_params(value_net_params, "value_net_params") if policy_and_value_net_params: log_params(policy_and_value_net_params, "policy_and_value_net_params") if value_losses: logging.vlog(1, "value_losses: %s", np.stack(value_losses)) if ppo_objective: logging.vlog(1, "ppo_objective: %s", np.stack(ppo_objective)) if average_rewards: logging.vlog(1, "average_rewards: %s", average_rewards) return ((policy_net_params, value_net_params), average_rewards, np.stack(value_losses), np.stack(ppo_objective))
def training_loop( env=None, epochs=EPOCHS, policy_net_fun=None, value_net_fun=None, policy_and_value_net_fun=None, policy_optimizer_fun=None, value_optimizer_fun=None, policy_and_value_optimizer_fun=None, batch_size=BATCH_TRAJECTORIES, num_optimizer_steps=NUM_OPTIMIZER_STEPS, policy_only_num_optimizer_steps=POLICY_ONLY_NUM_OPTIMIZER_STEPS, value_only_num_optimizer_steps=VALUE_ONLY_NUM_OPTIMIZER_STEPS, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, target_kl=0.01, boundary=20, max_timestep=None, random_seed=None, gamma=GAMMA, lambda_=LAMBDA, epsilon=EPSILON, c1=1.0, c2=0.01): """Runs the training loop for PPO, with fixed policy and value nets.""" assert env jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed) value_losses = [] ppo_objective = [] combined_losses = [] average_rewards = [] # Batch Observations Shape = [-1, -1] + OBS, because we will eventually call # policy and value networks on shape [B, T] +_OBS batch_observations_shape = (-1, -1) + env.observation_space.shape assert isinstance(env.action_space, gym.spaces.Discrete) num_actions = env.action_space.n policy_and_value_net_params, policy_and_value_net_apply = None, None policy_and_value_opt_state, policy_and_value_opt_update = None, None policy_net_params, policy_net_apply = None, None value_net_params, value_net_apply = None, None if policy_and_value_net_fun is not None: jax_rng_key, subkey = jax_random.split(jax_rng_key) # Initialize the policy and value network. policy_and_value_net_params, policy_and_value_net_apply = ( policy_and_value_net_fun(subkey, batch_observations_shape, num_actions)) # Initialize the optimizers. policy_and_value_opt_state, policy_and_value_opt_update = ( policy_and_value_optimizer_fun(policy_and_value_net_params)) policy_and_value_net_apply = jit(policy_and_value_net_apply) else: # Initialize the policy and value functions. assert policy_net_fun and value_net_fun jax_rng_key, key1, key2 = jax_random.split(jax_rng_key, num=3) policy_net_params, policy_net_apply = policy_net_fun( key1, batch_observations_shape, num_actions) value_net_params, value_net_apply = value_net_fun(key2, batch_observations_shape, num_actions) policy_net_apply = jit(policy_net_apply) value_net_apply = jit(value_net_apply) # Initialize the optimizers. ppo_opt_state, ppo_opt_update = policy_optimizer_fun(policy_net_params) value_opt_state, value_opt_update = value_optimizer_fun(value_net_params) # A function that will call the appropriate policy function with parameters. def get_policy_output(observations): # Get the fresh params for collecting the policy. if policy_net_apply is not None: return policy_net_apply(observations, trax_opt.get_params(ppo_opt_state)) assert policy_and_value_net_apply policy_predictions, unused_value_predictions = policy_and_value_net_apply( observations, trax_opt.get_params(policy_and_value_opt_state)) return policy_predictions for i in range(epochs): t = time.time() t0 = t logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i) trajs = collect_trajectories( env, policy_fun=get_policy_output, num_trajectories=batch_size, policy=POLICY, max_timestep=max_timestep, boundary=boundary, epsilon=(10.0 / (i + 10.0))) # this is a different epsilon. logging.vlog(1, "Collecting trajectories took %0.2f msec.", get_time(t)) # These were the params that were used to collect the trajectory. if policy_and_value_net_apply: policy_and_value_net_params = trax_opt.get_params( policy_and_value_opt_state) else: policy_net_params = trax_opt.get_params(ppo_opt_state) value_net_params = trax_opt.get_params(value_opt_state) avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs) max_reward = max(np.sum(traj[2]) for traj in trajs) min_reward = min(np.sum(traj[2]) for traj in trajs) average_rewards.append(avg_reward) logging.vlog(1, "Rewards average=[%0.2f], max=[%0.2f], min=[%0.2f]", avg_reward, max_reward, min_reward) logging.vlog(2, "Rewards: %s", [float(np.sum(traj[2])) for traj in trajs]) logging.vlog(1, "Average Rewards: %s", average_rewards) logging.vlog(1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]", float(sum(len(traj[0]) for traj in trajs)) / len(trajs), max(len(traj[0]) for traj in trajs), min(len(traj[0]) for traj in trajs)) logging.vlog(2, "Trajectory Lengths: %s", [len(traj[0]) for traj in trajs]) t = time.time() (_, reward_mask, padded_observations, padded_actions, padded_rewards) = pad_trajectories( trajs, boundary=boundary) logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(t)) logging.vlog(1, "Padded Observations' shape [%s]", str(padded_observations.shape)) logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape)) logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape)) # Some assertions. B, T = padded_actions.shape # pylint: disable=invalid-name assert (B, T) == padded_rewards.shape assert (B, T) == reward_mask.shape assert (B, T + 1) == padded_observations.shape[:2] assert (B, T + 1) + env.observation_space.shape == padded_observations.shape # Linear annealing from 0.1 to 0.0 # epsilon_schedule = epsilon if epochs == 1 else epsilon * (1.0 - # (i / # (epochs - 1))) # Constant epsilon. epsilon_schedule = epsilon # Compute value and ppo losses. cur_value_loss, cur_ppo_loss, cur_combined_loss = None, None, None if policy_and_value_net_apply is not None: logging.vlog(2, "Starting to compute P&V loss.") t = time.time() cur_combined_loss, cur_ppo_loss, cur_value_loss, _ = ( combined_loss( policy_and_value_net_params, policy_and_value_net_params, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2)) logging.vlog( 1, "Calculating P&V loss [%10.2f(%10.2f, %10.2f)] took %0.2f msec.", cur_combined_loss, cur_value_loss, cur_ppo_loss, get_time(t)) else: t = time.time() cur_value_loss = value_loss( value_net_apply, value_net_params, padded_observations, padded_rewards, reward_mask, gamma=gamma) logging.vlog(1, "Calculating value loss took %0.2f msec.", get_time(t)) t = time.time() cur_ppo_loss = ppo_loss( policy_net_apply, policy_net_params, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule) logging.vlog(1, "Calculating PPO loss took %0.2f msec.", get_time(t)) value_losses.append(cur_value_loss) ppo_objective.append(-1.0 * cur_ppo_loss) if cur_combined_loss: combined_losses.append(cur_combined_loss) if policy_and_value_net_apply: logging.vlog(1, "Policy and Value Optimization") t1 = time.time() for j in range(num_optimizer_steps): t = time.time() # Update the optimizer state. policy_and_value_opt_state = policy_and_value_opt_step( j, policy_and_value_opt_state, policy_and_value_opt_update, policy_and_value_net_apply, # for the entirety of this loop, this should refer to params that # were used to collect the trajectory. policy_and_value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, c1=c1, c2=c2, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule) t2 = time.time() if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): # Compute and log the loss. # Get the new params. new_policy_and_value_net_params = trax_opt.get_params( policy_and_value_opt_state) (loss_combined, loss_ppo, loss_value, unused_entropy_bonus) = ( combined_loss( new_policy_and_value_net_params, # old params, that were used to collect the trajectory policy_and_value_net_params, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2)) logging.vlog(1, "One Policy and Value grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog( 1, "Combined Loss(value, ppo) [%10.2f] -> [%10.2f(%10.2f,%10.2f)]", cur_combined_loss, loss_combined, loss_value, loss_ppo) # Update the params. policy_and_value_net_params = new_policy_and_value_net_params logging.vlog( 1, "Total Combined Loss reduction [%0.2f]%%", (100 * (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss))) logging.info( "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], Combined" " Loss(value, ppo) [%10.2f(%10.2f,%10.2f)], took [%10.2f msec]", i, min_reward, max_reward, avg_reward, loss_combined, loss_value, loss_ppo, get_time(t1)) else: # Run optimizers. logging.vlog(1, "PPO Optimization") t1 = time.time() for j in range(policy_only_num_optimizer_steps): t = time.time() # Update the optimizer state. ppo_opt_state = ppo_opt_step( j, ppo_opt_state, ppo_opt_update, policy_net_apply, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, ) t2 = time.time() # Get the new params. new_policy_net_params = trax_opt.get_params(ppo_opt_state) # These are the "old" params - policy_net_params # Compute the approx KL for early stopping. log_probab_actions_old = policy_net_apply(padded_observations, policy_net_params) log_probab_actions_new = policy_net_apply(padded_observations, new_policy_net_params) approx_kl = np.mean(log_probab_actions_old - log_probab_actions_new) early_stopping = approx_kl > 1.5 * target_kl if early_stopping: logging.vlog( 1, "Early stopping policy optimization at iter: %d, " "with approx_kl: %0.2f", j, approx_kl) # We don't return right-away, we want the below to execute on the last # iteration. if (((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1) or early_stopping): new_ppo_loss = ppo_loss( policy_net_apply, new_policy_net_params, policy_net_params, value_net_apply, value_net_params, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, ) logging.vlog(1, "One PPO grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog(1, "PPO loss [%10.2f] -> [%10.2f]", cur_ppo_loss, new_ppo_loss) if early_stopping: break # Update the params ONLY AND ONLY AFTER we complete all the optimization # iterations, till then `policy_net_params` should refer to the params # that were used in collecting the policy. # policy_net_params = trax_opt.get_params(ppo_opt_state) logging.vlog(1, "Total PPO loss reduction [%0.2f]%%", (100 * (cur_ppo_loss - new_ppo_loss) / np.abs(cur_ppo_loss))) logging.vlog(1, "Value Optimization") for j in range(value_only_num_optimizer_steps): t = time.time() value_opt_state = value_opt_step( j, value_opt_state, value_opt_update, value_net_apply, padded_observations, padded_rewards, reward_mask, gamma=gamma) t2 = time.time() value_net_params = trax_opt.get_params(value_opt_state) if ((j + 1) % print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1): new_value_loss = value_loss( value_net_apply, value_net_params, padded_observations, padded_rewards, reward_mask, gamma=gamma) logging.vlog(1, "One value grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog(1, "Value loss [%10.2f] -> [%10.2f]", cur_value_loss, new_value_loss) logging.vlog(1, "Total value loss reduction [%0.2f]%%", (100 * (cur_value_loss - new_value_loss) / np.abs(cur_value_loss))) logging.vlog(1, "Grad desc took %0.2f msec", get_time(t1)) # Set the optimized params to new params. policy_net_params = trax_opt.get_params(ppo_opt_state) value_net_params = trax_opt.get_params(value_opt_state) logging.info( "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], " "ppo loss [%10.2f], value loss [%10.2f], took [%10.2f msec]", i, min_reward, max_reward, avg_reward, new_ppo_loss, new_value_loss, get_time(t0)) # Log the parameters, just for the sake of it. if policy_net_params: log_params(policy_net_params, "policy_net_params") if value_net_params: log_params(value_net_params, "value_net_params") if policy_and_value_net_params: log_params(policy_and_value_net_params, "policy_and_value_net_params") if value_losses: logging.vlog(1, "value_losses: %s", np.stack(value_losses)) if ppo_objective: logging.vlog(1, "ppo_objective:\n%s", np.stack(ppo_objective)) if combined_losses: logging.vlog(1, "combined_losses:\n%s", np.stack(combined_losses)) if average_rewards: logging.vlog(1, "average_rewards:\n%s", average_rewards) return ((policy_net_params, value_net_params), average_rewards, np.stack(value_losses), np.stack(ppo_objective))
def training_loop( env, eval_env, env_name, policy_and_value_net_fn, policy_and_value_optimizer_fn, output_dir, epochs=EPOCHS, n_optimizer_steps=N_OPTIMIZER_STEPS, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, target_kl=0.01, boundary=20, max_timestep=None, max_timestep_eval=20000, random_seed=None, gamma=GAMMA, lambda_=LAMBDA, epsilon=EPSILON, c1=1.0, c2=0.01, eval_every_n=1000, done_frac_for_policy_save=0.5, enable_early_stopping=True, n_evals=1, len_history_for_policy=4, eval_temperatures=(1.0, 0.5), ): """Runs the training loop for PPO, with fixed policy and value nets. Args: env: gym.Env to use for training. eval_env: gym.Env to use for evaluation. env_name: Name of the environment. policy_and_value_net_fn: Function defining the policy and value network. policy_and_value_optimizer_fn: Function defining the optimizer. output_dir: Output dir. epochs: Number of epochs to run for. n_optimizer_steps: Number of optimizer steps. print_every_optimizer_steps: How often to log during the policy optimization process. target_kl: Policy iteration early stopping. boundary: We pad trajectories at integer multiples of this number. max_timestep: If set to an integer, maximum number of time-steps in a trajectory. Used in the collect procedure. max_timestep_eval: If set to an integer, maximum number of time-steps in an evaluation trajectory. Used in the collect procedure. random_seed: Random seed. gamma: Reward discount factor. lambda_: N-step TD-error discount factor in GAE. epsilon: Random action probability in epsilon-greedy sampling. c1: Value loss coefficient. c2: Entropy loss coefficient. eval_every_n: How frequently to eval the policy. done_frac_for_policy_save: Fraction of the trajectories that should be done to checkpoint the policy. enable_early_stopping: Whether to enable early stopping. n_evals: Number of times to evaluate. len_history_for_policy: How much of history to give to the policy. eval_temperatures: Sequence of temperatures to try for categorical sampling during evaluation. """ gfile.makedirs(output_dir) # Create summary writers and history. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) timing_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "timing")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) train_sw.text("env_name", env_name) timing_sw.text("env_name", env_name) eval_sw.text("env_name", env_name) jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed) # Batch Observations Shape = [1, 1] + OBS, because we will eventually call # policy and value networks on shape [B, T] +_OBS batch_observations_shape = (1, 1) + env.observation_space.shape observations_dtype = env.observation_space.dtype assert isinstance(env.action_space, gym.spaces.Discrete) n_actions = env.action_space.n jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2) # Initialize the policy and value network. policy_and_value_net_params, policy_and_value_net_apply = ( policy_and_value_net_fn(key1, batch_observations_shape, observations_dtype, n_actions)) # Maybe restore the policy params. If there is nothing to restore, then # iteration = 0 and policy_and_value_net_params are returned as is. restore, policy_and_value_net_params, iteration = (maybe_restore_params( output_dir, policy_and_value_net_params)) if restore: logging.info("Restored parameters from iteration [%d]", iteration) # We should start from the next iteration. iteration += 1 policy_and_value_net_apply = jit(policy_and_value_net_apply) # Initialize the optimizers. policy_and_value_optimizer = ( policy_and_value_optimizer_fn(policy_and_value_net_params)) (policy_and_value_opt_state, policy_and_value_opt_update, policy_and_value_get_params) = policy_and_value_optimizer n_trajectories_done = 0 last_saved_at = 0 logging.info("Starting the PPO training loop.") for i in range(iteration, epochs): epoch_start_time = time.time() # Params we'll use to collect the trajectories. policy_and_value_net_params = policy_and_value_get_params( policy_and_value_opt_state) # A function to get the policy and value predictions. def get_predictions(observations, rng=None): """Returns log-probs, value predictions and key back.""" key, key1 = jax_random.split(rng, num=2) log_probs, value_preds = policy_and_value_net_apply( observations, policy_and_value_net_params, rng=key1) return log_probs, value_preds, key # Evaluate the policy. policy_eval_start_time = time.time() if ((i + 1) % eval_every_n == 0) or (i == epochs - 1): jax_rng_key, key = jax_random.split(jax_rng_key, num=2) logging.vlog(1, "Epoch [% 6d] evaluating policy.", i) reward_stats = evaluate_policy( eval_env, get_predictions, temperatures=eval_temperatures, max_timestep=max_timestep_eval, n_evals=n_evals, len_history_for_policy=len_history_for_policy, rng=key) write_eval_reward_summaries(reward_stats, eval_sw, epoch=i) policy_eval_time = get_time(policy_eval_start_time) trajectory_collection_start_time = time.time() logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i) jax_rng_key, key = jax_random.split(jax_rng_key) trajs, n_done, timing_info = collect_trajectories( env, policy_fn=get_predictions, n_trajectories=env.batch_size, max_timestep=max_timestep, rng=key, len_history_for_policy=len_history_for_policy, reset=(i == 0) or restore, epsilon=(10.0 / (i + 10.0))) # this is a different epsilon. trajectory_collection_time = get_time(trajectory_collection_start_time) logging.vlog(1, "Collecting trajectories took %0.2f msec.", trajectory_collection_time) avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs) max_reward = max(np.sum(traj[2]) for traj in trajs) min_reward = min(np.sum(traj[2]) for traj in trajs) train_sw.scalar("train/reward_mean_truncated", avg_reward, step=i) logging.vlog(1, "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s", avg_reward, max_reward, min_reward, [float(np.sum(traj[2])) for traj in trajs]) logging.vlog( 1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]", float(sum(len(traj[0]) for traj in trajs)) / len(trajs), max(len(traj[0]) for traj in trajs), min(len(traj[0]) for traj in trajs)) logging.vlog(2, "Trajectory Lengths: %s", [len(traj[0]) for traj in trajs]) padding_start_time = time.time() (_, reward_mask, padded_observations, padded_actions, padded_rewards, padded_infos) = pad_trajectories(trajs, boundary=boundary) padding_time = get_time(padding_start_time) logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(padding_start_time)) logging.vlog(1, "Padded Observations' shape [%s]", str(padded_observations.shape)) logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape)) logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape)) # Some assertions. B, T = padded_actions.shape # pylint: disable=invalid-name assert (B, T) == padded_rewards.shape assert (B, T) == reward_mask.shape assert (B, T + 1) == padded_observations.shape[:2] assert (B, T + 1) + env.observation_space.shape == padded_observations.shape log_prob_recompute_start_time = time.time() assert ("log_prob_actions" in padded_infos and "value_predictions" in padded_infos) # These are the actual log-probabs and value predictions seen while picking # the actions. actual_log_probabs_traj = padded_infos["log_prob_actions"] actual_value_predictions_traj = padded_infos["value_predictions"] assert (B, T) == actual_log_probabs_traj.shape[:2] A = actual_log_probabs_traj.shape[2] # pylint: disable=invalid-name assert (B, T, 1) == actual_value_predictions_traj.shape # TODO(afrozm): log-probabs doesn't need to be (B, T+1, A) it can do with # (B, T, A), so make that change throughout. # NOTE: We don't have the log-probabs and value-predictions for the last # observation, so we re-calculate for everything, but use the original ones # for all but the last time-step. jax_rng_key, key = jax_random.split(jax_rng_key) log_probabs_traj, value_predictions_traj, _ = get_predictions( padded_observations, rng=key) assert (B, T + 1, A) == log_probabs_traj.shape assert (B, T + 1, 1) == value_predictions_traj.shape # Concatenate the last time-step's log-probabs and value predictions to the # actual log-probabs and value predictions and use those going forward. log_probabs_traj = np.concatenate( (actual_log_probabs_traj, log_probabs_traj[:, -1:, :]), axis=1) value_predictions_traj = np.concatenate( (actual_value_predictions_traj, value_predictions_traj[:, -1:, :]), axis=1) log_prob_recompute_time = get_time(log_prob_recompute_start_time) # Linear annealing from 0.1 to 0.0 # epsilon_schedule = epsilon if epochs == 1 else epsilon * (1.0 - # (i / # (epochs - 1))) # Constant epsilon. epsilon_schedule = epsilon # Compute value and ppo losses. jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2) logging.vlog(2, "Starting to compute P&V loss.") loss_compute_start_time = time.time() cur_combined_loss, cur_ppo_loss, cur_value_loss, entropy_bonus = ( combined_loss(policy_and_value_net_params, log_probabs_traj, value_predictions_traj, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2, rng=key1)) loss_compute_time = get_time(loss_compute_start_time) logging.vlog( 1, "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.", cur_combined_loss, cur_value_loss, cur_ppo_loss, entropy_bonus, get_time(loss_compute_start_time)) jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2) logging.vlog(1, "Policy and Value Optimization") optimization_start_time = time.time() keys = jax_random.split(key1, num=n_optimizer_steps) for j in range(n_optimizer_steps): k1, k2, k3 = jax_random.split(keys[j], num=3) t = time.time() # Update the optimizer state. policy_and_value_opt_state = policy_and_value_opt_step( j, policy_and_value_opt_state, policy_and_value_opt_update, policy_and_value_get_params, policy_and_value_net_apply, log_probabs_traj, value_predictions_traj, padded_observations, padded_actions, padded_rewards, reward_mask, c1=c1, c2=c2, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, rng=k1) # Compute the approx KL for early stopping. new_policy_and_value_net_params = policy_and_value_get_params( policy_and_value_opt_state) log_probab_actions_new, _ = policy_and_value_net_apply( padded_observations, new_policy_and_value_net_params, rng=k2) approx_kl = approximate_kl(log_probab_actions_new, log_probabs_traj, reward_mask) early_stopping = enable_early_stopping and approx_kl > 1.5 * target_kl if early_stopping: logging.vlog( 1, "Early stopping policy and value optimization at iter: %d, " "with approx_kl: %0.2f", j, approx_kl) # We don't return right-away, we want the below to execute on the last # iteration. t2 = time.time() if (((j + 1) % print_every_optimizer_steps == 0) or (j == n_optimizer_steps - 1) or early_stopping): # Compute and log the loss. (loss_combined, loss_ppo, loss_value, entropy_bonus) = (combined_loss( new_policy_and_value_net_params, log_probabs_traj, value_predictions_traj, policy_and_value_net_apply, padded_observations, padded_actions, padded_rewards, reward_mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon_schedule, c1=c1, c2=c2, rng=k3)) logging.vlog( 1, "One Policy and Value grad desc took: %0.2f msec", get_time(t, t2)) logging.vlog( 1, "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->" " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss, loss_combined, loss_value, loss_ppo, entropy_bonus) if early_stopping: break optimization_time = get_time(optimization_start_time) logging.vlog( 1, "Total Combined Loss reduction [%0.2f]%%", (100 * (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss))) # Save parameters every time we see the end of at least a fraction of batch # number of trajectories that are done (not completed -- completed includes # truncated and done). # Also don't save too frequently, enforce a minimum gap. # Or if this is the last iteration. policy_save_start_time = time.time() n_trajectories_done += n_done # TODO(afrozm): Refactor to trax.save_state. if (( (n_trajectories_done >= done_frac_for_policy_save * env.batch_size) and (i - last_saved_at > eval_every_n) and (((i + 1) % eval_every_n == 0))) or (i == epochs - 1)): logging.vlog(1, "Epoch [% 6d] saving model.", i) old_model_files = gfile.glob( os.path.join(output_dir, "model-??????.pkl")) params_file = os.path.join(output_dir, "model-%06d.pkl" % i) with gfile.GFile(params_file, "wb") as f: pickle.dump(policy_and_value_net_params, f) # Remove the old model files. for path in old_model_files: gfile.remove(path) # Reset this number. n_trajectories_done = 0 last_saved_at = i policy_save_time = get_time(policy_save_start_time) epoch_time = get_time(epoch_start_time) logging.info( "Epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined" " Loss(value, ppo, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]", i, min_reward, max_reward, avg_reward, loss_combined, loss_value, loss_ppo, entropy_bonus) timing_dict = { "epoch": epoch_time, "policy_eval": policy_eval_time, "trajectory_collection": trajectory_collection_time, "padding": padding_time, "log_prob_recompute": log_prob_recompute_time, "loss_compute": loss_compute_time, "optimization": optimization_time, "policy_save": policy_save_time, } timing_dict.update(timing_info) for k, v in timing_dict.items(): timing_sw.scalar("timing/%s" % k, v, step=i) max_key_len = max(len(k) for k in timing_dict) timing_info_list = [ "%s : % 10.2f" % (k.rjust(max_key_len + 1), v) for k, v in sorted(timing_dict.items()) ] logging.info("Epoch [% 6d], Timings: \n%s", i, "\n".join(timing_info_list)) # Reset restore. restore = False # Flush summary writers once in a while. if (i + 1) % 1000 == 0 or i == epochs - 1: train_sw.flush() timing_sw.flush() eval_sw.flush()
def __init__( self, train_env, eval_env, policy_and_value_model, policy_and_value_optimizer_fn, policy_and_value_two_towers, output_dir, n_optimizer_steps, print_every_optimizer_steps, target_kl, boundary, max_timestep, max_timestep_eval, random_seed, gamma, lambda_, c1, c2, eval_every_n, done_frac_for_policy_save, n_evals, len_history_for_policy, eval_temperatures, ): self._train_env = train_env self._eval_env = eval_env self._n_optimizer_steps = n_optimizer_steps self._print_every_optimizer_steps = print_every_optimizer_steps self._target_kl = target_kl self._boundary = boundary self._max_timestep = max_timestep self._max_timestep_eval = max_timestep_eval self._gamma = gamma self._lambda_ = lambda_ self._c1 = c1 self._c2 = c2 self._eval_every_n = eval_every_n self._done_frac_for_policy_save = done_frac_for_policy_save self._n_evals = n_evals self._len_history_for_policy = len_history_for_policy self._eval_temperatures = eval_temperatures assert isinstance(self._train_env.action_space, gym.spaces.Discrete) n_actions = self._train_env.action_space.n # Batch Observations Shape = [1, 1] + OBS, because we will eventually call # policy and value networks on shape [B, T] +_OBS batch_observations_shape = (1, 1) + self._train_env.observation_space.shape observations_dtype = self._train_env.observation_space.dtype self._rng = trax.get_random_number_generator_and_set_seed(random_seed) self._rng, key1 = jax_random.split(self._rng, num=2) # Initialize the policy and value network. policy_and_value_net_params, policy_and_value_net_apply = ( policy_and_value_net( rng_key=key1, batch_observations_shape=batch_observations_shape, observations_dtype=observations_dtype, n_actions=n_actions, bottom_layers_fn=policy_and_value_model, two_towers=policy_and_value_two_towers, ) ) self._policy_and_value_net_apply = jit(policy_and_value_net_apply) # Maybe restore the policy params. If there is nothing to restore, then # iteration = 0 and policy_and_value_net_params are returned as is. restored, policy_and_value_net_params, self._epoch = ( maybe_restore_params(output_dir, policy_and_value_net_params)) if restored: logging.info("Restored parameters from iteration [%d]", self._epoch) # We should start from the next iteration. self._epoch += 1 # Initialize the optimizers. policy_and_value_optimizer = ( policy_and_value_optimizer_fn(policy_and_value_net_params)) (self._policy_and_value_opt_state, self._policy_and_value_opt_update, self._policy_and_value_get_params) = policy_and_value_optimizer self._output_dir = output_dir gfile.makedirs(self._output_dir) # Create summary writers and history. self._train_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "train")) self._timing_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "timing")) self._eval_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "eval")) self._should_reset = True self._n_trajectories_done = 0 self._last_saved_at = 0