Exemplo n.º 1
0
 def test_doesnt_analyze_box_action_space(self):
   space = gym.spaces.Box(shape=(2, 3), low=0, high=1)
   with self.assertRaises(AssertionError):
     ppo.analyze_action_space(space)
Exemplo n.º 2
0
 def test_doesnt_analyze_multi_disccrete_action_space_with_inequal_categories(
     self
 ):
   space = gym.spaces.MultiDiscrete(nvec=(2, 3))
   with self.assertRaises(AssertionError):
     ppo.analyze_action_space(space)
Exemplo n.º 3
0
 def test_analyzes_discrete_action_space(self):
   space = gym.spaces.Discrete(n=5)
   (n_controls, n_actions) = ppo.analyze_action_space(space)
   self.assertEqual(n_controls, 1)
   self.assertEqual(n_actions, 5)
Exemplo n.º 4
0
 def test_analyzes_multi_discrete_action_space_with_equal_categories(self):
   space = gym.spaces.MultiDiscrete(nvec=(3, 3))
   (n_controls, n_actions) = ppo.analyze_action_space(space)
   self.assertEqual(n_controls, 2)
   self.assertEqual(n_actions, 3)
Exemplo n.º 5
0
    def __init__(self,
                 train_env,
                 eval_env,
                 output_dir=None,
                 policy_and_value_model=trax_models.FrameStackMLP,
                 policy_and_value_optimizer=functools.partial(
                     trax_opt.Adam, learning_rate=1e-3),
                 policy_and_value_two_towers=False,
                 policy_and_value_vocab_size=None,
                 n_optimizer_steps=N_OPTIMIZER_STEPS,
                 optimizer_batch_size=64,
                 print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
                 target_kl=0.01,
                 boundary=20,
                 max_timestep=100,
                 max_timestep_eval=20000,
                 random_seed=None,
                 gamma=GAMMA,
                 lambda_=LAMBDA,
                 value_weight=1.0,
                 entropy_weight=0.01,
                 epsilon=0.1,
                 eval_every_n=1000,
                 save_every_n=1000,
                 done_frac_for_policy_save=0.5,
                 n_evals=1,
                 len_history_for_policy=4,
                 eval_temperatures=(1.0, 0.5),
                 separate_eval=True,
                 init_policy_from_world_model_output_dir=None,
                 controller=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 **kwargs):
        """Creates the PPO trainer.

    Args:
      train_env: gym.Env to use for training.
      eval_env: gym.Env to use for evaluation.
      output_dir: Output dir.
      policy_and_value_model: Function defining the policy and value network,
        without the policy and value heads.
      policy_and_value_optimizer: Function defining the optimizer.
      policy_and_value_two_towers: Whether to use two separate models as the
        policy and value networks. If False, share their parameters.
      policy_and_value_vocab_size: Vocabulary size of a policy and value network
        operating on serialized representation. If None, use raw continuous
        representation.
      n_optimizer_steps: Number of optimizer steps.
      optimizer_batch_size: Batch size of an optimizer step.
      print_every_optimizer_steps: How often to log during the policy
        optimization process.
      target_kl: Policy iteration early stopping. Set to infinity to disable
        early stopping.
      boundary: We pad trajectories at integer multiples of this number.
      max_timestep: If set to an integer, maximum number of time-steps in a
        trajectory. Used in the collect procedure.
      max_timestep_eval: If set to an integer, maximum number of time-steps in
        an evaluation trajectory. Used in the collect procedure.
      random_seed: Random seed.
      gamma: Reward discount factor.
      lambda_: N-step TD-error discount factor in GAE.
      value_weight: Value loss coefficient.
      entropy_weight: Entropy loss coefficient.
      epsilon: Clipping coefficient.
      eval_every_n: How frequently to eval the policy.
      save_every_n: How frequently to save the policy.
      done_frac_for_policy_save: Fraction of the trajectories that should be
        done to checkpoint the policy.
      n_evals: Number of times to evaluate.
      len_history_for_policy: How much of history to give to the policy.
      eval_temperatures: Sequence of temperatures to try for categorical
        sampling during evaluation.
      separate_eval: Whether to run separate evaluation using a set of
        temperatures. If False, the training reward is reported as evaluation
        reward with temperature 1.0.
      init_policy_from_world_model_output_dir: Model output dir for initializing
        the policy. If None, initialize randomly.
      controller: Function history -> (step -> {'name': value}) controlling
        nontrainable parameters.
      should_save_checkpoints: Whether to save policy checkpoints.
      should_write_summaries: Whether to save summaries.
      **kwargs: Additional keyword arguments passed to the base class.
    """
        # Set in base class constructor.
        self._train_env = None
        self._should_reset = None

        self._n_optimizer_steps = n_optimizer_steps
        self._optimizer_batch_size = optimizer_batch_size
        self._print_every_optimizer_steps = print_every_optimizer_steps
        self._target_kl = target_kl
        self._boundary = boundary
        self._max_timestep = max_timestep
        self._max_timestep_eval = max_timestep_eval
        self._nontrainable_params = {
            'gamma': np.array(gamma),
            'lambda': np.array(lambda_),
            'value_weight': np.array(value_weight),
            'entropy_weight': np.array(entropy_weight),
            'epsilon': np.array(epsilon),
        }
        self._eval_every_n = eval_every_n
        self._save_every_n = save_every_n
        self._done_frac_for_policy_save = done_frac_for_policy_save
        self._n_evals = n_evals
        self._len_history_for_policy = len_history_for_policy
        self._eval_temperatures = eval_temperatures
        self._separate_eval = separate_eval
        self._controller = controller
        self._should_save_checkpoints = should_save_checkpoints
        self._should_write_summaries = should_write_summaries
        self._history = None

        (n_controls,
         n_actions) = ppo.analyze_action_space(train_env.action_space)

        self._rng = trainer_lib.get_random_number_generator_and_set_seed(
            random_seed)

        self._policy_and_value_vocab_size = policy_and_value_vocab_size
        if self._policy_and_value_vocab_size is not None:
            self._serialization_kwargs = ppo.init_serialization(
                vocab_size=self._policy_and_value_vocab_size,
                observation_space=train_env.observation_space,
                action_space=train_env.action_space,
                n_timesteps=(self._max_timestep + 1),
            )
        else:
            self._serialization_kwargs = {}
        self._init_policy_from_world_model_output_dir = (
            init_policy_from_world_model_output_dir)

        self._rewards_to_actions = ppo.init_rewards_to_actions(
            self._policy_and_value_vocab_size,
            train_env.observation_space,
            train_env.action_space,
            n_timesteps=(self._max_timestep + 1),
        )

        self._policy_and_value_net_fn = functools.partial(
            ppo.policy_and_value_net,
            n_actions=n_actions,
            n_controls=n_controls,
            vocab_size=self._policy_and_value_vocab_size,
            bottom_layers_fn=policy_and_value_model,
            two_towers=policy_and_value_two_towers,
        )
        self._policy_and_value_net_apply = jit(self._policy_and_value_net_fn())
        self._policy_and_value_optimizer = policy_and_value_optimizer()

        # Super ctor calls reset(), which uses fields initialized above.
        super(PPO, self).__init__(train_env, eval_env, output_dir, **kwargs)