def step(self, action): """Step the environment. One environment step corresponds to self.train_steps training steps. Args: action: (int) Action to take. An index in self.action_multipliers. Returns: Tuple (observation, reward, done, info). observation is a singleton vector with the current value of the metric. reward is the difference in the metric since the last step. done is set after reaching self.env_steps environment steps. info is an empty dict. """ self._current_controls = { # name: value control_config[0]: online_tune.update_control( # pylint: disable=g-complex-comprehension control_config, control_action, self._trainer.state.history, self._action_multipliers, ) for (control_action, control_config) in zip(action, self._control_configs) } last_reward_metric = self._current_reward_metric self._trainer.train_epoch(self._train_steps, self._eval_steps) self._step += 1 current_reward_metric = self._current_reward_metric observation = self._current_observation reward = current_reward_metric - last_reward_metric done = self._step == self._env_steps return (observation, reward, done, {})
def test_clips_updated_control_with_flipping(self): config = ("momentum", None, (0.5, 0.99), True) history = trax_history.History() self._append_metrics(history, online_tune.control_metric("momentum"), [0.985]) new_control = online_tune.update_control( control_config=config, action=0, history=history, action_multipliers=(0.5, 1.0, 2.0), ) np.testing.assert_almost_equal(new_control, 0.99)
def test_clips_updated_control_without_flipping(self): config = ("learning_rate", None, (1e-9, 10.0), False) history = trax_history.History() self._append_metrics(history, online_tune.control_metric("learning_rate"), [7.0]) new_control = online_tune.update_control( control_config=config, action=2, history=history, action_multipliers=(0.5, 1.0, 2.0), ) np.testing.assert_almost_equal(new_control, 10.0)
def PolicySchedule( history, observation_metrics=( ("train", "metrics/accuracy"), ("train", "metrics/loss"), ("eval", "metrics/accuracy"), ("eval", "metrics/loss"), ), include_lr_in_observation=False, observation_range=(0.0, 5.0), start_lr=0.001, max_lr=10.0, action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5), policy_and_value_model=trax_models.FrameStackMLP, policy_and_value_two_towers=False, policy_dir=gin.REQUIRED, ): """Learning rate schedule controlled by a learned policy. Args: history: the history of training and evaluation (History object). observation_metrics: list of pairs (mode, metric), as in the History object. include_lr_in_observation: bool, whether to include the learning rate in observations. observation_range: tuple (low, high), range to clip the observation to. start_lr: starting learning rate. max_lr: maximum value to clip the learning rate to. action_multipliers: sequence of LR multipliers that policy actions correspond to. policy_and_value_model: Trax model to use as the policy. policy_and_value_two_towers: bool, whether the action distribution and value prediction is computed by separate model towers. policy_dir: directory with the policy checkpoint. Returns: a function learning_rate(step): float -> {"learning_rate": float}, the step-dependent lr. """ # Turn the history into observations for the policy. If we don't have any, # return the initial learning rate. start_time = time.time() lr_config = ("learning_rate", start_lr, (1e-9, max_lr), False) if include_lr_in_observation: control_configs = (lr_config, ) else: control_configs = None observations = online_tune.history_to_observations(history, observation_metrics, observation_range, control_configs) logging.vlog(1, "Building observations took %0.2f sec.", time.time() - start_time) if observations.shape[0] == 0: return lambda _: start_lr # Build the policy network and load its parameters. start_time = time.time() net = ppo.policy_and_value_net( n_controls=1, n_actions=len(action_multipliers), bottom_layers_fn=policy_and_value_model, two_towers=policy_and_value_two_towers, ) logging.vlog(1, "Building the policy network took %0.2f sec.", time.time() - start_time) start_time = time.time() # (opt_state, state, epoch, opt_step) (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir) assert opt_state is not None, "Policy checkpoint not found." (params, _) = opt_state logging.vlog(1, "Restoring the policy parameters took %0.2f sec.", time.time() - start_time) # Run the policy and sample an action. seed = random.randint(0, 2**31 - 1) rng = jax_random.get_prng(seed=seed) start_time = time.time() # ((log_probs, value_preds), state). We have no way to pass state to the next # step, but that should be fine. ((log_probs, _), _) = net(np.array([observations]), params, state, rng=rng) logging.vlog(1, "Running the policy took %0.2f sec.", time.time() - start_time) # Sample from the action distribution for the last timestep. action = utils.gumbel_sample(log_probs[0, -1, :]) # Get a new learning rate. new_lr = online_tune.update_control(lr_config, action.item(), history, action_multipliers) return lambda _: {"learning_rate": new_lr}
def PolicySchedule( history, observation_metrics=( ("train", "metrics/accuracy"), ("train", "metrics/loss"), ("eval", "metrics/accuracy"), ("eval", "metrics/loss"), ), include_controls_in_observation=False, control_configs=( # (name, start, (low, high), flip) ("learning_rate", 1e-3, (1e-9, 10.0), False), ), observation_range=(0.0, 10.0), action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5), policy_and_value_model=trax_models.FrameStackMLP, policy_and_value_two_towers=False, policy_and_value_vocab_size=None, policy_dir=gin.REQUIRED, temperature=1.0, ): """Learning rate schedule controlled by a learned policy. Args: history: the history of training and evaluation (History object). observation_metrics: list of pairs (mode, metric), as in the History object. include_controls_in_observation: bool, whether to include the controls in observations. control_configs: control configs, see trax.rl.envs.OnlineTuneEnv. observation_range: tuple (low, high), range to clip the metrics to. action_multipliers: sequence of LR multipliers that policy actions correspond to. policy_and_value_model: Trax model to use as the policy. policy_and_value_two_towers: bool, whether the action distribution and value prediction is computed by separate model towers. policy_and_value_vocab_size: vocabulary size of a policy and value network operating on serialized representation. If None, use raw continuous representation. policy_dir: directory with the policy checkpoint. temperature: temperature for sampling from the policy. Returns: a function nontrainable_params(step): float -> {"name": float}, the step-dependent schedule for nontrainable parameters. """ # Turn the history into observations for the policy. If we don't have any, # return the initial learning rate. start_time = time.time() observations = online_tune.history_to_observations( history, observation_metrics, observation_range, control_configs if include_controls_in_observation else None) logging.vlog(1, "Building observations took %0.2f sec.", time.time() - start_time) if observations.shape[0] == 0: controls = { name: start_value for (name, start_value, _, _) in control_configs } return lambda _: controls assert policy_and_value_vocab_size is None, ( "Serialized policies are not supported yet.") # Build the policy network and load its parameters. start_time = time.time() net = ppo.policy_and_value_net( n_controls=len(control_configs), n_actions=len(action_multipliers), vocab_size=policy_and_value_vocab_size, bottom_layers_fn=policy_and_value_model, two_towers=policy_and_value_two_towers, ) logging.vlog(1, "Building the policy network took %0.2f sec.", time.time() - start_time) start_time = time.time() # (opt_state, state, epoch, opt_step) (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir) assert opt_state is not None, "Policy checkpoint not found." (params, _) = opt_state logging.vlog(1, "Restoring the policy parameters took %0.2f sec.", time.time() - start_time) # Run the policy and sample an action. seed = random.randint(0, 2**31 - 1) rng = jax_random.get_prng(seed=seed) start_time = time.time() # ((log_probs, value_preds), state). We have no way to pass state to the next # step, but that should be fine. (log_probs, _) = (net(np.array([observations]), params=params, state=state, rng=rng)) logging.vlog(1, "Running the policy took %0.2f sec.", time.time() - start_time) # Sample from the action distribution for the last timestep. assert log_probs.shape == (1, len(control_configs) * observations.shape[0], len(action_multipliers)) action = utils.gumbel_sample(log_probs[0, -len(control_configs):, :] / temperature) # Get new controls. controls = { # name: value control_config[0]: online_tune.update_control( # pylint: disable=g-complex-comprehension control_config, control_action, history, action_multipliers) for (control_action, control_config) in zip(action, control_configs) } return lambda _: controls