Example #1
0
    def __init__(self,
                 train_env,
                 eval_env,
                 output_dir,
                 policy_and_value_model=trax_models.FrameStackMLP,
                 policy_and_value_optimizer=functools.partial(
                     trax_opt.Adam, learning_rate=1e-3),
                 policy_and_value_two_towers=False,
                 n_optimizer_steps=N_OPTIMIZER_STEPS,
                 print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
                 target_kl=0.01,
                 boundary=20,
                 max_timestep=None,
                 max_timestep_eval=20000,
                 random_seed=None,
                 gamma=GAMMA,
                 lambda_=LAMBDA,
                 c1=1.0,
                 c2=0.01,
                 eval_every_n=1000,
                 done_frac_for_policy_save=0.5,
                 n_evals=1,
                 len_history_for_policy=4,
                 eval_temperatures=(1.0, 0.5),
                 **kwargs):
        """Creates the PPO trainer.

    Args:
      train_env: gym.Env to use for training.
      eval_env: gym.Env to use for evaluation.
      output_dir: Output dir.
      policy_and_value_model: Function defining the policy and value network,
        without the policy and value heads.
      policy_and_value_optimizer: Function defining the optimizer.
      policy_and_value_two_towers: Whether to use two separate models as the
        policy and value networks. If False, share their parameters.
      n_optimizer_steps: Number of optimizer steps.
      print_every_optimizer_steps: How often to log during the policy
        optimization process.
      target_kl: Policy iteration early stopping. Set to infinity to disable
        early stopping.
      boundary: We pad trajectories at integer multiples of this number.
      max_timestep: If set to an integer, maximum number of time-steps in
        a trajectory. Used in the collect procedure.
      max_timestep_eval: If set to an integer, maximum number of time-steps in
        an evaluation trajectory. Used in the collect procedure.
      random_seed: Random seed.
      gamma: Reward discount factor.
      lambda_: N-step TD-error discount factor in GAE.
      c1: Value loss coefficient.
      c2: Entropy loss coefficient.
      eval_every_n: How frequently to eval the policy.
      done_frac_for_policy_save: Fraction of the trajectories that should be
        done to checkpoint the policy.
      n_evals: Number of times to evaluate.
      len_history_for_policy: How much of history to give to the policy.
      eval_temperatures: Sequence of temperatures to try for categorical
        sampling during evaluation.
      **kwargs: Additional keyword arguments passed to the base class.
    """
        # Set in base class constructor.
        self._train_env = None
        self._should_reset = None

        super(PPO, self).__init__(train_env, eval_env, output_dir, **kwargs)

        self._n_optimizer_steps = n_optimizer_steps
        self._print_every_optimizer_steps = print_every_optimizer_steps
        self._target_kl = target_kl
        self._boundary = boundary
        self._max_timestep = max_timestep
        self._max_timestep_eval = max_timestep_eval
        self._gamma = gamma
        self._lambda_ = lambda_
        self._c1 = c1
        self._c2 = c2
        self._eval_every_n = eval_every_n
        self._done_frac_for_policy_save = done_frac_for_policy_save
        self._n_evals = n_evals
        self._len_history_for_policy = len_history_for_policy
        self._eval_temperatures = eval_temperatures

        assert isinstance(self.train_env.action_space, gym.spaces.Discrete)
        n_actions = self.train_env.action_space.n

        # Batch Observations Shape = [1, 1] + OBS, because we will eventually call
        # policy and value networks on shape [B, T] +_OBS
        batch_observations_shape = (1,
                                    1) + self.train_env.observation_space.shape
        observations_dtype = self.train_env.observation_space.dtype

        self._rng = trax.get_random_number_generator_and_set_seed(random_seed)
        self._rng, key1 = jax_random.split(self._rng, num=2)

        # Initialize the policy and value network.
        policy_and_value_net_params, self._model_state, policy_and_value_net_apply = (
            ppo.policy_and_value_net(
                rng_key=key1,
                batch_observations_shape=batch_observations_shape,
                observations_dtype=observations_dtype,
                n_actions=n_actions,
                bottom_layers_fn=policy_and_value_model,
                two_towers=policy_and_value_two_towers,
            ))
        self._policy_and_value_net_apply = jit(policy_and_value_net_apply)

        # Initialize the optimizer.
        (policy_and_value_opt_state, self._policy_and_value_opt_update,
         self._policy_and_value_get_params) = ppo.optimizer_fn(
             policy_and_value_optimizer, policy_and_value_net_params)

        # Maybe restore the optimization state. If there is nothing to restore, then
        # iteration = 0 and policy_and_value_opt_state is returned as is.
        (restored, self._policy_and_value_opt_state, self._model_state,
         self._epoch, self._total_opt_step) = ppo.maybe_restore_opt_state(
             output_dir, policy_and_value_opt_state, self._model_state)

        if restored:
            logging.info("Restored parameters from iteration [%d]",
                         self._epoch)
            # We should start from the next iteration.
            self._epoch += 1

        # Create summary writers and history.
        self._train_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "train"))
        self._timing_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "timing"))
        self._eval_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "eval"))

        self._n_trajectories_done = 0

        self._last_saved_at = 0
Example #2
0
def training_loop(
    env=None,
    epochs=EPOCHS,
    policy_and_value_net_fn=None,
    policy_and_value_optimizer_fn=None,
    batch_size=BATCH_TRAJECTORIES,
    n_optimizer_steps=N_OPTIMIZER_STEPS,
    print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
    target_kl=0.01,
    boundary=20,
    max_timestep=None,
    max_timestep_eval=20000,
    random_seed=None,
    gamma=GAMMA,
    lambda_=LAMBDA,
    epsilon=EPSILON,
    c1=1.0,
    c2=0.01,
    output_dir=None,
    eval_every_n=1000,
    eval_env=None,
    done_frac_for_policy_save=0.5,
    enable_early_stopping=True,
    env_name=None,
    n_evals=1,
    len_history_for_policy=4,
):
    """Runs the training loop for PPO, with fixed policy and value nets."""
    assert env
    assert output_dir
    assert env_name

    gfile.makedirs(output_dir)

    # Create summary writers and history.
    train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
    timing_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "timing"))
    eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

    train_sw.text("env_name", env_name)
    timing_sw.text("env_name", env_name)
    eval_sw.text("env_name", env_name)

    jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed)

    # Batch Observations Shape = [1, 1] + OBS, because we will eventually call
    # policy and value networks on shape [B, T] +_OBS
    batch_observations_shape = (1, 1) + env.observation_space.shape
    observations_dtype = env.observation_space.dtype

    assert isinstance(env.action_space, gym.spaces.Discrete)
    n_actions = env.action_space.n

    jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)

    # Initialize the policy and value network.
    policy_and_value_net_params, policy_and_value_net_apply = (
        policy_and_value_net_fn(key1, batch_observations_shape,
                                observations_dtype, n_actions))

    # Maybe restore the policy params. If there is nothing to restore, then
    # iteration = 0 and policy_and_value_net_params are returned as is.
    restore, policy_and_value_net_params, iteration = (maybe_restore_params(
        output_dir, policy_and_value_net_params))

    if restore:
        logging.info("Restored parameters from iteration [%d]", iteration)
        # We should start from the next iteration.
        iteration += 1

    policy_and_value_net_apply = jit(policy_and_value_net_apply)

    # Initialize the optimizers.
    policy_and_value_optimizer = (
        policy_and_value_optimizer_fn(policy_and_value_net_params))
    (policy_and_value_opt_state, policy_and_value_opt_update,
     policy_and_value_get_params) = policy_and_value_optimizer

    n_trajectories_done = 0
    last_saved_at = 0

    logging.info("Starting the PPO training loop.")
    for i in range(iteration, epochs):
        epoch_start_time = time.time()

        # Params we'll use to collect the trajectories.
        policy_and_value_net_params = policy_and_value_get_params(
            policy_and_value_opt_state)

        # A function to get the policy and value predictions.
        def get_predictions(observations, rng=None):
            """Returns log-probs, value predictions and key back."""
            key, key1 = jax_random.split(rng, num=2)

            log_probs, value_preds = policy_and_value_net_apply(
                observations, policy_and_value_net_params, rng=key1)

            return log_probs, value_preds, key

        # Evaluate the policy.
        policy_eval_start_time = time.time()
        if ((i + 1) % eval_every_n == 0) or (i == epochs - 1):
            jax_rng_key, key = jax_random.split(jax_rng_key, num=2)

            logging.vlog(1, "Epoch [% 6d] evaluating policy.", i)

            avg_reward, avg_reward_unclipped = evaluate_policy(
                eval_env,
                get_predictions,
                max_timestep=max_timestep_eval,
                n_evals=n_evals,
                len_history_for_policy=len_history_for_policy,
                rng=key)
            for k, v in avg_reward.items():
                eval_sw.scalar("eval/mean_reward/%s" % k, v, step=i)
                logging.info(
                    "Epoch [% 6d] Policy Evaluation (clipped) [%s] = %10.2f",
                    i, k, v)
            for k, v in avg_reward_unclipped.items():
                eval_sw.scalar("eval/mean_reward_unclipped/%s" % k, v, step=i)
                logging.info(
                    "Epoch [% 6d] Policy Evaluation (unclipped) [%s] = %10.2f",
                    i, k, v)
        policy_eval_time = get_time(policy_eval_start_time)

        trajectory_collection_start_time = time.time()
        logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i)
        jax_rng_key, key = jax_random.split(jax_rng_key)
        trajs, n_done, timing_info = collect_trajectories(
            env,
            policy_fn=get_predictions,
            n_trajectories=batch_size,
            max_timestep=max_timestep,
            rng=key,
            len_history_for_policy=len_history_for_policy,
            reset=(i == 0) or restore,
            epsilon=(10.0 / (i + 10.0)))  # this is a different epsilon.
        trajectory_collection_time = get_time(trajectory_collection_start_time)

        logging.vlog(1, "Collecting trajectories took %0.2f msec.",
                     trajectory_collection_time)

        avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs)
        max_reward = max(np.sum(traj[2]) for traj in trajs)
        min_reward = min(np.sum(traj[2]) for traj in trajs)

        train_sw.scalar("train/mean_reward", avg_reward, step=i)

        logging.vlog(1,
                     "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s",
                     avg_reward, max_reward, min_reward,
                     [float(np.sum(traj[2])) for traj in trajs])

        logging.vlog(
            1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]",
            float(sum(len(traj[0]) for traj in trajs)) / len(trajs),
            max(len(traj[0]) for traj in trajs),
            min(len(traj[0]) for traj in trajs))
        logging.vlog(2, "Trajectory Lengths: %s",
                     [len(traj[0]) for traj in trajs])

        padding_start_time = time.time()
        (_, reward_mask, padded_observations, padded_actions,
         padded_rewards) = pad_trajectories(trajs, boundary=boundary)
        padding_time = get_time(padding_start_time)

        logging.vlog(1, "Padding trajectories took %0.2f msec.",
                     get_time(padding_start_time))
        logging.vlog(1, "Padded Observations' shape [%s]",
                     str(padded_observations.shape))
        logging.vlog(1, "Padded Actions' shape [%s]",
                     str(padded_actions.shape))
        logging.vlog(1, "Padded Rewards' shape [%s]",
                     str(padded_rewards.shape))

        # Calculate log-probabilities and value predictions of the trajectories.
        # We'll pass these to the loss functions so as to not get recomputed.

        # NOTE:
        # There is a slight problem here, if the policy network contains
        # stochasticity in the log-probabilities (ex: dropout), then calculating
        # these again here is not going to be correct and should be done in the
        # collect function.

        log_prob_recompute_start_time = time.time()
        jax_rng_key, key = jax_random.split(jax_rng_key)
        log_probabs_traj, value_predictions_traj, _ = get_predictions(
            padded_observations, rng=key)
        log_prob_recompute_time = get_time(log_prob_recompute_start_time)

        # Some assertions.
        B, T = padded_actions.shape  # pylint: disable=invalid-name
        assert (B, T) == padded_rewards.shape
        assert (B, T) == reward_mask.shape
        assert (B, T + 1) == padded_observations.shape[:2]
        assert (B, T +
                1) + env.observation_space.shape == padded_observations.shape

        # Linear annealing from 0.1 to 0.0
        # epsilon_schedule = epsilon if epochs == 1 else epsilon * (1.0 -
        #                                                           (i /
        #                                                            (epochs - 1)))

        # Constant epsilon.
        epsilon_schedule = epsilon

        # Compute value and ppo losses.
        jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)
        logging.vlog(2, "Starting to compute P&V loss.")
        loss_compute_start_time = time.time()
        cur_combined_loss, cur_ppo_loss, cur_value_loss, entropy_bonus = (
            combined_loss(policy_and_value_net_params,
                          log_probabs_traj,
                          value_predictions_traj,
                          policy_and_value_net_apply,
                          padded_observations,
                          padded_actions,
                          padded_rewards,
                          reward_mask,
                          gamma=gamma,
                          lambda_=lambda_,
                          epsilon=epsilon_schedule,
                          c1=c1,
                          c2=c2,
                          rng=key1))
        loss_compute_time = get_time(loss_compute_start_time)
        logging.vlog(
            1,
            "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.",
            cur_combined_loss, cur_value_loss, cur_ppo_loss, entropy_bonus,
            get_time(loss_compute_start_time))

        jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)
        logging.vlog(1, "Policy and Value Optimization")
        optimization_start_time = time.time()
        keys = jax_random.split(key1, num=n_optimizer_steps)
        for j in range(n_optimizer_steps):
            k1, k2, k3 = jax_random.split(keys[j], num=3)
            t = time.time()
            # Update the optimizer state.
            policy_and_value_opt_state = policy_and_value_opt_step(
                j,
                policy_and_value_opt_state,
                policy_and_value_opt_update,
                policy_and_value_get_params,
                policy_and_value_net_apply,
                log_probabs_traj,
                value_predictions_traj,
                padded_observations,
                padded_actions,
                padded_rewards,
                reward_mask,
                c1=c1,
                c2=c2,
                gamma=gamma,
                lambda_=lambda_,
                epsilon=epsilon_schedule,
                rng=k1)

            # Compute the approx KL for early stopping.
            new_policy_and_value_net_params = policy_and_value_get_params(
                policy_and_value_opt_state)

            log_probab_actions_new, _ = policy_and_value_net_apply(
                padded_observations, new_policy_and_value_net_params, rng=k2)

            approx_kl = approximate_kl(log_probab_actions_new,
                                       log_probabs_traj, reward_mask)

            early_stopping = enable_early_stopping and approx_kl > 1.5 * target_kl
            if early_stopping:
                logging.vlog(
                    1,
                    "Early stopping policy and value optimization at iter: %d, "
                    "with approx_kl: %0.2f", j, approx_kl)
                # We don't return right-away, we want the below to execute on the last
                # iteration.

            t2 = time.time()
            if (((j + 1) % print_every_optimizer_steps == 0)
                    or (j == n_optimizer_steps - 1) or early_stopping):
                # Compute and log the loss.
                (loss_combined, loss_ppo, loss_value,
                 entropy_bonus) = (combined_loss(
                     new_policy_and_value_net_params,
                     log_probabs_traj,
                     value_predictions_traj,
                     policy_and_value_net_apply,
                     padded_observations,
                     padded_actions,
                     padded_rewards,
                     reward_mask,
                     gamma=gamma,
                     lambda_=lambda_,
                     epsilon=epsilon_schedule,
                     c1=c1,
                     c2=c2,
                     rng=k3))
                logging.vlog(
                    1, "One Policy and Value grad desc took: %0.2f msec",
                    get_time(t, t2))
                logging.vlog(
                    1, "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->"
                    " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss,
                    loss_combined, loss_value, loss_ppo, entropy_bonus)

            if early_stopping:
                break

        optimization_time = get_time(optimization_start_time)

        logging.vlog(
            1, "Total Combined Loss reduction [%0.2f]%%",
            (100 *
             (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss)))

        # Save parameters every time we see the end of at least a fraction of batch
        # number of trajectories that are done (not completed -- completed includes
        # truncated and done).
        # Also don't save too frequently, enforce a minimum gap.
        # Or if this is the last iteration.
        policy_save_start_time = time.time()
        n_trajectories_done += n_done
        # TODO(afrozm): Refactor to trax.save_state.
        if (((n_trajectories_done >= done_frac_for_policy_save * batch_size)
             and (i - last_saved_at > eval_every_n) and
             (((i + 1) % eval_every_n == 0))) or (i == epochs - 1)):
            logging.vlog(1, "Epoch [% 6d] saving model.", i)
            old_model_files = gfile.glob(
                os.path.join(output_dir, "model-??????.pkl"))
            params_file = os.path.join(output_dir, "model-%06d.pkl" % i)
            with gfile.GFile(params_file, "wb") as f:
                pickle.dump(policy_and_value_net_params, f)
            # Remove the old model files.
            for path in old_model_files:
                gfile.remove(path)
            # Reset this number.
            n_trajectories_done = 0
            last_saved_at = i
        policy_save_time = get_time(policy_save_start_time)

        epoch_time = get_time(epoch_start_time)

        logging.info(
            "Epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined"
            " Loss(value, ppo, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]", i,
            min_reward, max_reward, avg_reward, loss_combined, loss_value,
            loss_ppo, entropy_bonus)

        timing_dict = {
            "epoch": epoch_time,
            "policy_eval": policy_eval_time,
            "trajectory_collection": trajectory_collection_time,
            "padding": padding_time,
            "log_prob_recompute": log_prob_recompute_time,
            "loss_compute": loss_compute_time,
            "optimization": optimization_time,
            "policy_save": policy_save_time,
        }

        timing_dict.update(timing_info)

        for k, v in timing_dict.items():
            timing_sw.scalar("timing/%s" % k, v, step=i)

        max_key_len = max(len(k) for k in timing_dict)
        timing_info_list = [
            "%s : % 10.2f" % (k.rjust(max_key_len + 1), v)
            for k, v in sorted(timing_dict.items())
        ]
        logging.info("Epoch [% 6d], Timings: \n%s", i,
                     "\n".join(timing_info_list))

        # Reset restore.
        restore = False

        # Flush summary writers once in a while.
        if (i + 1) % 1000 == 0 or i == epochs - 1:
            train_sw.flush()
            timing_sw.flush()
            eval_sw.flush()
    def __init__(self,
                 train_env,
                 eval_env,
                 output_dir,
                 policy_and_value_model=trax_models.FrameStackMLP,
                 policy_and_value_optimizer=functools.partial(
                     trax_opt.Adam, learning_rate=1e-3),
                 policy_and_value_two_towers=False,
                 policy_and_value_vocab_size=None,
                 n_optimizer_steps=N_OPTIMIZER_STEPS,
                 optimizer_batch_size=64,
                 print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
                 target_kl=0.01,
                 boundary=20,
                 max_timestep=100,
                 max_timestep_eval=20000,
                 random_seed=None,
                 gamma=GAMMA,
                 lambda_=LAMBDA,
                 c1=1.0,
                 c2=0.01,
                 eval_every_n=1000,
                 save_every_n=1000,
                 done_frac_for_policy_save=0.5,
                 n_evals=1,
                 len_history_for_policy=4,
                 eval_temperatures=(1.0, 0.5),
                 separate_eval=True,
                 init_policy_from_world_model_output_dir=None,
                 **kwargs):
        """Creates the PPO trainer.

    Args:
      train_env: gym.Env to use for training.
      eval_env: gym.Env to use for evaluation.
      output_dir: Output dir.
      policy_and_value_model: Function defining the policy and value network,
        without the policy and value heads.
      policy_and_value_optimizer: Function defining the optimizer.
      policy_and_value_two_towers: Whether to use two separate models as the
        policy and value networks. If False, share their parameters.
      policy_and_value_vocab_size: Vocabulary size of a policy and value network
        operating on serialized representation. If None, use raw continuous
        representation.
      n_optimizer_steps: Number of optimizer steps.
      optimizer_batch_size: Batch size of an optimizer step.
      print_every_optimizer_steps: How often to log during the policy
        optimization process.
      target_kl: Policy iteration early stopping. Set to infinity to disable
        early stopping.
      boundary: We pad trajectories at integer multiples of this number.
      max_timestep: If set to an integer, maximum number of time-steps in a
        trajectory. Used in the collect procedure.
      max_timestep_eval: If set to an integer, maximum number of time-steps in
        an evaluation trajectory. Used in the collect procedure.
      random_seed: Random seed.
      gamma: Reward discount factor.
      lambda_: N-step TD-error discount factor in GAE.
      c1: Value loss coefficient.
      c2: Entropy loss coefficient.
      eval_every_n: How frequently to eval the policy.
      save_every_n: How frequently to save the policy.
      done_frac_for_policy_save: Fraction of the trajectories that should be
        done to checkpoint the policy.
      n_evals: Number of times to evaluate.
      len_history_for_policy: How much of history to give to the policy.
      eval_temperatures: Sequence of temperatures to try for categorical
        sampling during evaluation.
      separate_eval: Whether to run separate evaluation using a set of
        temperatures. If False, the training reward is reported as evaluation
        reward with temperature 1.0.
      init_policy_from_world_model_output_dir: Model output dir for initializing
        the policy. If None, initialize randomly.
      **kwargs: Additional keyword arguments passed to the base class.
    """
        # Set in base class constructor.
        self._train_env = None
        self._should_reset = None

        super(PPO, self).__init__(train_env, eval_env, output_dir, **kwargs)

        self._n_optimizer_steps = n_optimizer_steps
        self._optimizer_batch_size = optimizer_batch_size
        self._print_every_optimizer_steps = print_every_optimizer_steps
        self._target_kl = target_kl
        self._boundary = boundary
        self._max_timestep = max_timestep
        self._max_timestep_eval = max_timestep_eval
        self._gamma = gamma
        self._lambda_ = lambda_
        self._c1 = c1
        self._c2 = c2
        self._eval_every_n = eval_every_n
        self._save_every_n = save_every_n
        self._done_frac_for_policy_save = done_frac_for_policy_save
        self._n_evals = n_evals
        self._len_history_for_policy = len_history_for_policy
        self._eval_temperatures = eval_temperatures
        self._separate_eval = separate_eval

        action_space = self.train_env.action_space
        assert isinstance(action_space,
                          (gym.spaces.Discrete, gym.spaces.MultiDiscrete))
        if isinstance(action_space, gym.spaces.Discrete):
            n_actions = action_space.n
            n_controls = 1
        else:
            (n_controls, ) = action_space.nvec.shape
            assert n_controls > 0
            assert onp.min(action_space.nvec) == onp.max(action_space.nvec), (
                "Every control must have the same number of actions.")
            n_actions = action_space.nvec[0]
        self._n_actions = n_actions
        self._n_controls = n_controls

        self._rng = trax.get_random_number_generator_and_set_seed(random_seed)
        self._rng, key1 = jax_random.split(self._rng, num=2)

        vocab_size = policy_and_value_vocab_size
        self._serialized_sequence_policy = vocab_size is not None
        if self._serialized_sequence_policy:
            self._serialization_kwargs = self._init_serialization(vocab_size)
        else:
            self._serialization_kwargs = {}

        # Initialize the policy and value network.
        policy_and_value_net = ppo.policy_and_value_net(
            n_actions=n_actions,
            n_controls=n_controls,
            vocab_size=vocab_size,
            bottom_layers_fn=policy_and_value_model,
            two_towers=policy_and_value_two_towers,
        )
        self._policy_and_value_net_apply = jit(policy_and_value_net)
        (batch_obs_shape, obs_dtype) = self._batch_obs_shape_and_dtype
        policy_and_value_net_params, self._model_state = (
            policy_and_value_net.initialize_once(batch_obs_shape, obs_dtype,
                                                 key1))
        if init_policy_from_world_model_output_dir is not None:
            policy_and_value_net_params = ppo.init_policy_from_world_model_checkpoint(
                policy_and_value_net_params,
                init_policy_from_world_model_output_dir)

        # Initialize the optimizer.
        (policy_and_value_opt_state, self._policy_and_value_opt_update,
         self._policy_and_value_get_params) = ppo.optimizer_fn(
             policy_and_value_optimizer, policy_and_value_net_params)

        # Restore the optimizer state.
        self._policy_and_value_opt_state = policy_and_value_opt_state
        self._epoch = 0
        self._total_opt_step = 0
        self.update_optimization_state(
            output_dir, policy_and_value_opt_state=policy_and_value_opt_state)

        # Create summary writers and history.
        self._train_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "train"))
        self._timing_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "timing"))
        self._eval_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "eval"))

        self._n_trajectories_done = 0

        self._last_saved_at = 0
        if self._async_mode:
            logging.info(
                "Saving model on startup to have a model policy file.")
            self.save()

        self._rewards_to_actions = self._init_rewards_to_actions()
Example #4
0
 def setUp(self):
   super(PpoTest, self).setUp()
   self.rng_key = trax.get_random_number_generator_and_set_seed(0)
Example #5
0
 def setUp(self):
     self.rng_key = trax.get_random_number_generator_and_set_seed(0)
Example #6
0
def training_loop(
        env=None,
        env_name="CartPole-v0",
        epochs=EPOCHS,
        policy_net_fun=None,
        value_net_fun=None,
        policy_and_value_net_fun=None,  # TODO(afrozm): Implement.
        policy_optimizer_fun=optimizer_fun,
        value_optimizer_fun=optimizer_fun,
        batch_size=BATCH_TRAJECTORIES,
        num_optimizer_steps=NUM_OPTIMIZER_STEPS,
        print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
        boundary=20,
        max_timestep=None,
        random_seed=None):
    """Runs the training loop for PPO, with fixed policy and value nets."""
    jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed)

    value_losses = []
    ppo_objective = []
    average_rewards = []

    env = env if env is not None else gym.make(env_name)

    # Batch Observations Shape = [-1, -1] + OBS, because we will eventually call
    # policy and value networks on shape [B, T] +_OBS
    batch_observations_shape = (-1, -1) + env.observation_space.shape

    assert isinstance(env.action_space, gym.spaces.Discrete)
    num_actions = env.action_space.n

    # TODO(afrozm): Have a single net for both policy and action.
    assert policy_and_value_net_fun is None

    # Initialize the policy and value functions.
    assert policy_net_fun and value_net_fun
    jax_rng_key, key1, key2 = jax_random.split(jax_rng_key, num=3)

    policy_net_params, policy_net_apply = policy_net_fun(
        key1, batch_observations_shape, num_actions)
    value_net_params, value_net_apply = value_net_fun(
        key2, batch_observations_shape, num_actions)

    # Initialize the optimizers.
    assert policy_optimizer_fun and value_optimizer_fun

    ppo_opt_state, ppo_opt_update = policy_optimizer_fun(policy_net_params)
    value_opt_state, value_opt_update = value_optimizer_fun(value_net_params)

    for i in range(epochs):
        t = time.time()
        t0 = t
        logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i)
        trajs = collect_trajectories(
            env,
            policy_net_apply,
            policy_net_params,
            num_trajectories=batch_size,
            policy=POLICY,
            max_timestep=max_timestep,
            epsilon=(10.0 / (i + 10.0)))  # this is a different epsilon.

        avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs)
        max_reward = max(np.sum(traj[2]) for traj in trajs)
        min_reward = min(np.sum(traj[2]) for traj in trajs)
        average_rewards.append(avg_reward)

        logging.vlog(1, "Rewards average=[%0.2f], max=[%0.2f], min=[%0.2f]",
                     avg_reward, max_reward, min_reward)
        logging.vlog(1, "Collecting trajectories took %0.2f msec.",
                     get_time(t))
        logging.vlog(
            1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]",
            float(sum(len(traj[0]) for traj in trajs)) / len(trajs),
            max(len(traj[0]) for traj in trajs),
            min(len(traj[0]) for traj in trajs))

        t = time.time()
        (_, reward_mask, padded_observations, padded_actions,
         padded_rewards) = pad_trajectories(trajs, boundary=boundary)

        logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(t))
        logging.vlog(1, "Padded Observations' shape [%s]",
                     str(padded_observations.shape))
        logging.vlog(1, "Padded Actions' shape [%s]",
                     str(padded_actions.shape))
        logging.vlog(1, "Padded Rewards' shape [%s]",
                     str(padded_rewards.shape))

        # Some assertions.
        B, T = padded_actions.shape  # pylint: disable=invalid-name
        assert (B, T) == padded_rewards.shape
        assert (B, T) == reward_mask.shape
        assert (B, T + 1) == padded_observations.shape[:2]
        assert (B, T +
                1) + env.observation_space.shape == padded_observations.shape

        # Linear annealing from 0.1 to 0.0
        epsilon = 0.1 if epochs == 1 else 0.1 * (1.0 - (i / (epochs - 1)))

        t = time.time()
        cur_value_loss = value_loss(value_net_apply,
                                    value_net_params,
                                    padded_observations,
                                    padded_rewards,
                                    reward_mask,
                                    gamma=GAMMA)

        logging.vlog(1, "Calculating value loss took %0.2f msec.", get_time(t))
        value_losses.append(cur_value_loss)

        t = time.time()
        cur_ppo_loss = ppo_loss(policy_net_apply,
                                policy_net_params,
                                policy_net_params,
                                value_net_apply,
                                value_net_params,
                                padded_observations,
                                padded_actions,
                                padded_rewards,
                                reward_mask,
                                gamma=GAMMA,
                                lambda_=LAMBDA,
                                epsilon=epsilon)
        # ppo_loss = 11.00110011
        logging.vlog(1, "Calculating PPO loss took %0.2f msec.", get_time(t))
        ppo_objective.append(-cur_ppo_loss)

        # Run optimizers.
        logging.vlog(1, "PPO Optimization")
        t1 = time.time()

        for j in range(num_optimizer_steps):
            t = time.time()
            # Update the optimizer state.
            ppo_opt_state = ppo_opt_step(j,
                                         ppo_opt_state,
                                         ppo_opt_update,
                                         policy_net_apply,
                                         policy_net_params,
                                         value_net_apply,
                                         value_net_params,
                                         padded_observations,
                                         padded_actions,
                                         padded_rewards,
                                         reward_mask,
                                         gamma=GAMMA,
                                         lambda_=LAMBDA,
                                         epsilon=epsilon)
            t2 = time.time()
            # Get the new params.
            new_policy_net_params = trax_opt.get_params(ppo_opt_state)
            if ((j + 1) % print_every_optimizer_steps
                    == 0) or (j == num_optimizer_steps - 1):
                new_ppo_loss = ppo_loss(policy_net_apply,
                                        new_policy_net_params,
                                        policy_net_params,
                                        value_net_apply,
                                        value_net_params,
                                        padded_observations,
                                        padded_actions,
                                        padded_rewards,
                                        reward_mask,
                                        gamma=GAMMA,
                                        lambda_=LAMBDA,
                                        epsilon=epsilon)
                logging.vlog(1, "One PPO grad desc took: %0.2f msec",
                             get_time(t, t2))
                logging.vlog(1, "PPO loss [%10.2f] -> [%10.2f]", cur_ppo_loss,
                             new_ppo_loss)
            # Update the params.
            policy_net_params = new_policy_net_params

        logging.vlog(1, "Total PPO loss reduction [%0.2f]%%",
                     (100 *
                      (cur_ppo_loss - new_ppo_loss) / np.abs(cur_ppo_loss)))

        logging.vlog(1, "Value Optimization")

        for j in range(num_optimizer_steps):
            t = time.time()
            value_opt_state = value_opt_step(j,
                                             value_opt_state,
                                             value_opt_update,
                                             value_net_apply,
                                             padded_observations,
                                             padded_rewards,
                                             reward_mask,
                                             gamma=GAMMA)
            t2 = time.time()
            value_net_params = trax_opt.get_params(value_opt_state)
            if ((j + 1) % print_every_optimizer_steps
                    == 0) or (j == num_optimizer_steps - 1):
                new_value_loss = value_loss(value_net_apply,
                                            value_net_params,
                                            padded_observations,
                                            padded_rewards,
                                            reward_mask,
                                            gamma=GAMMA)
                logging.vlog(1, "One value grad desc took: %0.2f msec",
                             get_time(t, t2))
                logging.vlog(1, "Value loss [%10.2f] -> [%10.2f]",
                             cur_value_loss, new_value_loss)
        logging.vlog(
            1, "Total value loss reduction [%0.2f]%%",
            (100 * (cur_value_loss - new_value_loss) / np.abs(cur_value_loss)))

        logging.vlog(1, "Grad desc took %0.2f msec", get_time(t1))

        # Set the optimized params to new params.
        policy_net_params = trax_opt.get_params(ppo_opt_state)
        value_net_params = trax_opt.get_params(value_opt_state)

        logging.info(
            "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], "
            "ppo loss [%10.2f], value loss [%10.2f], took [%10.2f msec]", i,
            min_reward, max_reward, avg_reward, new_ppo_loss, new_value_loss,
            get_time(t0))

    logging.vlog(1, "value_losses: %s", np.stack(value_losses))
    logging.vlog(1, "ppo_objective: %s", np.stack(ppo_objective))
    logging.vlog(1, "average_rewards: %s", average_rewards)

    return ((policy_net_params, value_net_params), average_rewards,
            np.stack(value_losses), np.stack(ppo_objective))
Example #7
0
def training_loop(env=None,
                  env_name="CartPole-v0",
                  epochs=EPOCHS,
                  policy_net_fun=None,
                  value_net_fun=None,
                  policy_and_value_net_fun=None,
                  policy_optimizer_fun=None,
                  value_optimizer_fun=None,
                  policy_and_value_optimizer_fun=None,
                  batch_size=BATCH_TRAJECTORIES,
                  num_optimizer_steps=NUM_OPTIMIZER_STEPS,
                  print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
                  boundary=20,
                  max_timestep=None,
                  random_seed=None,
                  gamma=GAMMA,
                  lambda_=LAMBDA,
                  epsilon=EPSILON,
                  c1=1.0,
                  c2=0.01):
    """Runs the training loop for PPO, with fixed policy and value nets."""
    jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed)

    value_losses = []
    ppo_objective = []
    combined_losses = []
    average_rewards = []

    env = env if env is not None else gym.make(env_name)

    # Batch Observations Shape = [-1, -1] + OBS, because we will eventually call
    # policy and value networks on shape [B, T] +_OBS
    batch_observations_shape = (-1, -1) + env.observation_space.shape

    assert isinstance(env.action_space, gym.spaces.Discrete)
    num_actions = env.action_space.n

    policy_and_value_net_params, policy_and_value_net_apply = None, None
    policy_and_value_opt_state, policy_and_value_opt_update = None, None
    policy_net_params, policy_net_apply = None, None
    value_net_params, value_net_apply = None, None
    if policy_and_value_net_fun is not None:
        jax_rng_key, subkey = jax_random.split(jax_rng_key)

        # Initialize the policy and value network.
        policy_and_value_net_params, policy_and_value_net_apply = (
            policy_and_value_net_fun(subkey, batch_observations_shape,
                                     num_actions))

        # Initialize the optimizers.
        policy_and_value_opt_state, policy_and_value_opt_update = (
            policy_and_value_optimizer_fun(policy_and_value_net_params))
    else:
        # Initialize the policy and value functions.
        assert policy_net_fun and value_net_fun
        jax_rng_key, key1, key2 = jax_random.split(jax_rng_key, num=3)

        policy_net_params, policy_net_apply = policy_net_fun(
            key1, batch_observations_shape, num_actions)
        value_net_params, value_net_apply = value_net_fun(
            key2, batch_observations_shape, num_actions)

        # Initialize the optimizers.
        ppo_opt_state, ppo_opt_update = policy_optimizer_fun(policy_net_params)
        value_opt_state, value_opt_update = value_optimizer_fun(
            value_net_params)

    # A function that will call the appropriate policy function with parameters.
    def get_policy_output(observations):
        if policy_net_apply is not None:
            assert policy_net_params
            return policy_net_apply(observations, policy_net_params)

        assert policy_and_value_net_apply and policy_and_value_net_params
        policy_predictions, unused_value_predictions = policy_and_value_net_apply(
            observations, policy_and_value_net_params)
        return policy_predictions

    for i in range(epochs):
        t = time.time()
        t0 = t
        logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i)
        trajs = collect_trajectories(
            env,
            policy_fun=get_policy_output,
            num_trajectories=batch_size,
            policy=POLICY,
            max_timestep=max_timestep,
            epsilon=(10.0 / (i + 10.0)))  # this is a different epsilon.

        avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs)
        max_reward = max(np.sum(traj[2]) for traj in trajs)
        min_reward = min(np.sum(traj[2]) for traj in trajs)
        average_rewards.append(avg_reward)

        logging.vlog(1, "Rewards average=[%0.2f], max=[%0.2f], min=[%0.2f]",
                     avg_reward, max_reward, min_reward)
        logging.vlog(1, "Collecting trajectories took %0.2f msec.",
                     get_time(t))
        logging.vlog(
            1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]",
            float(sum(len(traj[0]) for traj in trajs)) / len(trajs),
            max(len(traj[0]) for traj in trajs),
            min(len(traj[0]) for traj in trajs))

        t = time.time()
        (_, reward_mask, padded_observations, padded_actions,
         padded_rewards) = pad_trajectories(trajs, boundary=boundary)

        logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(t))
        logging.vlog(1, "Padded Observations' shape [%s]",
                     str(padded_observations.shape))
        logging.vlog(1, "Padded Actions' shape [%s]",
                     str(padded_actions.shape))
        logging.vlog(1, "Padded Rewards' shape [%s]",
                     str(padded_rewards.shape))

        # Some assertions.
        B, T = padded_actions.shape  # pylint: disable=invalid-name
        assert (B, T) == padded_rewards.shape
        assert (B, T) == reward_mask.shape
        assert (B, T + 1) == padded_observations.shape[:2]
        assert (B, T +
                1) + env.observation_space.shape == padded_observations.shape

        # Linear annealing from 0.1 to 0.0
        epsilon_schedule = epsilon if epochs == 1 else epsilon * (
            1.0 - (i / (epochs - 1)))

        # Compute value and ppo losses.
        cur_value_loss, cur_ppo_loss, cur_combined_loss = None, None, None
        if policy_and_value_net_apply is not None:
            t = time.time()
            cur_combined_loss, cur_ppo_loss, cur_value_loss, _ = (
                combined_loss(policy_and_value_net_params,
                              policy_and_value_net_params,
                              policy_and_value_net_apply,
                              padded_observations,
                              padded_actions,
                              padded_rewards,
                              reward_mask,
                              gamma=gamma,
                              lambda_=lambda_,
                              epsilon=epsilon_schedule,
                              c1=c1,
                              c2=c2))
            logging.vlog(
                1,
                "Calculating P&V loss [%10.2f(%10.2f, %10.2f)] took %0.2f msec.",
                cur_combined_loss, cur_value_loss, cur_ppo_loss, get_time(t))
        else:
            t = time.time()
            cur_value_loss = value_loss(value_net_apply,
                                        value_net_params,
                                        padded_observations,
                                        padded_rewards,
                                        reward_mask,
                                        gamma=gamma)

            logging.vlog(1, "Calculating value loss took %0.2f msec.",
                         get_time(t))

            t = time.time()
            cur_ppo_loss = ppo_loss(policy_net_apply,
                                    policy_net_params,
                                    policy_net_params,
                                    value_net_apply,
                                    value_net_params,
                                    padded_observations,
                                    padded_actions,
                                    padded_rewards,
                                    reward_mask,
                                    gamma=gamma,
                                    lambda_=lambda_,
                                    epsilon=epsilon_schedule)
            logging.vlog(1, "Calculating PPO loss took %0.2f msec.",
                         get_time(t))

        value_losses.append(cur_value_loss)
        ppo_objective.append(-1.0 * cur_ppo_loss)
        combined_losses.append(cur_combined_loss)

        if policy_and_value_net_apply:
            logging.vlog(1, "Policy and Value Optimization")
            t1 = time.time()
            for j in range(num_optimizer_steps):
                t = time.time()
                # Update the optimizer state.
                policy_and_value_opt_state = policy_and_value_opt_step(
                    j,
                    policy_and_value_opt_state,
                    policy_and_value_opt_update,
                    policy_and_value_net_apply,
                    policy_and_value_net_params,
                    padded_observations,
                    padded_actions,
                    padded_rewards,
                    reward_mask,
                    c1=c1,
                    c2=c2,
                    gamma=gamma,
                    lambda_=lambda_,
                    epsilon=epsilon_schedule)
                t2 = time.time()
                # Get the new params.
                new_policy_and_value_net_params = trax_opt.get_params(
                    policy_and_value_opt_state)
                if ((j + 1) % print_every_optimizer_steps
                        == 0) or (j == num_optimizer_steps - 1):
                    # Compute and log the loss.
                    (loss_combined, loss_ppo, loss_value,
                     unused_entropy_bonus) = (
                         combined_loss(
                             new_policy_and_value_net_params,
                             policy_and_value_net_params,  # old params
                             policy_and_value_net_apply,
                             padded_observations,
                             padded_actions,
                             padded_rewards,
                             reward_mask,
                             gamma=gamma,
                             lambda_=lambda_,
                             epsilon=epsilon_schedule,
                             c1=c1,
                             c2=c2))
                    logging.vlog(
                        1, "One Policy and Value grad desc took: %0.2f msec",
                        get_time(t, t2))
                    logging.vlog(
                        1,
                        "Combined Loss(value, ppo) [%10.2f] -> [%10.2f(%10.2f,%10.2f)]",
                        cur_combined_loss, loss_combined, loss_value, loss_ppo)
                # Update the params.
                policy_and_value_net_params = new_policy_and_value_net_params

            logging.vlog(1, "Total PPO loss reduction [%0.2f]%%",
                         (100 * (cur_combined_loss - loss_combined) /
                          np.abs(cur_combined_loss)))

            logging.info(
                "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], Combined"
                " Loss(value, ppo) [%10.2f(%10.2f,%10.2f)], took [%10.2f msec]",
                i, min_reward, max_reward, avg_reward, loss_combined,
                loss_value, loss_ppo, get_time(t1))
        else:
            # Run optimizers.
            logging.vlog(1, "PPO Optimization")
            t1 = time.time()

            for j in range(num_optimizer_steps):
                t = time.time()
                # Update the optimizer state.
                ppo_opt_state = ppo_opt_step(
                    j,
                    ppo_opt_state,
                    ppo_opt_update,
                    policy_net_apply,
                    policy_net_params,
                    value_net_apply,
                    value_net_params,
                    padded_observations,
                    padded_actions,
                    padded_rewards,
                    reward_mask,
                    gamma=gamma,
                    lambda_=lambda_,
                    epsilon=epsilon_schedule,
                )
                t2 = time.time()
                # Get the new params.
                new_policy_net_params = trax_opt.get_params(ppo_opt_state)
                if ((j + 1) % print_every_optimizer_steps
                        == 0) or (j == num_optimizer_steps - 1):
                    new_ppo_loss = ppo_loss(
                        policy_net_apply,
                        new_policy_net_params,
                        policy_net_params,
                        value_net_apply,
                        value_net_params,
                        padded_observations,
                        padded_actions,
                        padded_rewards,
                        reward_mask,
                        gamma=gamma,
                        lambda_=lambda_,
                        epsilon=epsilon_schedule,
                    )
                    logging.vlog(1, "One PPO grad desc took: %0.2f msec",
                                 get_time(t, t2))
                    logging.vlog(1, "PPO loss [%10.2f] -> [%10.2f]",
                                 cur_ppo_loss, new_ppo_loss)
                # Update the params.
                policy_net_params = new_policy_net_params

            logging.vlog(
                1, "Total PPO loss reduction [%0.2f]%%",
                (100 * (cur_ppo_loss - new_ppo_loss) / np.abs(cur_ppo_loss)))

            logging.vlog(1, "Value Optimization")

            for j in range(num_optimizer_steps):
                t = time.time()
                value_opt_state = value_opt_step(j,
                                                 value_opt_state,
                                                 value_opt_update,
                                                 value_net_apply,
                                                 padded_observations,
                                                 padded_rewards,
                                                 reward_mask,
                                                 gamma=gamma)
                t2 = time.time()
                value_net_params = trax_opt.get_params(value_opt_state)
                if ((j + 1) % print_every_optimizer_steps
                        == 0) or (j == num_optimizer_steps - 1):
                    new_value_loss = value_loss(value_net_apply,
                                                value_net_params,
                                                padded_observations,
                                                padded_rewards,
                                                reward_mask,
                                                gamma=gamma)
                    logging.vlog(1, "One value grad desc took: %0.2f msec",
                                 get_time(t, t2))
                    logging.vlog(1, "Value loss [%10.2f] -> [%10.2f]",
                                 cur_value_loss, new_value_loss)
            logging.vlog(
                1, "Total value loss reduction [%0.2f]%%",
                (100 *
                 (cur_value_loss - new_value_loss) / np.abs(cur_value_loss)))

            logging.vlog(1, "Grad desc took %0.2f msec", get_time(t1))

            # Set the optimized params to new params.
            policy_net_params = trax_opt.get_params(ppo_opt_state)
            value_net_params = trax_opt.get_params(value_opt_state)

            logging.info(
                "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], "
                "ppo loss [%10.2f], value loss [%10.2f], took [%10.2f msec]",
                i, min_reward, max_reward, avg_reward, new_ppo_loss,
                new_value_loss, get_time(t0))

    # Log the parameters, just for the sake of it.
    if policy_net_params:
        log_params(policy_net_params, "policy_net_params")
    if value_net_params:
        log_params(value_net_params, "value_net_params")
    if policy_and_value_net_params:
        log_params(policy_and_value_net_params, "policy_and_value_net_params")

    if value_losses:
        logging.vlog(1, "value_losses: %s", np.stack(value_losses))
    if ppo_objective:
        logging.vlog(1, "ppo_objective: %s", np.stack(ppo_objective))
    if average_rewards:
        logging.vlog(1, "average_rewards: %s", average_rewards)

    return ((policy_net_params, value_net_params), average_rewards,
            np.stack(value_losses), np.stack(ppo_objective))
Example #8
0
def training_loop(
    env=None,
    epochs=EPOCHS,
    policy_net_fun=None,
    value_net_fun=None,
    policy_and_value_net_fun=None,
    policy_optimizer_fun=None,
    value_optimizer_fun=None,
    policy_and_value_optimizer_fun=None,
    batch_size=BATCH_TRAJECTORIES,
    num_optimizer_steps=NUM_OPTIMIZER_STEPS,
    policy_only_num_optimizer_steps=POLICY_ONLY_NUM_OPTIMIZER_STEPS,
    value_only_num_optimizer_steps=VALUE_ONLY_NUM_OPTIMIZER_STEPS,
    print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
    target_kl=0.01,
    boundary=20,
    max_timestep=None,
    random_seed=None,
    gamma=GAMMA,
    lambda_=LAMBDA,
    epsilon=EPSILON,
    c1=1.0,
    c2=0.01):
  """Runs the training loop for PPO, with fixed policy and value nets."""
  assert env
  jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed)

  value_losses = []
  ppo_objective = []
  combined_losses = []
  average_rewards = []

  # Batch Observations Shape = [-1, -1] + OBS, because we will eventually call
  # policy and value networks on shape [B, T] +_OBS
  batch_observations_shape = (-1, -1) + env.observation_space.shape

  assert isinstance(env.action_space, gym.spaces.Discrete)
  num_actions = env.action_space.n

  policy_and_value_net_params, policy_and_value_net_apply = None, None
  policy_and_value_opt_state, policy_and_value_opt_update = None, None
  policy_net_params, policy_net_apply = None, None
  value_net_params, value_net_apply = None, None
  if policy_and_value_net_fun is not None:
    jax_rng_key, subkey = jax_random.split(jax_rng_key)

    # Initialize the policy and value network.
    policy_and_value_net_params, policy_and_value_net_apply = (
        policy_and_value_net_fun(subkey, batch_observations_shape, num_actions))

    # Initialize the optimizers.
    policy_and_value_opt_state, policy_and_value_opt_update = (
        policy_and_value_optimizer_fun(policy_and_value_net_params))

    policy_and_value_net_apply = jit(policy_and_value_net_apply)
  else:
    # Initialize the policy and value functions.
    assert policy_net_fun and value_net_fun
    jax_rng_key, key1, key2 = jax_random.split(jax_rng_key, num=3)

    policy_net_params, policy_net_apply = policy_net_fun(
        key1, batch_observations_shape, num_actions)
    value_net_params, value_net_apply = value_net_fun(key2,
                                                      batch_observations_shape,
                                                      num_actions)

    policy_net_apply = jit(policy_net_apply)
    value_net_apply = jit(value_net_apply)

    # Initialize the optimizers.
    ppo_opt_state, ppo_opt_update = policy_optimizer_fun(policy_net_params)
    value_opt_state, value_opt_update = value_optimizer_fun(value_net_params)

  # A function that will call the appropriate policy function with parameters.
  def get_policy_output(observations):
    # Get the fresh params for collecting the policy.
    if policy_net_apply is not None:
      return policy_net_apply(observations, trax_opt.get_params(ppo_opt_state))

    assert policy_and_value_net_apply

    policy_predictions, unused_value_predictions = policy_and_value_net_apply(
        observations, trax_opt.get_params(policy_and_value_opt_state))
    return policy_predictions

  for i in range(epochs):
    t = time.time()
    t0 = t
    logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i)
    trajs = collect_trajectories(
        env,
        policy_fun=get_policy_output,
        num_trajectories=batch_size,
        policy=POLICY,
        max_timestep=max_timestep,
        boundary=boundary,
        epsilon=(10.0 / (i + 10.0)))  # this is a different epsilon.

    logging.vlog(1, "Collecting trajectories took %0.2f msec.", get_time(t))

    # These were the params that were used to collect the trajectory.
    if policy_and_value_net_apply:
      policy_and_value_net_params = trax_opt.get_params(
          policy_and_value_opt_state)
    else:
      policy_net_params = trax_opt.get_params(ppo_opt_state)
      value_net_params = trax_opt.get_params(value_opt_state)

    avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs)
    max_reward = max(np.sum(traj[2]) for traj in trajs)
    min_reward = min(np.sum(traj[2]) for traj in trajs)
    average_rewards.append(avg_reward)

    logging.vlog(1, "Rewards average=[%0.2f], max=[%0.2f], min=[%0.2f]",
                 avg_reward, max_reward, min_reward)
    logging.vlog(2, "Rewards: %s", [float(np.sum(traj[2])) for traj in trajs])
    logging.vlog(1, "Average Rewards: %s", average_rewards)

    logging.vlog(1,
                 "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]",
                 float(sum(len(traj[0]) for traj in trajs)) / len(trajs),
                 max(len(traj[0]) for traj in trajs),
                 min(len(traj[0]) for traj in trajs))
    logging.vlog(2, "Trajectory Lengths: %s", [len(traj[0]) for traj in trajs])

    t = time.time()
    (_, reward_mask, padded_observations, padded_actions,
     padded_rewards) = pad_trajectories(
         trajs, boundary=boundary)

    logging.vlog(1, "Padding trajectories took %0.2f msec.", get_time(t))
    logging.vlog(1, "Padded Observations' shape [%s]",
                 str(padded_observations.shape))
    logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape))
    logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape))

    # Some assertions.
    B, T = padded_actions.shape  # pylint: disable=invalid-name
    assert (B, T) == padded_rewards.shape
    assert (B, T) == reward_mask.shape
    assert (B, T + 1) == padded_observations.shape[:2]
    assert (B, T + 1) + env.observation_space.shape == padded_observations.shape

    # Linear annealing from 0.1 to 0.0
    # epsilon_schedule = epsilon if epochs == 1 else epsilon * (1.0 -
    #                                                           (i /
    #                                                            (epochs - 1)))

    # Constant epsilon.
    epsilon_schedule = epsilon

    # Compute value and ppo losses.
    cur_value_loss, cur_ppo_loss, cur_combined_loss = None, None, None
    if policy_and_value_net_apply is not None:
      logging.vlog(2, "Starting to compute P&V loss.")
      t = time.time()
      cur_combined_loss, cur_ppo_loss, cur_value_loss, _ = (
          combined_loss(
              policy_and_value_net_params,
              policy_and_value_net_params,
              policy_and_value_net_apply,
              padded_observations,
              padded_actions,
              padded_rewards,
              reward_mask,
              gamma=gamma,
              lambda_=lambda_,
              epsilon=epsilon_schedule,
              c1=c1,
              c2=c2))
      logging.vlog(
          1, "Calculating P&V loss [%10.2f(%10.2f, %10.2f)] took %0.2f msec.",
          cur_combined_loss, cur_value_loss, cur_ppo_loss, get_time(t))
    else:
      t = time.time()
      cur_value_loss = value_loss(
          value_net_apply,
          value_net_params,
          padded_observations,
          padded_rewards,
          reward_mask,
          gamma=gamma)

      logging.vlog(1, "Calculating value loss took %0.2f msec.", get_time(t))

      t = time.time()
      cur_ppo_loss = ppo_loss(
          policy_net_apply,
          policy_net_params,
          policy_net_params,
          value_net_apply,
          value_net_params,
          padded_observations,
          padded_actions,
          padded_rewards,
          reward_mask,
          gamma=gamma,
          lambda_=lambda_,
          epsilon=epsilon_schedule)
      logging.vlog(1, "Calculating PPO loss took %0.2f msec.", get_time(t))

    value_losses.append(cur_value_loss)
    ppo_objective.append(-1.0 * cur_ppo_loss)
    if cur_combined_loss:
      combined_losses.append(cur_combined_loss)

    if policy_and_value_net_apply:
      logging.vlog(1, "Policy and Value Optimization")
      t1 = time.time()
      for j in range(num_optimizer_steps):
        t = time.time()
        # Update the optimizer state.
        policy_and_value_opt_state = policy_and_value_opt_step(
            j,
            policy_and_value_opt_state,
            policy_and_value_opt_update,
            policy_and_value_net_apply,
            # for the entirety of this loop, this should refer to params that
            # were used to collect the trajectory.
            policy_and_value_net_params,
            padded_observations,
            padded_actions,
            padded_rewards,
            reward_mask,
            c1=c1,
            c2=c2,
            gamma=gamma,
            lambda_=lambda_,
            epsilon=epsilon_schedule)
        t2 = time.time()
        if ((j + 1) %
            print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1):
          # Compute and log the loss.
          # Get the new params.
          new_policy_and_value_net_params = trax_opt.get_params(
              policy_and_value_opt_state)
          (loss_combined, loss_ppo, loss_value, unused_entropy_bonus) = (
              combined_loss(
                  new_policy_and_value_net_params,
                  # old params, that were used to collect the trajectory
                  policy_and_value_net_params,
                  policy_and_value_net_apply,
                  padded_observations,
                  padded_actions,
                  padded_rewards,
                  reward_mask,
                  gamma=gamma,
                  lambda_=lambda_,
                  epsilon=epsilon_schedule,
                  c1=c1,
                  c2=c2))
          logging.vlog(1, "One Policy and Value grad desc took: %0.2f msec",
                       get_time(t, t2))
          logging.vlog(
              1,
              "Combined Loss(value, ppo) [%10.2f] -> [%10.2f(%10.2f,%10.2f)]",
              cur_combined_loss, loss_combined, loss_value, loss_ppo)

      # Update the params.
      policy_and_value_net_params = new_policy_and_value_net_params

      logging.vlog(
          1, "Total Combined Loss reduction [%0.2f]%%",
          (100 *
           (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss)))

      logging.info(
          "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], Combined"
          " Loss(value, ppo) [%10.2f(%10.2f,%10.2f)], took [%10.2f msec]", i,
          min_reward, max_reward, avg_reward, loss_combined, loss_value,
          loss_ppo, get_time(t1))
    else:
      # Run optimizers.
      logging.vlog(1, "PPO Optimization")
      t1 = time.time()

      for j in range(policy_only_num_optimizer_steps):
        t = time.time()
        # Update the optimizer state.
        ppo_opt_state = ppo_opt_step(
            j,
            ppo_opt_state,
            ppo_opt_update,
            policy_net_apply,
            policy_net_params,
            value_net_apply,
            value_net_params,
            padded_observations,
            padded_actions,
            padded_rewards,
            reward_mask,
            gamma=gamma,
            lambda_=lambda_,
            epsilon=epsilon_schedule,
        )
        t2 = time.time()
        # Get the new params.
        new_policy_net_params = trax_opt.get_params(ppo_opt_state)

        # These are the "old" params - policy_net_params

        # Compute the approx KL for early stopping.
        log_probab_actions_old = policy_net_apply(padded_observations,
                                                  policy_net_params)
        log_probab_actions_new = policy_net_apply(padded_observations,
                                                  new_policy_net_params)

        approx_kl = np.mean(log_probab_actions_old - log_probab_actions_new)

        early_stopping = approx_kl > 1.5 * target_kl
        if early_stopping:
          logging.vlog(
              1, "Early stopping policy optimization at iter: %d, "
              "with approx_kl: %0.2f", j, approx_kl)
          # We don't return right-away, we want the below to execute on the last
          # iteration.

        if (((j + 1) % print_every_optimizer_steps == 0) or
            (j == num_optimizer_steps - 1) or early_stopping):
          new_ppo_loss = ppo_loss(
              policy_net_apply,
              new_policy_net_params,
              policy_net_params,
              value_net_apply,
              value_net_params,
              padded_observations,
              padded_actions,
              padded_rewards,
              reward_mask,
              gamma=gamma,
              lambda_=lambda_,
              epsilon=epsilon_schedule,
          )
          logging.vlog(1, "One PPO grad desc took: %0.2f msec", get_time(t, t2))
          logging.vlog(1, "PPO loss [%10.2f] -> [%10.2f]", cur_ppo_loss,
                       new_ppo_loss)

        if early_stopping:
          break

      # Update the params ONLY AND ONLY AFTER we complete all the optimization
      # iterations, till then `policy_net_params` should refer to the params
      # that were used in collecting the policy.
      # policy_net_params = trax_opt.get_params(ppo_opt_state)

      logging.vlog(1, "Total PPO loss reduction [%0.2f]%%",
                   (100 * (cur_ppo_loss - new_ppo_loss) / np.abs(cur_ppo_loss)))

      logging.vlog(1, "Value Optimization")

      for j in range(value_only_num_optimizer_steps):
        t = time.time()
        value_opt_state = value_opt_step(
            j,
            value_opt_state,
            value_opt_update,
            value_net_apply,
            padded_observations,
            padded_rewards,
            reward_mask,
            gamma=gamma)
        t2 = time.time()
        value_net_params = trax_opt.get_params(value_opt_state)
        if ((j + 1) %
            print_every_optimizer_steps == 0) or (j == num_optimizer_steps - 1):
          new_value_loss = value_loss(
              value_net_apply,
              value_net_params,
              padded_observations,
              padded_rewards,
              reward_mask,
              gamma=gamma)
          logging.vlog(1, "One value grad desc took: %0.2f msec",
                       get_time(t, t2))
          logging.vlog(1, "Value loss [%10.2f] -> [%10.2f]", cur_value_loss,
                       new_value_loss)
      logging.vlog(1, "Total value loss reduction [%0.2f]%%",
                   (100 *
                    (cur_value_loss - new_value_loss) / np.abs(cur_value_loss)))

      logging.vlog(1, "Grad desc took %0.2f msec", get_time(t1))

      # Set the optimized params to new params.
      policy_net_params = trax_opt.get_params(ppo_opt_state)
      value_net_params = trax_opt.get_params(value_opt_state)

      logging.info(
          "Epoch [% 6d], Reward[min, max, avg] [%10.2f,%10.2f,%10.2f], "
          "ppo loss [%10.2f], value loss [%10.2f], took [%10.2f msec]", i,
          min_reward, max_reward, avg_reward, new_ppo_loss, new_value_loss,
          get_time(t0))

  # Log the parameters, just for the sake of it.
  if policy_net_params:
    log_params(policy_net_params, "policy_net_params")
  if value_net_params:
    log_params(value_net_params, "value_net_params")
  if policy_and_value_net_params:
    log_params(policy_and_value_net_params, "policy_and_value_net_params")

  if value_losses:
    logging.vlog(1, "value_losses: %s", np.stack(value_losses))
  if ppo_objective:
    logging.vlog(1, "ppo_objective:\n%s", np.stack(ppo_objective))
  if combined_losses:
    logging.vlog(1, "combined_losses:\n%s", np.stack(combined_losses))
  if average_rewards:
    logging.vlog(1, "average_rewards:\n%s", average_rewards)

  return ((policy_net_params, value_net_params), average_rewards,
          np.stack(value_losses), np.stack(ppo_objective))
Example #9
0
def training_loop(
        env,
        eval_env,
        env_name,
        policy_and_value_net_fn,
        policy_and_value_optimizer_fn,
        output_dir,
        epochs=EPOCHS,
        n_optimizer_steps=N_OPTIMIZER_STEPS,
        print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
        target_kl=0.01,
        boundary=20,
        max_timestep=None,
        max_timestep_eval=20000,
        random_seed=None,
        gamma=GAMMA,
        lambda_=LAMBDA,
        epsilon=EPSILON,
        c1=1.0,
        c2=0.01,
        eval_every_n=1000,
        done_frac_for_policy_save=0.5,
        enable_early_stopping=True,
        n_evals=1,
        len_history_for_policy=4,
        eval_temperatures=(1.0, 0.5),
):
    """Runs the training loop for PPO, with fixed policy and value nets.

  Args:
    env: gym.Env to use for training.
    eval_env: gym.Env to use for evaluation.
    env_name: Name of the environment.
    policy_and_value_net_fn: Function defining the policy and value network.
    policy_and_value_optimizer_fn: Function defining the optimizer.
    output_dir: Output dir.
    epochs: Number of epochs to run for.
    n_optimizer_steps: Number of optimizer steps.
    print_every_optimizer_steps: How often to log during the policy optimization
      process.
    target_kl: Policy iteration early stopping.
    boundary: We pad trajectories at integer multiples of this number.
    max_timestep: If set to an integer, maximum number of time-steps in
      a trajectory. Used in the collect procedure.
    max_timestep_eval: If set to an integer, maximum number of time-steps in an
      evaluation trajectory. Used in the collect procedure.
    random_seed: Random seed.
    gamma: Reward discount factor.
    lambda_: N-step TD-error discount factor in GAE.
    epsilon: Random action probability in epsilon-greedy sampling.
    c1: Value loss coefficient.
    c2: Entropy loss coefficient.
    eval_every_n: How frequently to eval the policy.
    done_frac_for_policy_save: Fraction of the trajectories that should be done
      to checkpoint the policy.
    enable_early_stopping: Whether to enable early stopping.
    n_evals: Number of times to evaluate.
    len_history_for_policy: How much of history to give to the policy.
    eval_temperatures: Sequence of temperatures to try for categorical sampling
      during evaluation.
  """
    gfile.makedirs(output_dir)

    # Create summary writers and history.
    train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
    timing_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "timing"))
    eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

    train_sw.text("env_name", env_name)
    timing_sw.text("env_name", env_name)
    eval_sw.text("env_name", env_name)

    jax_rng_key = trax.get_random_number_generator_and_set_seed(random_seed)

    # Batch Observations Shape = [1, 1] + OBS, because we will eventually call
    # policy and value networks on shape [B, T] +_OBS
    batch_observations_shape = (1, 1) + env.observation_space.shape
    observations_dtype = env.observation_space.dtype

    assert isinstance(env.action_space, gym.spaces.Discrete)
    n_actions = env.action_space.n

    jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)

    # Initialize the policy and value network.
    policy_and_value_net_params, policy_and_value_net_apply = (
        policy_and_value_net_fn(key1, batch_observations_shape,
                                observations_dtype, n_actions))

    # Maybe restore the policy params. If there is nothing to restore, then
    # iteration = 0 and policy_and_value_net_params are returned as is.
    restore, policy_and_value_net_params, iteration = (maybe_restore_params(
        output_dir, policy_and_value_net_params))

    if restore:
        logging.info("Restored parameters from iteration [%d]", iteration)
        # We should start from the next iteration.
        iteration += 1

    policy_and_value_net_apply = jit(policy_and_value_net_apply)

    # Initialize the optimizers.
    policy_and_value_optimizer = (
        policy_and_value_optimizer_fn(policy_and_value_net_params))
    (policy_and_value_opt_state, policy_and_value_opt_update,
     policy_and_value_get_params) = policy_and_value_optimizer

    n_trajectories_done = 0
    last_saved_at = 0

    logging.info("Starting the PPO training loop.")
    for i in range(iteration, epochs):
        epoch_start_time = time.time()

        # Params we'll use to collect the trajectories.
        policy_and_value_net_params = policy_and_value_get_params(
            policy_and_value_opt_state)

        # A function to get the policy and value predictions.
        def get_predictions(observations, rng=None):
            """Returns log-probs, value predictions and key back."""
            key, key1 = jax_random.split(rng, num=2)

            log_probs, value_preds = policy_and_value_net_apply(
                observations, policy_and_value_net_params, rng=key1)

            return log_probs, value_preds, key

        # Evaluate the policy.
        policy_eval_start_time = time.time()
        if ((i + 1) % eval_every_n == 0) or (i == epochs - 1):
            jax_rng_key, key = jax_random.split(jax_rng_key, num=2)

            logging.vlog(1, "Epoch [% 6d] evaluating policy.", i)

            reward_stats = evaluate_policy(
                eval_env,
                get_predictions,
                temperatures=eval_temperatures,
                max_timestep=max_timestep_eval,
                n_evals=n_evals,
                len_history_for_policy=len_history_for_policy,
                rng=key)
            write_eval_reward_summaries(reward_stats, eval_sw, epoch=i)
        policy_eval_time = get_time(policy_eval_start_time)

        trajectory_collection_start_time = time.time()
        logging.vlog(1, "Epoch [% 6d] collecting trajectories.", i)
        jax_rng_key, key = jax_random.split(jax_rng_key)
        trajs, n_done, timing_info = collect_trajectories(
            env,
            policy_fn=get_predictions,
            n_trajectories=env.batch_size,
            max_timestep=max_timestep,
            rng=key,
            len_history_for_policy=len_history_for_policy,
            reset=(i == 0) or restore,
            epsilon=(10.0 / (i + 10.0)))  # this is a different epsilon.
        trajectory_collection_time = get_time(trajectory_collection_start_time)

        logging.vlog(1, "Collecting trajectories took %0.2f msec.",
                     trajectory_collection_time)

        avg_reward = float(sum(np.sum(traj[2]) for traj in trajs)) / len(trajs)
        max_reward = max(np.sum(traj[2]) for traj in trajs)
        min_reward = min(np.sum(traj[2]) for traj in trajs)

        train_sw.scalar("train/reward_mean_truncated", avg_reward, step=i)

        logging.vlog(1,
                     "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s",
                     avg_reward, max_reward, min_reward,
                     [float(np.sum(traj[2])) for traj in trajs])

        logging.vlog(
            1, "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]",
            float(sum(len(traj[0]) for traj in trajs)) / len(trajs),
            max(len(traj[0]) for traj in trajs),
            min(len(traj[0]) for traj in trajs))
        logging.vlog(2, "Trajectory Lengths: %s",
                     [len(traj[0]) for traj in trajs])

        padding_start_time = time.time()
        (_, reward_mask, padded_observations, padded_actions, padded_rewards,
         padded_infos) = pad_trajectories(trajs, boundary=boundary)
        padding_time = get_time(padding_start_time)

        logging.vlog(1, "Padding trajectories took %0.2f msec.",
                     get_time(padding_start_time))
        logging.vlog(1, "Padded Observations' shape [%s]",
                     str(padded_observations.shape))
        logging.vlog(1, "Padded Actions' shape [%s]",
                     str(padded_actions.shape))
        logging.vlog(1, "Padded Rewards' shape [%s]",
                     str(padded_rewards.shape))

        # Some assertions.
        B, T = padded_actions.shape  # pylint: disable=invalid-name
        assert (B, T) == padded_rewards.shape
        assert (B, T) == reward_mask.shape
        assert (B, T + 1) == padded_observations.shape[:2]
        assert (B, T +
                1) + env.observation_space.shape == padded_observations.shape

        log_prob_recompute_start_time = time.time()
        assert ("log_prob_actions" in padded_infos
                and "value_predictions" in padded_infos)
        # These are the actual log-probabs and value predictions seen while picking
        # the actions.
        actual_log_probabs_traj = padded_infos["log_prob_actions"]
        actual_value_predictions_traj = padded_infos["value_predictions"]

        assert (B, T) == actual_log_probabs_traj.shape[:2]
        A = actual_log_probabs_traj.shape[2]  # pylint: disable=invalid-name
        assert (B, T, 1) == actual_value_predictions_traj.shape

        # TODO(afrozm): log-probabs doesn't need to be (B, T+1, A) it can do with
        # (B, T, A), so make that change throughout.

        # NOTE: We don't have the log-probabs and value-predictions for the last
        # observation, so we re-calculate for everything, but use the original ones
        # for all but the last time-step.
        jax_rng_key, key = jax_random.split(jax_rng_key)
        log_probabs_traj, value_predictions_traj, _ = get_predictions(
            padded_observations, rng=key)

        assert (B, T + 1, A) == log_probabs_traj.shape
        assert (B, T + 1, 1) == value_predictions_traj.shape

        # Concatenate the last time-step's log-probabs and value predictions to the
        # actual log-probabs and value predictions and use those going forward.
        log_probabs_traj = np.concatenate(
            (actual_log_probabs_traj, log_probabs_traj[:, -1:, :]), axis=1)
        value_predictions_traj = np.concatenate(
            (actual_value_predictions_traj, value_predictions_traj[:, -1:, :]),
            axis=1)

        log_prob_recompute_time = get_time(log_prob_recompute_start_time)

        # Linear annealing from 0.1 to 0.0
        # epsilon_schedule = epsilon if epochs == 1 else epsilon * (1.0 -
        #                                                           (i /
        #                                                            (epochs - 1)))

        # Constant epsilon.
        epsilon_schedule = epsilon

        # Compute value and ppo losses.
        jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)
        logging.vlog(2, "Starting to compute P&V loss.")
        loss_compute_start_time = time.time()
        cur_combined_loss, cur_ppo_loss, cur_value_loss, entropy_bonus = (
            combined_loss(policy_and_value_net_params,
                          log_probabs_traj,
                          value_predictions_traj,
                          policy_and_value_net_apply,
                          padded_observations,
                          padded_actions,
                          padded_rewards,
                          reward_mask,
                          gamma=gamma,
                          lambda_=lambda_,
                          epsilon=epsilon_schedule,
                          c1=c1,
                          c2=c2,
                          rng=key1))
        loss_compute_time = get_time(loss_compute_start_time)
        logging.vlog(
            1,
            "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.",
            cur_combined_loss, cur_value_loss, cur_ppo_loss, entropy_bonus,
            get_time(loss_compute_start_time))

        jax_rng_key, key1 = jax_random.split(jax_rng_key, num=2)
        logging.vlog(1, "Policy and Value Optimization")
        optimization_start_time = time.time()
        keys = jax_random.split(key1, num=n_optimizer_steps)
        for j in range(n_optimizer_steps):
            k1, k2, k3 = jax_random.split(keys[j], num=3)
            t = time.time()
            # Update the optimizer state.
            policy_and_value_opt_state = policy_and_value_opt_step(
                j,
                policy_and_value_opt_state,
                policy_and_value_opt_update,
                policy_and_value_get_params,
                policy_and_value_net_apply,
                log_probabs_traj,
                value_predictions_traj,
                padded_observations,
                padded_actions,
                padded_rewards,
                reward_mask,
                c1=c1,
                c2=c2,
                gamma=gamma,
                lambda_=lambda_,
                epsilon=epsilon_schedule,
                rng=k1)

            # Compute the approx KL for early stopping.
            new_policy_and_value_net_params = policy_and_value_get_params(
                policy_and_value_opt_state)

            log_probab_actions_new, _ = policy_and_value_net_apply(
                padded_observations, new_policy_and_value_net_params, rng=k2)

            approx_kl = approximate_kl(log_probab_actions_new,
                                       log_probabs_traj, reward_mask)

            early_stopping = enable_early_stopping and approx_kl > 1.5 * target_kl
            if early_stopping:
                logging.vlog(
                    1,
                    "Early stopping policy and value optimization at iter: %d, "
                    "with approx_kl: %0.2f", j, approx_kl)
                # We don't return right-away, we want the below to execute on the last
                # iteration.

            t2 = time.time()
            if (((j + 1) % print_every_optimizer_steps == 0)
                    or (j == n_optimizer_steps - 1) or early_stopping):
                # Compute and log the loss.
                (loss_combined, loss_ppo, loss_value,
                 entropy_bonus) = (combined_loss(
                     new_policy_and_value_net_params,
                     log_probabs_traj,
                     value_predictions_traj,
                     policy_and_value_net_apply,
                     padded_observations,
                     padded_actions,
                     padded_rewards,
                     reward_mask,
                     gamma=gamma,
                     lambda_=lambda_,
                     epsilon=epsilon_schedule,
                     c1=c1,
                     c2=c2,
                     rng=k3))
                logging.vlog(
                    1, "One Policy and Value grad desc took: %0.2f msec",
                    get_time(t, t2))
                logging.vlog(
                    1, "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->"
                    " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss,
                    loss_combined, loss_value, loss_ppo, entropy_bonus)

            if early_stopping:
                break

        optimization_time = get_time(optimization_start_time)

        logging.vlog(
            1, "Total Combined Loss reduction [%0.2f]%%",
            (100 *
             (cur_combined_loss - loss_combined) / np.abs(cur_combined_loss)))

        # Save parameters every time we see the end of at least a fraction of batch
        # number of trajectories that are done (not completed -- completed includes
        # truncated and done).
        # Also don't save too frequently, enforce a minimum gap.
        # Or if this is the last iteration.
        policy_save_start_time = time.time()
        n_trajectories_done += n_done
        # TODO(afrozm): Refactor to trax.save_state.
        if ((
            (n_trajectories_done >= done_frac_for_policy_save * env.batch_size)
                and (i - last_saved_at > eval_every_n) and
            (((i + 1) % eval_every_n == 0))) or (i == epochs - 1)):
            logging.vlog(1, "Epoch [% 6d] saving model.", i)
            old_model_files = gfile.glob(
                os.path.join(output_dir, "model-??????.pkl"))
            params_file = os.path.join(output_dir, "model-%06d.pkl" % i)
            with gfile.GFile(params_file, "wb") as f:
                pickle.dump(policy_and_value_net_params, f)
            # Remove the old model files.
            for path in old_model_files:
                gfile.remove(path)
            # Reset this number.
            n_trajectories_done = 0
            last_saved_at = i
        policy_save_time = get_time(policy_save_start_time)

        epoch_time = get_time(epoch_start_time)

        logging.info(
            "Epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined"
            " Loss(value, ppo, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]", i,
            min_reward, max_reward, avg_reward, loss_combined, loss_value,
            loss_ppo, entropy_bonus)

        timing_dict = {
            "epoch": epoch_time,
            "policy_eval": policy_eval_time,
            "trajectory_collection": trajectory_collection_time,
            "padding": padding_time,
            "log_prob_recompute": log_prob_recompute_time,
            "loss_compute": loss_compute_time,
            "optimization": optimization_time,
            "policy_save": policy_save_time,
        }

        timing_dict.update(timing_info)

        for k, v in timing_dict.items():
            timing_sw.scalar("timing/%s" % k, v, step=i)

        max_key_len = max(len(k) for k in timing_dict)
        timing_info_list = [
            "%s : % 10.2f" % (k.rjust(max_key_len + 1), v)
            for k, v in sorted(timing_dict.items())
        ]
        logging.info("Epoch [% 6d], Timings: \n%s", i,
                     "\n".join(timing_info_list))

        # Reset restore.
        restore = False

        # Flush summary writers once in a while.
        if (i + 1) % 1000 == 0 or i == epochs - 1:
            train_sw.flush()
            timing_sw.flush()
            eval_sw.flush()
Example #10
0
  def __init__(
      self,
      train_env,
      eval_env,
      policy_and_value_model,
      policy_and_value_optimizer_fn,
      policy_and_value_two_towers,
      output_dir,
      n_optimizer_steps,
      print_every_optimizer_steps,
      target_kl,
      boundary,
      max_timestep,
      max_timestep_eval,
      random_seed,
      gamma,
      lambda_,
      c1,
      c2,
      eval_every_n,
      done_frac_for_policy_save,
      n_evals,
      len_history_for_policy,
      eval_temperatures,
  ):
    self._train_env = train_env
    self._eval_env = eval_env
    self._n_optimizer_steps = n_optimizer_steps
    self._print_every_optimizer_steps = print_every_optimizer_steps
    self._target_kl = target_kl
    self._boundary = boundary
    self._max_timestep = max_timestep
    self._max_timestep_eval = max_timestep_eval
    self._gamma = gamma
    self._lambda_ = lambda_
    self._c1 = c1
    self._c2 = c2
    self._eval_every_n = eval_every_n
    self._done_frac_for_policy_save = done_frac_for_policy_save
    self._n_evals = n_evals
    self._len_history_for_policy = len_history_for_policy
    self._eval_temperatures = eval_temperatures

    assert isinstance(self._train_env.action_space, gym.spaces.Discrete)
    n_actions = self._train_env.action_space.n

    # Batch Observations Shape = [1, 1] + OBS, because we will eventually call
    # policy and value networks on shape [B, T] +_OBS
    batch_observations_shape = (1, 1) + self._train_env.observation_space.shape
    observations_dtype = self._train_env.observation_space.dtype

    self._rng = trax.get_random_number_generator_and_set_seed(random_seed)
    self._rng, key1 = jax_random.split(self._rng, num=2)

    # Initialize the policy and value network.
    policy_and_value_net_params, policy_and_value_net_apply = (
        policy_and_value_net(
            rng_key=key1,
            batch_observations_shape=batch_observations_shape,
            observations_dtype=observations_dtype,
            n_actions=n_actions,
            bottom_layers_fn=policy_and_value_model,
            two_towers=policy_and_value_two_towers,
        )
    )
    self._policy_and_value_net_apply = jit(policy_and_value_net_apply)

    # Maybe restore the policy params. If there is nothing to restore, then
    # iteration = 0 and policy_and_value_net_params are returned as is.
    restored, policy_and_value_net_params, self._epoch = (
        maybe_restore_params(output_dir, policy_and_value_net_params))

    if restored:
      logging.info("Restored parameters from iteration [%d]", self._epoch)
      # We should start from the next iteration.
      self._epoch += 1

    # Initialize the optimizers.
    policy_and_value_optimizer = (
        policy_and_value_optimizer_fn(policy_and_value_net_params))
    (self._policy_and_value_opt_state, self._policy_and_value_opt_update,
     self._policy_and_value_get_params) = policy_and_value_optimizer

    self._output_dir = output_dir
    gfile.makedirs(self._output_dir)

    # Create summary writers and history.
    self._train_sw = jaxboard.SummaryWriter(
        os.path.join(self._output_dir, "train"))
    self._timing_sw = jaxboard.SummaryWriter(
        os.path.join(self._output_dir, "timing"))
    self._eval_sw = jaxboard.SummaryWriter(
        os.path.join(self._output_dir, "eval"))

    self._should_reset = True
    self._n_trajectories_done = 0

    self._last_saved_at = 0