def test_policy_and_value_net(self): observation_shape = (3, 4, 5) batch_observation_shape = (1, 1) + observation_shape n_actions = 2 n_controls = 3 pnv_model = ppo.policy_and_value_net( n_controls=n_controls, n_actions=n_actions, vocab_size=None, bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)], two_towers=True, ) input_signature = ShapeDtype(batch_observation_shape) _, _ = pnv_model.init(input_signature) batch = 2 time_steps = 10 batch_of_observations = np.random.uniform( size=(batch, time_steps) + observation_shape) pnv_output = pnv_model(batch_of_observations) # Output is a list, first is probab of actions and the next is value output. self.assertEqual(2, len(pnv_output)) self.assertEqual( (batch, time_steps * n_controls, n_actions), pnv_output[0].shape) self.assertEqual((batch, time_steps * n_controls), pnv_output[1].shape)
def _make_schedule( self, history, control_configs, observation_metrics=(('eval', 'metrics/accuracy'), ), action_multipliers=(1.0, ), vocab_size=None, ): policy_and_value_model = functools.partial( transformer.TransformerDecoder, d_model=2, d_ff=2, n_layers=0, vocab_size=vocab_size, ) net = ppo.policy_and_value_net( n_actions=len(action_multipliers), n_controls=len(control_configs), vocab_size=None, bottom_layers_fn=policy_and_value_model, two_towers=False, ) obs_dim = len(observation_metrics) if vocab_size is None: shape = (1, 1, obs_dim) dtype = np.float32 else: shape = (1, 1) dtype = np.int32 input_signature = ShapeDtype(shape, dtype) (params, state) = net.init(input_signature) policy_dir = self.get_temp_dir() # Optimizer slots and parameters should not be used for anything. slots = None opt_params = None opt_state = (params, slots, opt_params) ppo.save_opt_state(policy_dir, opt_state, state, epoch=0, total_opt_step=0, history=history) return lr_schedules.PolicySchedule( history, observation_metrics=observation_metrics, include_controls_in_observation=False, action_multipliers=action_multipliers, control_configs=control_configs, policy_and_value_model=policy_and_value_model, policy_and_value_two_towers=False, policy_and_value_vocab_size=vocab_size, policy_dir=policy_dir, )
def test_inits_policy_by_world_model_checkpoint(self): transformer_kwargs = { "d_model": 1, "d_ff": 1, "n_layers": 1, "n_heads": 1, "max_len": 128, "mode": "train", } rng = jax_random.PRNGKey(123) init_kwargs = { "input_shapes": (1, 1), "input_dtype": np.int32, "rng": rng, } model_fn = functools.partial(models.TransformerLM, vocab_size=4, **transformer_kwargs) output_dir = self.get_temp_dir() # Initialize a world model checkpoint by running the trainer. trainer_lib.train( output_dir, model=model_fn, inputs=functools.partial(inputs.random_inputs, input_shape=(1, 1), output_shape=(1, 1)), train_steps=1, eval_steps=1, ) policy = ppo.policy_and_value_net( n_actions=3, n_controls=2, vocab_size=4, bottom_layers_fn=functools.partial(models.TransformerDecoder, **transformer_kwargs), two_towers=False, ) (policy_params, policy_state) = policy.initialize_once(**init_kwargs) # Initialize policy parameters from world model parameters. new_policy_params = ppo.init_policy_from_world_model_checkpoint( policy_params, output_dir) # Try to run the policy with new parameters. observations = np.zeros((1, 100), dtype=np.int32) policy(observations, params=new_policy_params, state=policy_state, rng=rng)
def test_inits_policy_by_world_model_checkpoint(self): transformer_kwargs = { 'd_model': 1, 'd_ff': 1, 'n_layers': 1, 'n_heads': 1, 'max_len': 128, 'mode': 'train', } rng = jax_random.PRNGKey(123) model_fn = functools.partial(models.TransformerLM, vocab_size=4, **transformer_kwargs) output_dir = self.get_temp_dir() # Initialize a world model checkpoint by running the trainer. trainer_lib.train( output_dir, model=model_fn, inputs=functools.partial(inputs.random_inputs, input_shape=(1, 1), output_shape=(1, 1)), steps=1, eval_steps=1, ) make_policy = lambda: ppo.policy_and_value_net( # pylint: disable=g-long-lambda n_actions=3, n_controls=2, vocab_size=4, bottom_layers_fn=functools.partial(models.TransformerDecoder, ** transformer_kwargs), two_towers=False, ) policy = make_policy() input_signature = ShapeDtype((1, 1), np.int32) policy._set_rng_recursive(rng) policy_params, policy_state = make_policy().init(input_signature) # Initialize policy parameters from world model parameters. new_policy_params = ppo.init_policy_from_world_model_checkpoint( policy_params, output_dir) # Try to run the policy with new parameters. observations = np.zeros((1, 100), dtype=np.int32) policy(observations, weights=new_policy_params, state=policy_state, rng=rng)
def _make_schedule( self, history, control_configs, observation_metrics=(("eval", "metrics/accuracy"), ), action_multipliers=(1.0, ), ): policy_and_value_model = atari_cnn.FrameStackMLP net = ppo.policy_and_value_net( n_actions=len(action_multipliers), n_controls=len(control_configs), vocab_size=None, bottom_layers_fn=policy_and_value_model, two_towers=False, ) rng = jax_random.get_prng(seed=0) obs_dim = len(observation_metrics) (params, state) = net.initialize_once((1, 1, obs_dim), np.float32, rng) policy_dir = self.get_temp_dir() # Optimizer slots should not be used for anything. slots = None opt_state = (params, slots) ppo.save_opt_state(policy_dir, opt_state, state, epoch=0, total_opt_step=0) return learning_rate.PolicySchedule( history, observation_metrics=observation_metrics, include_controls_in_observation=False, action_multipliers=action_multipliers, control_configs=control_configs, policy_and_value_model=policy_and_value_model, policy_and_value_two_towers=False, policy_dir=policy_dir, )
def test_combined_loss(self): B, T, A, OBS = 2, 10, 2, (28, 28, 3) # pylint: disable=invalid-name batch_observation_shape = (1, 1) + OBS net = ppo.policy_and_value_net( n_controls=1, n_actions=A, vocab_size=None, bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)], two_towers=True, ) input_signature = ShapeDtype(batch_observation_shape) old_params, _ = net.init(input_signature) new_params, state = net.init(input_signature) # Generate a batch of observations. observations = np.random.uniform(size=(B, T + 1) + OBS) actions = np.random.randint(0, A, size=(B, T + 1)) rewards = np.random.uniform(0, 1, size=(B, T)) mask = np.ones_like(rewards) # Just test that this computes at all. (new_log_probabs, value_predictions_new) = ( net(observations, weights=new_params, state=state)) (old_log_probabs, value_predictions_old) = ( net(observations, weights=old_params, state=state)) gamma = 0.99 lambda_ = 0.95 epsilon = 0.2 value_weight = 1.0 entropy_weight = 0.01 nontrainable_params = { 'gamma': gamma, 'lambda': lambda_, 'epsilon': epsilon, 'value_weight': value_weight, 'entropy_weight': entropy_weight, } rewards_to_actions = np.eye(value_predictions_old.shape[1]) (value_loss_1, _) = ppo.value_loss_given_predictions( value_predictions_new, rewards, mask, gamma=gamma, value_prediction_old=value_predictions_old, epsilon=epsilon) (ppo_loss_1, _) = ppo.ppo_loss_given_predictions( new_log_probabs, old_log_probabs, value_predictions_old, actions, rewards_to_actions, rewards, mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon) (combined_loss, (ppo_loss_2, value_loss_2, entropy_bonus), _, state) = ( ppo.combined_loss(new_params, old_log_probabs, value_predictions_old, net, observations, actions, rewards_to_actions, rewards, mask, nontrainable_params=nontrainable_params, state=state) ) # Test that these compute at all and are self consistent. self.assertGreater(entropy_bonus, 0.0) self.assertNear(value_loss_1, value_loss_2, 1e-6) self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6) self.assertNear( combined_loss, ppo_loss_2 + (value_weight * value_loss_2) - (entropy_weight * entropy_bonus), 1e-6 )
def __init__(self, train_env, eval_env, output_dir, policy_and_value_model=trax_models.FrameStackMLP, policy_and_value_optimizer=functools.partial( trax_opt.Adam, learning_rate=1e-3), policy_and_value_two_towers=False, policy_and_value_vocab_size=None, n_optimizer_steps=N_OPTIMIZER_STEPS, optimizer_batch_size=64, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, target_kl=0.01, boundary=20, max_timestep=100, max_timestep_eval=20000, random_seed=None, gamma=GAMMA, lambda_=LAMBDA, c1=1.0, c2=0.01, eval_every_n=1000, save_every_n=1000, done_frac_for_policy_save=0.5, n_evals=1, len_history_for_policy=4, eval_temperatures=(1.0, 0.5), separate_eval=True, init_policy_from_world_model_output_dir=None, **kwargs): """Creates the PPO trainer. Args: train_env: gym.Env to use for training. eval_env: gym.Env to use for evaluation. output_dir: Output dir. policy_and_value_model: Function defining the policy and value network, without the policy and value heads. policy_and_value_optimizer: Function defining the optimizer. policy_and_value_two_towers: Whether to use two separate models as the policy and value networks. If False, share their parameters. policy_and_value_vocab_size: Vocabulary size of a policy and value network operating on serialized representation. If None, use raw continuous representation. n_optimizer_steps: Number of optimizer steps. optimizer_batch_size: Batch size of an optimizer step. print_every_optimizer_steps: How often to log during the policy optimization process. target_kl: Policy iteration early stopping. Set to infinity to disable early stopping. boundary: We pad trajectories at integer multiples of this number. max_timestep: If set to an integer, maximum number of time-steps in a trajectory. Used in the collect procedure. max_timestep_eval: If set to an integer, maximum number of time-steps in an evaluation trajectory. Used in the collect procedure. random_seed: Random seed. gamma: Reward discount factor. lambda_: N-step TD-error discount factor in GAE. c1: Value loss coefficient. c2: Entropy loss coefficient. eval_every_n: How frequently to eval the policy. save_every_n: How frequently to save the policy. done_frac_for_policy_save: Fraction of the trajectories that should be done to checkpoint the policy. n_evals: Number of times to evaluate. len_history_for_policy: How much of history to give to the policy. eval_temperatures: Sequence of temperatures to try for categorical sampling during evaluation. separate_eval: Whether to run separate evaluation using a set of temperatures. If False, the training reward is reported as evaluation reward with temperature 1.0. init_policy_from_world_model_output_dir: Model output dir for initializing the policy. If None, initialize randomly. **kwargs: Additional keyword arguments passed to the base class. """ # Set in base class constructor. self._train_env = None self._should_reset = None super(PPO, self).__init__(train_env, eval_env, output_dir, **kwargs) self._n_optimizer_steps = n_optimizer_steps self._optimizer_batch_size = optimizer_batch_size self._print_every_optimizer_steps = print_every_optimizer_steps self._target_kl = target_kl self._boundary = boundary self._max_timestep = max_timestep self._max_timestep_eval = max_timestep_eval self._gamma = gamma self._lambda_ = lambda_ self._c1 = c1 self._c2 = c2 self._eval_every_n = eval_every_n self._save_every_n = save_every_n self._done_frac_for_policy_save = done_frac_for_policy_save self._n_evals = n_evals self._len_history_for_policy = len_history_for_policy self._eval_temperatures = eval_temperatures self._separate_eval = separate_eval action_space = self.train_env.action_space assert isinstance(action_space, (gym.spaces.Discrete, gym.spaces.MultiDiscrete)) if isinstance(action_space, gym.spaces.Discrete): n_actions = action_space.n n_controls = 1 else: (n_controls, ) = action_space.nvec.shape assert n_controls > 0 assert onp.min(action_space.nvec) == onp.max(action_space.nvec), ( "Every control must have the same number of actions.") n_actions = action_space.nvec[0] self._n_actions = n_actions self._n_controls = n_controls self._rng = trainer_lib.get_random_number_generator_and_set_seed( random_seed) self._rng, key1 = jax_random.split(self._rng, num=2) vocab_size = policy_and_value_vocab_size self._serialized_sequence_policy = vocab_size is not None if self._serialized_sequence_policy: self._serialization_kwargs = self._init_serialization(vocab_size) else: self._serialization_kwargs = {} # Initialize the policy and value network. policy_and_value_net = ppo.policy_and_value_net( n_actions=n_actions, n_controls=n_controls, vocab_size=vocab_size, bottom_layers_fn=policy_and_value_model, two_towers=policy_and_value_two_towers, ) self._policy_and_value_net_apply = jit(policy_and_value_net) (batch_obs_shape, obs_dtype) = self._batch_obs_shape_and_dtype policy_and_value_net_params, self._model_state = ( policy_and_value_net.initialize_once(batch_obs_shape, obs_dtype, key1)) if init_policy_from_world_model_output_dir is not None: policy_and_value_net_params = ppo.init_policy_from_world_model_checkpoint( policy_and_value_net_params, init_policy_from_world_model_output_dir) # Initialize the optimizer. (policy_and_value_opt_state, self._policy_and_value_opt_update, self._policy_and_value_get_params) = ppo.optimizer_fn( policy_and_value_optimizer, policy_and_value_net_params) # Restore the optimizer state. self._policy_and_value_opt_state = policy_and_value_opt_state self._epoch = 0 self._total_opt_step = 0 self.update_optimization_state( output_dir, policy_and_value_opt_state=policy_and_value_opt_state) # Create summary writers and history. self._train_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "train")) self._timing_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "timing")) self._eval_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "eval")) self._n_trajectories_done = 0 self._last_saved_at = 0 if self._async_mode: logging.info( "Saving model on startup to have a model policy file.") self.save() self._rewards_to_actions = self._init_rewards_to_actions()
def PolicySchedule( history, observation_metrics=( ("train", "metrics/accuracy"), ("train", "metrics/loss"), ("eval", "metrics/accuracy"), ("eval", "metrics/loss"), ), include_controls_in_observation=False, control_configs=( # (name, start, (low, high), flip) ("learning_rate", 1e-3, (1e-9, 10.0), False), ), observation_range=(0.0, 10.0), action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5), policy_and_value_model=trax_models.FrameStackMLP, policy_and_value_two_towers=False, policy_and_value_vocab_size=None, policy_dir=gin.REQUIRED, temperature=1.0, ): """Learning rate schedule controlled by a learned policy. Args: history: the history of training and evaluation (History object). observation_metrics: list of pairs (mode, metric), as in the History object. include_controls_in_observation: bool, whether to include the controls in observations. control_configs: control configs, see trax.rl.envs.OnlineTuneEnv. observation_range: tuple (low, high), range to clip the metrics to. action_multipliers: sequence of LR multipliers that policy actions correspond to. policy_and_value_model: Trax model to use as the policy. policy_and_value_two_towers: bool, whether the action distribution and value prediction is computed by separate model towers. policy_and_value_vocab_size: vocabulary size of a policy and value network operating on serialized representation. If None, use raw continuous representation. policy_dir: directory with the policy checkpoint. temperature: temperature for sampling from the policy. Returns: a function nontrainable_params(step): float -> {"name": float}, the step-dependent schedule for nontrainable parameters. """ # Turn the history into observations for the policy. If we don't have any, # return the initial learning rate. start_time = time.time() observations = online_tune.history_to_observations( history, observation_metrics, observation_range, control_configs if include_controls_in_observation else None) logging.vlog(1, "Building observations took %0.2f sec.", time.time() - start_time) if observations.shape[0] == 0: controls = { name: start_value for (name, start_value, _, _) in control_configs } return lambda _: controls assert policy_and_value_vocab_size is None, ( "Serialized policies are not supported yet.") # Build the policy network and load its parameters. start_time = time.time() net = ppo.policy_and_value_net( n_controls=len(control_configs), n_actions=len(action_multipliers), vocab_size=policy_and_value_vocab_size, bottom_layers_fn=policy_and_value_model, two_towers=policy_and_value_two_towers, ) logging.vlog(1, "Building the policy network took %0.2f sec.", time.time() - start_time) start_time = time.time() # (opt_state, state, epoch, opt_step) (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir) assert opt_state is not None, "Policy checkpoint not found." (params, _) = opt_state logging.vlog(1, "Restoring the policy parameters took %0.2f sec.", time.time() - start_time) # Run the policy and sample an action. seed = random.randint(0, 2**31 - 1) rng = jax_random.get_prng(seed=seed) start_time = time.time() # ((log_probs, value_preds), state). We have no way to pass state to the next # step, but that should be fine. (log_probs, _) = (net(np.array([observations]), params=params, state=state, rng=rng)) logging.vlog(1, "Running the policy took %0.2f sec.", time.time() - start_time) # Sample from the action distribution for the last timestep. assert log_probs.shape == (1, len(control_configs) * observations.shape[0], len(action_multipliers)) action = utils.gumbel_sample(log_probs[0, -len(control_configs):, :] / temperature) # Get new controls. controls = { # name: value control_config[0]: online_tune.update_control( # pylint: disable=g-complex-comprehension control_config, control_action, history, action_multipliers) for (control_action, control_config) in zip(action, control_configs) } return lambda _: controls
def test_combined_loss(self): self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3) B, T, A, OBS = 2, 10, 2, (28, 28, 3) # pylint: disable=invalid-name batch_observation_shape = (1, 1) + OBS net = ppo.policy_and_value_net( n_controls=1, n_actions=A, vocab_size=None, bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)], two_towers=True, ) old_params, _ = net.initialize_once(batch_observation_shape, np.float32, key1) new_params, state = net.initialize_once(batch_observation_shape, np.float32, key2) # Generate a batch of observations. observations = np.random.uniform(size=(B, T + 1) + OBS) actions = np.random.randint(0, A, size=(B, T + 1)) rewards = np.random.uniform(0, 1, size=(B, T)) mask = np.ones_like(rewards) # Just test that this computes at all. (new_log_probabs, value_predictions_new) = (net(observations, params=new_params, state=state)) (old_log_probabs, value_predictions_old) = (net(observations, params=old_params, state=state)) gamma = 0.99 lambda_ = 0.95 epsilon = 0.2 c1 = 1.0 c2 = 0.01 rewards_to_actions = np.eye(value_predictions_old.shape[1]) (value_loss_1, _) = ppo.value_loss_given_predictions( value_predictions_new, rewards, mask, gamma=gamma, value_prediction_old=value_predictions_old, epsilon=epsilon) (ppo_loss_1, _) = ppo.ppo_loss_given_predictions(new_log_probabs, old_log_probabs, value_predictions_old, actions, rewards_to_actions, rewards, mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon) (combined_loss, (ppo_loss_2, value_loss_2, entropy_bonus), _, state) = (ppo.combined_loss(new_params, old_log_probabs, value_predictions_old, net, observations, actions, rewards_to_actions, rewards, mask, gamma=gamma, lambda_=lambda_, epsilon=epsilon, c1=c1, c2=c2, state=state)) # Test that these compute at all and are self consistent. self.assertGreater(entropy_bonus, 0.0) self.assertNear(value_loss_1, value_loss_2, 1e-6) self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6) self.assertNear( combined_loss, ppo_loss_2 + (c1 * value_loss_2) - (c2 * entropy_bonus), 1e-6)