def update_optimization_state(self,
                                  output_dir,
                                  policy_and_value_opt_state=None):
        (self._policy_and_value_opt_state, self._model_state, self._epoch,
         self._total_opt_step) = ppo.maybe_restore_opt_state(
             output_dir, policy_and_value_opt_state, self._model_state)

        if self._epoch > 0:
            logging.info("Restored parameters from epoch [%d]", self._epoch)
Exemplo n.º 2
0
 def test_saves_and_restores_opt_state(self):
     opt_state = 123
     state = 456
     epoch = 7
     opt_step = 89
     output_dir = self.get_temp_dir()
     ppo.save_opt_state(output_dir, opt_state, state, epoch, opt_step)
     restored_data = ppo.maybe_restore_opt_state(output_dir)
     self.assertEqual(restored_data, (opt_state, state, epoch, opt_step))
Exemplo n.º 3
0
    def __init__(self,
                 train_env,
                 eval_env,
                 output_dir,
                 policy_and_value_model=trax_models.FrameStackMLP,
                 policy_and_value_optimizer=functools.partial(
                     trax_opt.Adam, learning_rate=1e-3),
                 policy_and_value_two_towers=False,
                 n_optimizer_steps=N_OPTIMIZER_STEPS,
                 print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
                 target_kl=0.01,
                 boundary=20,
                 max_timestep=None,
                 max_timestep_eval=20000,
                 random_seed=None,
                 gamma=GAMMA,
                 lambda_=LAMBDA,
                 c1=1.0,
                 c2=0.01,
                 eval_every_n=1000,
                 done_frac_for_policy_save=0.5,
                 n_evals=1,
                 len_history_for_policy=4,
                 eval_temperatures=(1.0, 0.5),
                 **kwargs):
        """Creates the PPO trainer.

    Args:
      train_env: gym.Env to use for training.
      eval_env: gym.Env to use for evaluation.
      output_dir: Output dir.
      policy_and_value_model: Function defining the policy and value network,
        without the policy and value heads.
      policy_and_value_optimizer: Function defining the optimizer.
      policy_and_value_two_towers: Whether to use two separate models as the
        policy and value networks. If False, share their parameters.
      n_optimizer_steps: Number of optimizer steps.
      print_every_optimizer_steps: How often to log during the policy
        optimization process.
      target_kl: Policy iteration early stopping. Set to infinity to disable
        early stopping.
      boundary: We pad trajectories at integer multiples of this number.
      max_timestep: If set to an integer, maximum number of time-steps in
        a trajectory. Used in the collect procedure.
      max_timestep_eval: If set to an integer, maximum number of time-steps in
        an evaluation trajectory. Used in the collect procedure.
      random_seed: Random seed.
      gamma: Reward discount factor.
      lambda_: N-step TD-error discount factor in GAE.
      c1: Value loss coefficient.
      c2: Entropy loss coefficient.
      eval_every_n: How frequently to eval the policy.
      done_frac_for_policy_save: Fraction of the trajectories that should be
        done to checkpoint the policy.
      n_evals: Number of times to evaluate.
      len_history_for_policy: How much of history to give to the policy.
      eval_temperatures: Sequence of temperatures to try for categorical
        sampling during evaluation.
      **kwargs: Additional keyword arguments passed to the base class.
    """
        # Set in base class constructor.
        self._train_env = None
        self._should_reset = None

        super(PPO, self).__init__(train_env, eval_env, output_dir, **kwargs)

        self._n_optimizer_steps = n_optimizer_steps
        self._print_every_optimizer_steps = print_every_optimizer_steps
        self._target_kl = target_kl
        self._boundary = boundary
        self._max_timestep = max_timestep
        self._max_timestep_eval = max_timestep_eval
        self._gamma = gamma
        self._lambda_ = lambda_
        self._c1 = c1
        self._c2 = c2
        self._eval_every_n = eval_every_n
        self._done_frac_for_policy_save = done_frac_for_policy_save
        self._n_evals = n_evals
        self._len_history_for_policy = len_history_for_policy
        self._eval_temperatures = eval_temperatures

        assert isinstance(self.train_env.action_space, gym.spaces.Discrete)
        n_actions = self.train_env.action_space.n

        # Batch Observations Shape = [1, 1] + OBS, because we will eventually call
        # policy and value networks on shape [B, T] +_OBS
        batch_observations_shape = (1,
                                    1) + self.train_env.observation_space.shape
        observations_dtype = self.train_env.observation_space.dtype

        self._rng = trax.get_random_number_generator_and_set_seed(random_seed)
        self._rng, key1 = jax_random.split(self._rng, num=2)

        # Initialize the policy and value network.
        policy_and_value_net_params, self._model_state, policy_and_value_net_apply = (
            ppo.policy_and_value_net(
                rng_key=key1,
                batch_observations_shape=batch_observations_shape,
                observations_dtype=observations_dtype,
                n_actions=n_actions,
                bottom_layers_fn=policy_and_value_model,
                two_towers=policy_and_value_two_towers,
            ))
        self._policy_and_value_net_apply = jit(policy_and_value_net_apply)

        # Initialize the optimizer.
        (policy_and_value_opt_state, self._policy_and_value_opt_update,
         self._policy_and_value_get_params) = ppo.optimizer_fn(
             policy_and_value_optimizer, policy_and_value_net_params)

        # Maybe restore the optimization state. If there is nothing to restore, then
        # iteration = 0 and policy_and_value_opt_state is returned as is.
        (restored, self._policy_and_value_opt_state, self._model_state,
         self._epoch, self._total_opt_step) = ppo.maybe_restore_opt_state(
             output_dir, policy_and_value_opt_state, self._model_state)

        if restored:
            logging.info("Restored parameters from iteration [%d]",
                         self._epoch)
            # We should start from the next iteration.
            self._epoch += 1

        # Create summary writers and history.
        self._train_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "train"))
        self._timing_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "timing"))
        self._eval_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "eval"))

        self._n_trajectories_done = 0

        self._last_saved_at = 0
Exemplo n.º 4
0
def PolicySchedule(
    history,
    observation_metrics=(
        ("train", "metrics/accuracy"),
        ("train", "metrics/loss"),
        ("eval", "metrics/accuracy"),
        ("eval", "metrics/loss"),
    ),
    include_lr_in_observation=False,
    observation_range=(0.0, 5.0),
    start_lr=0.001,
    max_lr=10.0,
    action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5),
    policy_and_value_model=trax_models.FrameStackMLP,
    policy_and_value_two_towers=False,
    policy_dir=gin.REQUIRED,
):
    """Learning rate schedule controlled by a learned policy.

  Args:
    history: the history of training and evaluation (History object).
    observation_metrics: list of pairs (mode, metric), as in the History object.
    include_lr_in_observation: bool, whether to include the learning rate in
      observations.
    observation_range: tuple (low, high), range to clip the observation to.
    start_lr: starting learning rate.
    max_lr: maximum value to clip the learning rate to.
    action_multipliers: sequence of LR multipliers that policy actions
      correspond to.
    policy_and_value_model: Trax model to use as the policy.
    policy_and_value_two_towers: bool, whether the action distribution and value
      prediction is computed by separate model towers.
    policy_dir: directory with the policy checkpoint.

  Returns:
    a function learning_rate(step): float -> float, the step-dependent lr.
  """

    # Turn the history into observations for the policy. If we don't have any,
    # return the initial learning rate.
    start_time = time.time()
    observations = online_tune.history_to_observations(
        history, observation_metrics, observation_range,
        include_lr_in_observation)
    logging.vlog(1, "Building observations took %0.2f sec.",
                 time.time() - start_time)
    if observations.shape[0] == 0:
        return lambda _: start_lr

    # Build the policy network and load its parameters.
    start_time = time.time()
    net = ppo.policy_and_value_net(
        n_actions=len(action_multipliers),
        bottom_layers_fn=policy_and_value_model,
        two_towers=policy_and_value_two_towers,
    )
    logging.vlog(1, "Building the policy network took %0.2f sec.",
                 time.time() - start_time)
    start_time = time.time()
    # (opt_state, state, epoch, opt_step)
    (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir)
    assert opt_state is not None, "Policy checkpoint not found."
    (params, _) = opt_state
    logging.vlog(1, "Restoring the policy parameters took %0.2f sec.",
                 time.time() - start_time)

    # Run the policy and sample an action.
    seed = random.randint(0, 2**31 - 1)
    rng = jax_random.get_prng(seed=seed)
    start_time = time.time()
    # ((log_probs, value_preds), state). We have no way to pass state to the next
    # step, but that should be fine.
    ((log_probs, _), _) = net(np.array([observations]), params, state, rng=rng)
    logging.vlog(1, "Running the policy took %0.2f sec.",
                 time.time() - start_time)
    # Sample from the action distribution for the last timestep.
    action = utils.gumbel_sample(log_probs[0, -1, :])

    # Get a new learning rate.
    new_lr = online_tune.new_learning_rate(action, history, action_multipliers,
                                           max_lr)
    return lambda _: new_lr
Exemplo n.º 5
0
def PolicySchedule(
    history,
    observation_metrics=(
        ("train", "metrics/accuracy"),
        ("train", "metrics/loss"),
        ("eval", "metrics/accuracy"),
        ("eval", "metrics/loss"),
    ),
    include_controls_in_observation=False,
    control_configs=(
        # (name, start, (low, high), flip)
        ("learning_rate", 1e-3, (1e-9, 10.0), False), ),
    observation_range=(0.0, 10.0),
    action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5),
    policy_and_value_model=trax_models.FrameStackMLP,
    policy_and_value_two_towers=False,
    policy_and_value_vocab_size=None,
    policy_dir=gin.REQUIRED,
    temperature=1.0,
):
    """Learning rate schedule controlled by a learned policy.

  Args:
    history: the history of training and evaluation (History object).
    observation_metrics: list of pairs (mode, metric), as in the History object.
    include_controls_in_observation: bool, whether to include the controls in
      observations.
    control_configs: control configs, see trax.rl.envs.OnlineTuneEnv.
    observation_range: tuple (low, high), range to clip the metrics to.
    action_multipliers: sequence of LR multipliers that policy actions
      correspond to.
    policy_and_value_model: Trax model to use as the policy.
    policy_and_value_two_towers: bool, whether the action distribution and value
      prediction is computed by separate model towers.
    policy_and_value_vocab_size: vocabulary size of a policy and value network
      operating on serialized representation. If None, use raw continuous
      representation.
    policy_dir: directory with the policy checkpoint.
    temperature: temperature for sampling from the policy.

  Returns:
    a function nontrainable_params(step): float -> {"name": float}, the
    step-dependent schedule for nontrainable parameters.
  """

    # Turn the history into observations for the policy. If we don't have any,
    # return the initial learning rate.
    start_time = time.time()
    observations = online_tune.history_to_observations(
        history, observation_metrics, observation_range,
        control_configs if include_controls_in_observation else None)
    logging.vlog(1, "Building observations took %0.2f sec.",
                 time.time() - start_time)
    if observations.shape[0] == 0:
        controls = {
            name: start_value
            for (name, start_value, _, _) in control_configs
        }
        return lambda _: controls

    assert policy_and_value_vocab_size is None, (
        "Serialized policies are not supported yet.")
    # Build the policy network and load its parameters.
    start_time = time.time()
    net = ppo.policy_and_value_net(
        n_controls=len(control_configs),
        n_actions=len(action_multipliers),
        vocab_size=policy_and_value_vocab_size,
        bottom_layers_fn=policy_and_value_model,
        two_towers=policy_and_value_two_towers,
    )
    logging.vlog(1, "Building the policy network took %0.2f sec.",
                 time.time() - start_time)
    start_time = time.time()
    # (opt_state, state, epoch, opt_step)
    (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir)
    assert opt_state is not None, "Policy checkpoint not found."
    (params, _) = opt_state
    logging.vlog(1, "Restoring the policy parameters took %0.2f sec.",
                 time.time() - start_time)

    # Run the policy and sample an action.
    seed = random.randint(0, 2**31 - 1)
    rng = jax_random.get_prng(seed=seed)
    start_time = time.time()
    # ((log_probs, value_preds), state). We have no way to pass state to the next
    # step, but that should be fine.
    (log_probs, _) = (net(np.array([observations]),
                          params=params,
                          state=state,
                          rng=rng))
    logging.vlog(1, "Running the policy took %0.2f sec.",
                 time.time() - start_time)
    # Sample from the action distribution for the last timestep.
    assert log_probs.shape == (1, len(control_configs) * observations.shape[0],
                               len(action_multipliers))
    action = utils.gumbel_sample(log_probs[0, -len(control_configs):, :] /
                                 temperature)

    # Get new controls.
    controls = {
        # name: value
        control_config[0]: online_tune.update_control(  # pylint: disable=g-complex-comprehension
            control_config, control_action, history, action_multipliers)
        for (control_action, control_config) in zip(action, control_configs)
    }
    return lambda _: controls