Beispiel #1
0
  def test_policy_and_value_net(self):
    observation_shape = (3, 4, 5)
    batch_observation_shape = (1, 1) + observation_shape
    n_actions = 2
    n_controls = 3
    pnv_model = ppo.policy_and_value_net(
        n_controls=n_controls,
        n_actions=n_actions,
        vocab_size=None,
        bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)],
        two_towers=True,
    )
    input_signature = ShapeDtype(batch_observation_shape)
    _, _ = pnv_model.init(input_signature)

    batch = 2
    time_steps = 10
    batch_of_observations = np.random.uniform(
        size=(batch, time_steps) + observation_shape)
    pnv_output = pnv_model(batch_of_observations)

    # Output is a list, first is probab of actions and the next is value output.
    self.assertEqual(2, len(pnv_output))
    self.assertEqual(
        (batch, time_steps * n_controls, n_actions), pnv_output[0].shape)
    self.assertEqual((batch, time_steps * n_controls), pnv_output[1].shape)
Beispiel #2
0
 def _make_schedule(
         self,
         history,
         control_configs,
         observation_metrics=(('eval', 'metrics/accuracy'), ),
         action_multipliers=(1.0, ),
         vocab_size=None,
 ):
     policy_and_value_model = functools.partial(
         transformer.TransformerDecoder,
         d_model=2,
         d_ff=2,
         n_layers=0,
         vocab_size=vocab_size,
     )
     net = ppo.policy_and_value_net(
         n_actions=len(action_multipliers),
         n_controls=len(control_configs),
         vocab_size=None,
         bottom_layers_fn=policy_and_value_model,
         two_towers=False,
     )
     obs_dim = len(observation_metrics)
     if vocab_size is None:
         shape = (1, 1, obs_dim)
         dtype = np.float32
     else:
         shape = (1, 1)
         dtype = np.int32
     input_signature = ShapeDtype(shape, dtype)
     (params, state) = net.init(input_signature)
     policy_dir = self.get_temp_dir()
     # Optimizer slots and parameters should not be used for anything.
     slots = None
     opt_params = None
     opt_state = (params, slots, opt_params)
     ppo.save_opt_state(policy_dir,
                        opt_state,
                        state,
                        epoch=0,
                        total_opt_step=0,
                        history=history)
     return lr_schedules.PolicySchedule(
         history,
         observation_metrics=observation_metrics,
         include_controls_in_observation=False,
         action_multipliers=action_multipliers,
         control_configs=control_configs,
         policy_and_value_model=policy_and_value_model,
         policy_and_value_two_towers=False,
         policy_and_value_vocab_size=vocab_size,
         policy_dir=policy_dir,
     )
Beispiel #3
0
    def test_inits_policy_by_world_model_checkpoint(self):
        transformer_kwargs = {
            "d_model": 1,
            "d_ff": 1,
            "n_layers": 1,
            "n_heads": 1,
            "max_len": 128,
            "mode": "train",
        }
        rng = jax_random.PRNGKey(123)
        init_kwargs = {
            "input_shapes": (1, 1),
            "input_dtype": np.int32,
            "rng": rng,
        }
        model_fn = functools.partial(models.TransformerLM,
                                     vocab_size=4,
                                     **transformer_kwargs)
        output_dir = self.get_temp_dir()
        # Initialize a world model checkpoint by running the trainer.
        trainer_lib.train(
            output_dir,
            model=model_fn,
            inputs=functools.partial(inputs.random_inputs,
                                     input_shape=(1, 1),
                                     output_shape=(1, 1)),
            train_steps=1,
            eval_steps=1,
        )

        policy = ppo.policy_and_value_net(
            n_actions=3,
            n_controls=2,
            vocab_size=4,
            bottom_layers_fn=functools.partial(models.TransformerDecoder,
                                               **transformer_kwargs),
            two_towers=False,
        )
        (policy_params, policy_state) = policy.initialize_once(**init_kwargs)

        # Initialize policy parameters from world model parameters.
        new_policy_params = ppo.init_policy_from_world_model_checkpoint(
            policy_params, output_dir)
        # Try to run the policy with new parameters.
        observations = np.zeros((1, 100), dtype=np.int32)
        policy(observations,
               params=new_policy_params,
               state=policy_state,
               rng=rng)
Beispiel #4
0
    def test_inits_policy_by_world_model_checkpoint(self):
        transformer_kwargs = {
            'd_model': 1,
            'd_ff': 1,
            'n_layers': 1,
            'n_heads': 1,
            'max_len': 128,
            'mode': 'train',
        }
        rng = jax_random.PRNGKey(123)
        model_fn = functools.partial(models.TransformerLM,
                                     vocab_size=4,
                                     **transformer_kwargs)
        output_dir = self.get_temp_dir()
        # Initialize a world model checkpoint by running the trainer.
        trainer_lib.train(
            output_dir,
            model=model_fn,
            inputs=functools.partial(inputs.random_inputs,
                                     input_shape=(1, 1),
                                     output_shape=(1, 1)),
            steps=1,
            eval_steps=1,
        )

        make_policy = lambda: ppo.policy_and_value_net(  # pylint: disable=g-long-lambda
            n_actions=3,
            n_controls=2,
            vocab_size=4,
            bottom_layers_fn=functools.partial(models.TransformerDecoder, **
                                               transformer_kwargs),
            two_towers=False,
        )
        policy = make_policy()
        input_signature = ShapeDtype((1, 1), np.int32)
        policy._set_rng_recursive(rng)
        policy_params, policy_state = make_policy().init(input_signature)

        # Initialize policy parameters from world model parameters.
        new_policy_params = ppo.init_policy_from_world_model_checkpoint(
            policy_params, output_dir)
        # Try to run the policy with new parameters.
        observations = np.zeros((1, 100), dtype=np.int32)
        policy(observations,
               weights=new_policy_params,
               state=policy_state,
               rng=rng)
Beispiel #5
0
 def _make_schedule(
         self,
         history,
         control_configs,
         observation_metrics=(("eval", "metrics/accuracy"), ),
         action_multipliers=(1.0, ),
 ):
     policy_and_value_model = atari_cnn.FrameStackMLP
     net = ppo.policy_and_value_net(
         n_actions=len(action_multipliers),
         n_controls=len(control_configs),
         vocab_size=None,
         bottom_layers_fn=policy_and_value_model,
         two_towers=False,
     )
     rng = jax_random.get_prng(seed=0)
     obs_dim = len(observation_metrics)
     (params, state) = net.initialize_once((1, 1, obs_dim), np.float32, rng)
     policy_dir = self.get_temp_dir()
     # Optimizer slots should not be used for anything.
     slots = None
     opt_state = (params, slots)
     ppo.save_opt_state(policy_dir,
                        opt_state,
                        state,
                        epoch=0,
                        total_opt_step=0)
     return learning_rate.PolicySchedule(
         history,
         observation_metrics=observation_metrics,
         include_controls_in_observation=False,
         action_multipliers=action_multipliers,
         control_configs=control_configs,
         policy_and_value_model=policy_and_value_model,
         policy_and_value_two_towers=False,
         policy_dir=policy_dir,
     )
Beispiel #6
0
  def test_combined_loss(self):
    B, T, A, OBS = 2, 10, 2, (28, 28, 3)  # pylint: disable=invalid-name
    batch_observation_shape = (1, 1) + OBS

    net = ppo.policy_and_value_net(
        n_controls=1,
        n_actions=A,
        vocab_size=None,
        bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)],
        two_towers=True,
    )

    input_signature = ShapeDtype(batch_observation_shape)
    old_params, _ = net.init(input_signature)
    new_params, state = net.init(input_signature)

    # Generate a batch of observations.

    observations = np.random.uniform(size=(B, T + 1) + OBS)
    actions = np.random.randint(0, A, size=(B, T + 1))
    rewards = np.random.uniform(0, 1, size=(B, T))
    mask = np.ones_like(rewards)

    # Just test that this computes at all.
    (new_log_probabs, value_predictions_new) = (
        net(observations, weights=new_params, state=state))
    (old_log_probabs, value_predictions_old) = (
        net(observations, weights=old_params, state=state))

    gamma = 0.99
    lambda_ = 0.95
    epsilon = 0.2
    value_weight = 1.0
    entropy_weight = 0.01

    nontrainable_params = {
        'gamma': gamma,
        'lambda': lambda_,
        'epsilon': epsilon,
        'value_weight': value_weight,
        'entropy_weight': entropy_weight,
    }

    rewards_to_actions = np.eye(value_predictions_old.shape[1])
    (value_loss_1, _) = ppo.value_loss_given_predictions(
        value_predictions_new, rewards, mask, gamma=gamma,
        value_prediction_old=value_predictions_old, epsilon=epsilon)
    (ppo_loss_1, _) = ppo.ppo_loss_given_predictions(
        new_log_probabs,
        old_log_probabs,
        value_predictions_old,
        actions,
        rewards_to_actions,
        rewards,
        mask,
        gamma=gamma,
        lambda_=lambda_,
        epsilon=epsilon)

    (combined_loss, (ppo_loss_2, value_loss_2, entropy_bonus), _, state) = (
        ppo.combined_loss(new_params,
                          old_log_probabs,
                          value_predictions_old,
                          net,
                          observations,
                          actions,
                          rewards_to_actions,
                          rewards,
                          mask,
                          nontrainable_params=nontrainable_params,
                          state=state)
    )

    # Test that these compute at all and are self consistent.
    self.assertGreater(entropy_bonus, 0.0)
    self.assertNear(value_loss_1, value_loss_2, 1e-6)
    self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6)
    self.assertNear(
        combined_loss,
        ppo_loss_2 + (value_weight * value_loss_2) -
        (entropy_weight * entropy_bonus),
        1e-6
    )
Beispiel #7
0
    def __init__(self,
                 train_env,
                 eval_env,
                 output_dir,
                 policy_and_value_model=trax_models.FrameStackMLP,
                 policy_and_value_optimizer=functools.partial(
                     trax_opt.Adam, learning_rate=1e-3),
                 policy_and_value_two_towers=False,
                 policy_and_value_vocab_size=None,
                 n_optimizer_steps=N_OPTIMIZER_STEPS,
                 optimizer_batch_size=64,
                 print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
                 target_kl=0.01,
                 boundary=20,
                 max_timestep=100,
                 max_timestep_eval=20000,
                 random_seed=None,
                 gamma=GAMMA,
                 lambda_=LAMBDA,
                 c1=1.0,
                 c2=0.01,
                 eval_every_n=1000,
                 save_every_n=1000,
                 done_frac_for_policy_save=0.5,
                 n_evals=1,
                 len_history_for_policy=4,
                 eval_temperatures=(1.0, 0.5),
                 separate_eval=True,
                 init_policy_from_world_model_output_dir=None,
                 **kwargs):
        """Creates the PPO trainer.

    Args:
      train_env: gym.Env to use for training.
      eval_env: gym.Env to use for evaluation.
      output_dir: Output dir.
      policy_and_value_model: Function defining the policy and value network,
        without the policy and value heads.
      policy_and_value_optimizer: Function defining the optimizer.
      policy_and_value_two_towers: Whether to use two separate models as the
        policy and value networks. If False, share their parameters.
      policy_and_value_vocab_size: Vocabulary size of a policy and value network
        operating on serialized representation. If None, use raw continuous
        representation.
      n_optimizer_steps: Number of optimizer steps.
      optimizer_batch_size: Batch size of an optimizer step.
      print_every_optimizer_steps: How often to log during the policy
        optimization process.
      target_kl: Policy iteration early stopping. Set to infinity to disable
        early stopping.
      boundary: We pad trajectories at integer multiples of this number.
      max_timestep: If set to an integer, maximum number of time-steps in a
        trajectory. Used in the collect procedure.
      max_timestep_eval: If set to an integer, maximum number of time-steps in
        an evaluation trajectory. Used in the collect procedure.
      random_seed: Random seed.
      gamma: Reward discount factor.
      lambda_: N-step TD-error discount factor in GAE.
      c1: Value loss coefficient.
      c2: Entropy loss coefficient.
      eval_every_n: How frequently to eval the policy.
      save_every_n: How frequently to save the policy.
      done_frac_for_policy_save: Fraction of the trajectories that should be
        done to checkpoint the policy.
      n_evals: Number of times to evaluate.
      len_history_for_policy: How much of history to give to the policy.
      eval_temperatures: Sequence of temperatures to try for categorical
        sampling during evaluation.
      separate_eval: Whether to run separate evaluation using a set of
        temperatures. If False, the training reward is reported as evaluation
        reward with temperature 1.0.
      init_policy_from_world_model_output_dir: Model output dir for initializing
        the policy. If None, initialize randomly.
      **kwargs: Additional keyword arguments passed to the base class.
    """
        # Set in base class constructor.
        self._train_env = None
        self._should_reset = None

        super(PPO, self).__init__(train_env, eval_env, output_dir, **kwargs)

        self._n_optimizer_steps = n_optimizer_steps
        self._optimizer_batch_size = optimizer_batch_size
        self._print_every_optimizer_steps = print_every_optimizer_steps
        self._target_kl = target_kl
        self._boundary = boundary
        self._max_timestep = max_timestep
        self._max_timestep_eval = max_timestep_eval
        self._gamma = gamma
        self._lambda_ = lambda_
        self._c1 = c1
        self._c2 = c2
        self._eval_every_n = eval_every_n
        self._save_every_n = save_every_n
        self._done_frac_for_policy_save = done_frac_for_policy_save
        self._n_evals = n_evals
        self._len_history_for_policy = len_history_for_policy
        self._eval_temperatures = eval_temperatures
        self._separate_eval = separate_eval

        action_space = self.train_env.action_space
        assert isinstance(action_space,
                          (gym.spaces.Discrete, gym.spaces.MultiDiscrete))
        if isinstance(action_space, gym.spaces.Discrete):
            n_actions = action_space.n
            n_controls = 1
        else:
            (n_controls, ) = action_space.nvec.shape
            assert n_controls > 0
            assert onp.min(action_space.nvec) == onp.max(action_space.nvec), (
                "Every control must have the same number of actions.")
            n_actions = action_space.nvec[0]
        self._n_actions = n_actions
        self._n_controls = n_controls

        self._rng = trainer_lib.get_random_number_generator_and_set_seed(
            random_seed)
        self._rng, key1 = jax_random.split(self._rng, num=2)

        vocab_size = policy_and_value_vocab_size
        self._serialized_sequence_policy = vocab_size is not None
        if self._serialized_sequence_policy:
            self._serialization_kwargs = self._init_serialization(vocab_size)
        else:
            self._serialization_kwargs = {}

        # Initialize the policy and value network.
        policy_and_value_net = ppo.policy_and_value_net(
            n_actions=n_actions,
            n_controls=n_controls,
            vocab_size=vocab_size,
            bottom_layers_fn=policy_and_value_model,
            two_towers=policy_and_value_two_towers,
        )
        self._policy_and_value_net_apply = jit(policy_and_value_net)
        (batch_obs_shape, obs_dtype) = self._batch_obs_shape_and_dtype
        policy_and_value_net_params, self._model_state = (
            policy_and_value_net.initialize_once(batch_obs_shape, obs_dtype,
                                                 key1))
        if init_policy_from_world_model_output_dir is not None:
            policy_and_value_net_params = ppo.init_policy_from_world_model_checkpoint(
                policy_and_value_net_params,
                init_policy_from_world_model_output_dir)

        # Initialize the optimizer.
        (policy_and_value_opt_state, self._policy_and_value_opt_update,
         self._policy_and_value_get_params) = ppo.optimizer_fn(
             policy_and_value_optimizer, policy_and_value_net_params)

        # Restore the optimizer state.
        self._policy_and_value_opt_state = policy_and_value_opt_state
        self._epoch = 0
        self._total_opt_step = 0
        self.update_optimization_state(
            output_dir, policy_and_value_opt_state=policy_and_value_opt_state)

        # Create summary writers and history.
        self._train_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "train"))
        self._timing_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "timing"))
        self._eval_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "eval"))

        self._n_trajectories_done = 0

        self._last_saved_at = 0
        if self._async_mode:
            logging.info(
                "Saving model on startup to have a model policy file.")
            self.save()

        self._rewards_to_actions = self._init_rewards_to_actions()
Beispiel #8
0
def PolicySchedule(
    history,
    observation_metrics=(
        ("train", "metrics/accuracy"),
        ("train", "metrics/loss"),
        ("eval", "metrics/accuracy"),
        ("eval", "metrics/loss"),
    ),
    include_controls_in_observation=False,
    control_configs=(
        # (name, start, (low, high), flip)
        ("learning_rate", 1e-3, (1e-9, 10.0), False), ),
    observation_range=(0.0, 10.0),
    action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5),
    policy_and_value_model=trax_models.FrameStackMLP,
    policy_and_value_two_towers=False,
    policy_and_value_vocab_size=None,
    policy_dir=gin.REQUIRED,
    temperature=1.0,
):
    """Learning rate schedule controlled by a learned policy.

  Args:
    history: the history of training and evaluation (History object).
    observation_metrics: list of pairs (mode, metric), as in the History object.
    include_controls_in_observation: bool, whether to include the controls in
      observations.
    control_configs: control configs, see trax.rl.envs.OnlineTuneEnv.
    observation_range: tuple (low, high), range to clip the metrics to.
    action_multipliers: sequence of LR multipliers that policy actions
      correspond to.
    policy_and_value_model: Trax model to use as the policy.
    policy_and_value_two_towers: bool, whether the action distribution and value
      prediction is computed by separate model towers.
    policy_and_value_vocab_size: vocabulary size of a policy and value network
      operating on serialized representation. If None, use raw continuous
      representation.
    policy_dir: directory with the policy checkpoint.
    temperature: temperature for sampling from the policy.

  Returns:
    a function nontrainable_params(step): float -> {"name": float}, the
    step-dependent schedule for nontrainable parameters.
  """

    # Turn the history into observations for the policy. If we don't have any,
    # return the initial learning rate.
    start_time = time.time()
    observations = online_tune.history_to_observations(
        history, observation_metrics, observation_range,
        control_configs if include_controls_in_observation else None)
    logging.vlog(1, "Building observations took %0.2f sec.",
                 time.time() - start_time)
    if observations.shape[0] == 0:
        controls = {
            name: start_value
            for (name, start_value, _, _) in control_configs
        }
        return lambda _: controls

    assert policy_and_value_vocab_size is None, (
        "Serialized policies are not supported yet.")
    # Build the policy network and load its parameters.
    start_time = time.time()
    net = ppo.policy_and_value_net(
        n_controls=len(control_configs),
        n_actions=len(action_multipliers),
        vocab_size=policy_and_value_vocab_size,
        bottom_layers_fn=policy_and_value_model,
        two_towers=policy_and_value_two_towers,
    )
    logging.vlog(1, "Building the policy network took %0.2f sec.",
                 time.time() - start_time)
    start_time = time.time()
    # (opt_state, state, epoch, opt_step)
    (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir)
    assert opt_state is not None, "Policy checkpoint not found."
    (params, _) = opt_state
    logging.vlog(1, "Restoring the policy parameters took %0.2f sec.",
                 time.time() - start_time)

    # Run the policy and sample an action.
    seed = random.randint(0, 2**31 - 1)
    rng = jax_random.get_prng(seed=seed)
    start_time = time.time()
    # ((log_probs, value_preds), state). We have no way to pass state to the next
    # step, but that should be fine.
    (log_probs, _) = (net(np.array([observations]),
                          params=params,
                          state=state,
                          rng=rng))
    logging.vlog(1, "Running the policy took %0.2f sec.",
                 time.time() - start_time)
    # Sample from the action distribution for the last timestep.
    assert log_probs.shape == (1, len(control_configs) * observations.shape[0],
                               len(action_multipliers))
    action = utils.gumbel_sample(log_probs[0, -len(control_configs):, :] /
                                 temperature)

    # Get new controls.
    controls = {
        # name: value
        control_config[0]: online_tune.update_control(  # pylint: disable=g-complex-comprehension
            control_config, control_action, history, action_multipliers)
        for (control_action, control_config) in zip(action, control_configs)
    }
    return lambda _: controls
Beispiel #9
0
    def test_combined_loss(self):
        self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3)

        B, T, A, OBS = 2, 10, 2, (28, 28, 3)  # pylint: disable=invalid-name
        batch_observation_shape = (1, 1) + OBS

        net = ppo.policy_and_value_net(
            n_controls=1,
            n_actions=A,
            vocab_size=None,
            bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)],
            two_towers=True,
        )

        old_params, _ = net.initialize_once(batch_observation_shape,
                                            np.float32, key1)
        new_params, state = net.initialize_once(batch_observation_shape,
                                                np.float32, key2)

        # Generate a batch of observations.

        observations = np.random.uniform(size=(B, T + 1) + OBS)
        actions = np.random.randint(0, A, size=(B, T + 1))
        rewards = np.random.uniform(0, 1, size=(B, T))
        mask = np.ones_like(rewards)

        # Just test that this computes at all.
        (new_log_probabs, value_predictions_new) = (net(observations,
                                                        params=new_params,
                                                        state=state))
        (old_log_probabs, value_predictions_old) = (net(observations,
                                                        params=old_params,
                                                        state=state))

        gamma = 0.99
        lambda_ = 0.95
        epsilon = 0.2
        c1 = 1.0
        c2 = 0.01

        rewards_to_actions = np.eye(value_predictions_old.shape[1])
        (value_loss_1, _) = ppo.value_loss_given_predictions(
            value_predictions_new,
            rewards,
            mask,
            gamma=gamma,
            value_prediction_old=value_predictions_old,
            epsilon=epsilon)
        (ppo_loss_1, _) = ppo.ppo_loss_given_predictions(new_log_probabs,
                                                         old_log_probabs,
                                                         value_predictions_old,
                                                         actions,
                                                         rewards_to_actions,
                                                         rewards,
                                                         mask,
                                                         gamma=gamma,
                                                         lambda_=lambda_,
                                                         epsilon=epsilon)

        (combined_loss, (ppo_loss_2, value_loss_2, entropy_bonus), _,
         state) = (ppo.combined_loss(new_params,
                                     old_log_probabs,
                                     value_predictions_old,
                                     net,
                                     observations,
                                     actions,
                                     rewards_to_actions,
                                     rewards,
                                     mask,
                                     gamma=gamma,
                                     lambda_=lambda_,
                                     epsilon=epsilon,
                                     c1=c1,
                                     c2=c2,
                                     state=state))

        # Test that these compute at all and are self consistent.
        self.assertGreater(entropy_bonus, 0.0)
        self.assertNear(value_loss_1, value_loss_2, 1e-6)
        self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6)
        self.assertNear(
            combined_loss,
            ppo_loss_2 + (c1 * value_loss_2) - (c2 * entropy_bonus), 1e-6)