Exemplo n.º 1
0
 def _current_observation(self):
     observations = online_tune.history_to_observations(
         self._trainer.history,
         self._observation_metrics,
         self._observation_range,
         self._control_configs
         if self._include_controls_in_observation else None,
     )
     assert observations.shape[0] > 0, 'No values in history for any metric.'
     return observations[-1, :]
Exemplo n.º 2
0
 def test_clips_observations(self):
     history = trax_history.History()
     self._append_metrics(history, ("eval", "loss"), [-10, 10])
     observations = online_tune.history_to_observations(
         history,
         metrics=(("eval", "loss"), ),
         observation_range=(-2, 2),
         control_configs=None,
     )
     np.testing.assert_array_equal(observations, [[-2], [2]])
Exemplo n.º 3
0
 def test_converts_history_to_observations_without_controls(self):
     history = trax_history.History()
     self._append_metrics(history, ("train", "loss"), [1.0, 0.07])
     self._append_metrics(history, ("eval", "accuracy"), [0.12, 0.68])
     observations = online_tune.history_to_observations(
         history,
         metrics=(("eval", "accuracy"), ("train", "loss")),
         observation_range=(-1, 1),
         control_configs=None,
     )
     np.testing.assert_array_almost_equal(observations,
                                          [[0.12, 1.0], [0.68, 0.07]])
Exemplo n.º 4
0
 def test_converts_history_to_observations_with_controls(self):
     history = trax_history.History()
     self._append_metrics(history, ("train", "training/learning_rate"),
                          [1e-3, 1e-4])
     observations = online_tune.history_to_observations(
         history,
         metrics=(),
         observation_range=(0, 5),
         control_configs=(("learning_rate", None, (1e-9, 10.0), False), ),
     )
     self.assertEqual(observations.shape, (2, 1))
     ((log_lr_1, ), (log_lr_2, )) = observations
     self.assertGreater(log_lr_1, log_lr_2)
Exemplo n.º 5
0
def PolicySchedule(
    history,
    observation_metrics=(
        ("train", "metrics/accuracy"),
        ("train", "metrics/loss"),
        ("eval", "metrics/accuracy"),
        ("eval", "metrics/loss"),
    ),
    include_controls_in_observation=False,
    control_configs=(
        # (name, start, (low, high), flip)
        ("learning_rate", 1e-3, (1e-9, 10.0), False), ),
    observation_range=(0.0, 10.0),
    action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5),
    policy_and_value_model=trax_models.FrameStackMLP,
    policy_and_value_two_towers=False,
    policy_and_value_vocab_size=None,
    policy_dir=gin.REQUIRED,
    temperature=1.0,
):
    """Learning rate schedule controlled by a learned policy.

  Args:
    history: the history of training and evaluation (History object).
    observation_metrics: list of pairs (mode, metric), as in the History object.
    include_controls_in_observation: bool, whether to include the controls in
      observations.
    control_configs: control configs, see trax.rl.envs.OnlineTuneEnv.
    observation_range: tuple (low, high), range to clip the metrics to.
    action_multipliers: sequence of LR multipliers that policy actions
      correspond to.
    policy_and_value_model: Trax model to use as the policy.
    policy_and_value_two_towers: bool, whether the action distribution and value
      prediction is computed by separate model towers.
    policy_and_value_vocab_size: vocabulary size of a policy and value network
      operating on serialized representation. If None, use raw continuous
      representation.
    policy_dir: directory with the policy checkpoint.
    temperature: temperature for sampling from the policy.

  Returns:
    a function nontrainable_params(step): float -> {"name": float}, the
    step-dependent schedule for nontrainable parameters.
  """

    # Turn the history into observations for the policy. If we don't have any,
    # return the initial learning rate.
    start_time = time.time()
    observations = online_tune.history_to_observations(
        history, observation_metrics, observation_range,
        control_configs if include_controls_in_observation else None)
    logging.vlog(1, "Building observations took %0.2f sec.",
                 time.time() - start_time)
    if observations.shape[0] == 0:
        controls = {
            name: start_value
            for (name, start_value, _, _) in control_configs
        }
        return lambda _: controls

    assert policy_and_value_vocab_size is None, (
        "Serialized policies are not supported yet.")
    # Build the policy network and load its parameters.
    start_time = time.time()
    net = ppo.policy_and_value_net(
        n_controls=len(control_configs),
        n_actions=len(action_multipliers),
        vocab_size=policy_and_value_vocab_size,
        bottom_layers_fn=policy_and_value_model,
        two_towers=policy_and_value_two_towers,
    )
    logging.vlog(1, "Building the policy network took %0.2f sec.",
                 time.time() - start_time)
    start_time = time.time()
    # (opt_state, state, epoch, opt_step)
    (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir)
    assert opt_state is not None, "Policy checkpoint not found."
    (params, _) = opt_state
    logging.vlog(1, "Restoring the policy parameters took %0.2f sec.",
                 time.time() - start_time)

    # Run the policy and sample an action.
    seed = random.randint(0, 2**31 - 1)
    rng = jax_random.get_prng(seed=seed)
    start_time = time.time()
    # ((log_probs, value_preds), state). We have no way to pass state to the next
    # step, but that should be fine.
    (log_probs, _) = (net(np.array([observations]),
                          params=params,
                          state=state,
                          rng=rng))
    logging.vlog(1, "Running the policy took %0.2f sec.",
                 time.time() - start_time)
    # Sample from the action distribution for the last timestep.
    assert log_probs.shape == (1, len(control_configs) * observations.shape[0],
                               len(action_multipliers))
    action = utils.gumbel_sample(log_probs[0, -len(control_configs):, :] /
                                 temperature)

    # Get new controls.
    controls = {
        # name: value
        control_config[0]: online_tune.update_control(  # pylint: disable=g-complex-comprehension
            control_config, control_action, history, action_multipliers)
        for (control_action, control_config) in zip(action, control_configs)
    }
    return lambda _: controls