def update_optimization_state(self, output_dir, policy_and_value_opt_state=None): (self._policy_and_value_opt_state, self._model_state, self._epoch, self._total_opt_step) = ppo.maybe_restore_opt_state( output_dir, policy_and_value_opt_state, self._model_state) if self._epoch > 0: logging.info("Restored parameters from epoch [%d]", self._epoch)
def test_saves_and_restores_opt_state(self): opt_state = 123 state = 456 epoch = 7 opt_step = 89 output_dir = self.get_temp_dir() ppo.save_opt_state(output_dir, opt_state, state, epoch, opt_step) restored_data = ppo.maybe_restore_opt_state(output_dir) self.assertEqual(restored_data, (opt_state, state, epoch, opt_step))
def __init__(self, train_env, eval_env, output_dir, policy_and_value_model=trax_models.FrameStackMLP, policy_and_value_optimizer=functools.partial( trax_opt.Adam, learning_rate=1e-3), policy_and_value_two_towers=False, n_optimizer_steps=N_OPTIMIZER_STEPS, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, target_kl=0.01, boundary=20, max_timestep=None, max_timestep_eval=20000, random_seed=None, gamma=GAMMA, lambda_=LAMBDA, c1=1.0, c2=0.01, eval_every_n=1000, done_frac_for_policy_save=0.5, n_evals=1, len_history_for_policy=4, eval_temperatures=(1.0, 0.5), **kwargs): """Creates the PPO trainer. Args: train_env: gym.Env to use for training. eval_env: gym.Env to use for evaluation. output_dir: Output dir. policy_and_value_model: Function defining the policy and value network, without the policy and value heads. policy_and_value_optimizer: Function defining the optimizer. policy_and_value_two_towers: Whether to use two separate models as the policy and value networks. If False, share their parameters. n_optimizer_steps: Number of optimizer steps. print_every_optimizer_steps: How often to log during the policy optimization process. target_kl: Policy iteration early stopping. Set to infinity to disable early stopping. boundary: We pad trajectories at integer multiples of this number. max_timestep: If set to an integer, maximum number of time-steps in a trajectory. Used in the collect procedure. max_timestep_eval: If set to an integer, maximum number of time-steps in an evaluation trajectory. Used in the collect procedure. random_seed: Random seed. gamma: Reward discount factor. lambda_: N-step TD-error discount factor in GAE. c1: Value loss coefficient. c2: Entropy loss coefficient. eval_every_n: How frequently to eval the policy. done_frac_for_policy_save: Fraction of the trajectories that should be done to checkpoint the policy. n_evals: Number of times to evaluate. len_history_for_policy: How much of history to give to the policy. eval_temperatures: Sequence of temperatures to try for categorical sampling during evaluation. **kwargs: Additional keyword arguments passed to the base class. """ # Set in base class constructor. self._train_env = None self._should_reset = None super(PPO, self).__init__(train_env, eval_env, output_dir, **kwargs) self._n_optimizer_steps = n_optimizer_steps self._print_every_optimizer_steps = print_every_optimizer_steps self._target_kl = target_kl self._boundary = boundary self._max_timestep = max_timestep self._max_timestep_eval = max_timestep_eval self._gamma = gamma self._lambda_ = lambda_ self._c1 = c1 self._c2 = c2 self._eval_every_n = eval_every_n self._done_frac_for_policy_save = done_frac_for_policy_save self._n_evals = n_evals self._len_history_for_policy = len_history_for_policy self._eval_temperatures = eval_temperatures assert isinstance(self.train_env.action_space, gym.spaces.Discrete) n_actions = self.train_env.action_space.n # Batch Observations Shape = [1, 1] + OBS, because we will eventually call # policy and value networks on shape [B, T] +_OBS batch_observations_shape = (1, 1) + self.train_env.observation_space.shape observations_dtype = self.train_env.observation_space.dtype self._rng = trax.get_random_number_generator_and_set_seed(random_seed) self._rng, key1 = jax_random.split(self._rng, num=2) # Initialize the policy and value network. policy_and_value_net_params, self._model_state, policy_and_value_net_apply = ( ppo.policy_and_value_net( rng_key=key1, batch_observations_shape=batch_observations_shape, observations_dtype=observations_dtype, n_actions=n_actions, bottom_layers_fn=policy_and_value_model, two_towers=policy_and_value_two_towers, )) self._policy_and_value_net_apply = jit(policy_and_value_net_apply) # Initialize the optimizer. (policy_and_value_opt_state, self._policy_and_value_opt_update, self._policy_and_value_get_params) = ppo.optimizer_fn( policy_and_value_optimizer, policy_and_value_net_params) # Maybe restore the optimization state. If there is nothing to restore, then # iteration = 0 and policy_and_value_opt_state is returned as is. (restored, self._policy_and_value_opt_state, self._model_state, self._epoch, self._total_opt_step) = ppo.maybe_restore_opt_state( output_dir, policy_and_value_opt_state, self._model_state) if restored: logging.info("Restored parameters from iteration [%d]", self._epoch) # We should start from the next iteration. self._epoch += 1 # Create summary writers and history. self._train_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "train")) self._timing_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "timing")) self._eval_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "eval")) self._n_trajectories_done = 0 self._last_saved_at = 0
def PolicySchedule( history, observation_metrics=( ("train", "metrics/accuracy"), ("train", "metrics/loss"), ("eval", "metrics/accuracy"), ("eval", "metrics/loss"), ), include_lr_in_observation=False, observation_range=(0.0, 5.0), start_lr=0.001, max_lr=10.0, action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5), policy_and_value_model=trax_models.FrameStackMLP, policy_and_value_two_towers=False, policy_dir=gin.REQUIRED, ): """Learning rate schedule controlled by a learned policy. Args: history: the history of training and evaluation (History object). observation_metrics: list of pairs (mode, metric), as in the History object. include_lr_in_observation: bool, whether to include the learning rate in observations. observation_range: tuple (low, high), range to clip the observation to. start_lr: starting learning rate. max_lr: maximum value to clip the learning rate to. action_multipliers: sequence of LR multipliers that policy actions correspond to. policy_and_value_model: Trax model to use as the policy. policy_and_value_two_towers: bool, whether the action distribution and value prediction is computed by separate model towers. policy_dir: directory with the policy checkpoint. Returns: a function learning_rate(step): float -> float, the step-dependent lr. """ # Turn the history into observations for the policy. If we don't have any, # return the initial learning rate. start_time = time.time() observations = online_tune.history_to_observations( history, observation_metrics, observation_range, include_lr_in_observation) logging.vlog(1, "Building observations took %0.2f sec.", time.time() - start_time) if observations.shape[0] == 0: return lambda _: start_lr # Build the policy network and load its parameters. start_time = time.time() net = ppo.policy_and_value_net( n_actions=len(action_multipliers), bottom_layers_fn=policy_and_value_model, two_towers=policy_and_value_two_towers, ) logging.vlog(1, "Building the policy network took %0.2f sec.", time.time() - start_time) start_time = time.time() # (opt_state, state, epoch, opt_step) (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir) assert opt_state is not None, "Policy checkpoint not found." (params, _) = opt_state logging.vlog(1, "Restoring the policy parameters took %0.2f sec.", time.time() - start_time) # Run the policy and sample an action. seed = random.randint(0, 2**31 - 1) rng = jax_random.get_prng(seed=seed) start_time = time.time() # ((log_probs, value_preds), state). We have no way to pass state to the next # step, but that should be fine. ((log_probs, _), _) = net(np.array([observations]), params, state, rng=rng) logging.vlog(1, "Running the policy took %0.2f sec.", time.time() - start_time) # Sample from the action distribution for the last timestep. action = utils.gumbel_sample(log_probs[0, -1, :]) # Get a new learning rate. new_lr = online_tune.new_learning_rate(action, history, action_multipliers, max_lr) return lambda _: new_lr
def PolicySchedule( history, observation_metrics=( ("train", "metrics/accuracy"), ("train", "metrics/loss"), ("eval", "metrics/accuracy"), ("eval", "metrics/loss"), ), include_controls_in_observation=False, control_configs=( # (name, start, (low, high), flip) ("learning_rate", 1e-3, (1e-9, 10.0), False), ), observation_range=(0.0, 10.0), action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5), policy_and_value_model=trax_models.FrameStackMLP, policy_and_value_two_towers=False, policy_and_value_vocab_size=None, policy_dir=gin.REQUIRED, temperature=1.0, ): """Learning rate schedule controlled by a learned policy. Args: history: the history of training and evaluation (History object). observation_metrics: list of pairs (mode, metric), as in the History object. include_controls_in_observation: bool, whether to include the controls in observations. control_configs: control configs, see trax.rl.envs.OnlineTuneEnv. observation_range: tuple (low, high), range to clip the metrics to. action_multipliers: sequence of LR multipliers that policy actions correspond to. policy_and_value_model: Trax model to use as the policy. policy_and_value_two_towers: bool, whether the action distribution and value prediction is computed by separate model towers. policy_and_value_vocab_size: vocabulary size of a policy and value network operating on serialized representation. If None, use raw continuous representation. policy_dir: directory with the policy checkpoint. temperature: temperature for sampling from the policy. Returns: a function nontrainable_params(step): float -> {"name": float}, the step-dependent schedule for nontrainable parameters. """ # Turn the history into observations for the policy. If we don't have any, # return the initial learning rate. start_time = time.time() observations = online_tune.history_to_observations( history, observation_metrics, observation_range, control_configs if include_controls_in_observation else None) logging.vlog(1, "Building observations took %0.2f sec.", time.time() - start_time) if observations.shape[0] == 0: controls = { name: start_value for (name, start_value, _, _) in control_configs } return lambda _: controls assert policy_and_value_vocab_size is None, ( "Serialized policies are not supported yet.") # Build the policy network and load its parameters. start_time = time.time() net = ppo.policy_and_value_net( n_controls=len(control_configs), n_actions=len(action_multipliers), vocab_size=policy_and_value_vocab_size, bottom_layers_fn=policy_and_value_model, two_towers=policy_and_value_two_towers, ) logging.vlog(1, "Building the policy network took %0.2f sec.", time.time() - start_time) start_time = time.time() # (opt_state, state, epoch, opt_step) (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir) assert opt_state is not None, "Policy checkpoint not found." (params, _) = opt_state logging.vlog(1, "Restoring the policy parameters took %0.2f sec.", time.time() - start_time) # Run the policy and sample an action. seed = random.randint(0, 2**31 - 1) rng = jax_random.get_prng(seed=seed) start_time = time.time() # ((log_probs, value_preds), state). We have no way to pass state to the next # step, but that should be fine. (log_probs, _) = (net(np.array([observations]), params=params, state=state, rng=rng)) logging.vlog(1, "Running the policy took %0.2f sec.", time.time() - start_time) # Sample from the action distribution for the last timestep. assert log_probs.shape == (1, len(control_configs) * observations.shape[0], len(action_multipliers)) action = utils.gumbel_sample(log_probs[0, -len(control_configs):, :] / temperature) # Get new controls. controls = { # name: value control_config[0]: online_tune.update_control( # pylint: disable=g-complex-comprehension control_config, control_action, history, action_multipliers) for (control_action, control_config) in zip(action, control_configs) } return lambda _: controls