Пример #1
0
def train_agent_real_env(
    env, agent_model_dir, event_dir, epoch_data_dir,
    hparams, epoch=0, is_final_epoch=False):
  """Train the PPO agent in the real environment."""
  del epoch_data_dir
  ppo_hparams = trainer_lib.create_hparams(hparams.ppo_params)
  ppo_params_names = ["epochs_num", "epoch_length",
                      "learning_rate", "num_agents", "eval_every_epochs",
                      "optimization_epochs", "effective_num_agents"]

  # This should be overridden.
  ppo_hparams.add_hparam("effective_num_agents", None)
  for param_name in ppo_params_names:
    ppo_param_name = "real_ppo_"+ param_name
    if ppo_param_name in hparams:
      ppo_hparams.set_hparam(param_name, hparams.get(ppo_param_name))

  ppo_hparams.epochs_num = _ppo_training_epochs(hparams, epoch,
                                                is_final_epoch, True)
  # We do not save model, as that resets frames that we need at restarts.
  # But we need to save at the last step, so we set it very high.
  ppo_hparams.save_models_every_epochs = 1000000

  environment_spec = rl.standard_atari_env_spec(
      batch_env=env, include_clipping=False
  )

  ppo_hparams.add_hparam("environment_spec", environment_spec)

  rl_trainer_lib.train(ppo_hparams, event_dir + "real", agent_model_dir,
                       name_scope="ppo_real%d" % (epoch + 1))

  # Save unfinished rollouts to history.
  env.reset()
Пример #2
0
def initialize_env_specs(hparams):
  """Initializes env_specs using T2TGymEnvs."""
  if getattr(hparams, "game", None):
    game_name = gym_env.camel_case_name(hparams.game)
    env = gym_env.T2TGymEnv("{}Deterministic-v4".format(game_name),
                            batch_size=hparams.num_agents)
    env.start_new_epoch(0)
    hparams.add_hparam("environment_spec", rl.standard_atari_env_spec(env))
    eval_env = gym_env.T2TGymEnv("{}Deterministic-v4".format(game_name),
                                 batch_size=hparams.num_eval_agents)
    eval_env.start_new_epoch(0)
    hparams.add_hparam(
        "environment_eval_spec", rl.standard_atari_env_eval_spec(eval_env))
  return hparams
Пример #3
0
  def get_environment_spec(self):
    env_spec = rl.standard_atari_env_spec(self.env_name)
    env_spec.wrappers = [
        [tf_atari_wrappers.IntToBitWrapper, {}],
        [tf_atari_wrappers.StackWrapper, {"history": 4}]
    ]
    env_spec.simulated_env = True
    env_spec.add_hparam("simulation_random_starts", True)
    env_spec.add_hparam("simulation_flip_first_random_for_beginning", True)
    env_spec.add_hparam("intrinsic_reward_scale", 0.0)
    initial_frames_problem = registry.problem(self.initial_frames_problem)
    env_spec.add_hparam("initial_frames_problem", initial_frames_problem)
    env_spec.add_hparam("video_num_input_frames", self.num_input_frames)
    env_spec.add_hparam("video_num_target_frames", self.video_num_target_frames)

    return env_spec
Пример #4
0
  def get_environment_spec(self):
    env_spec = rl.standard_atari_env_spec(
        self.env_name,
        simulated=True,
        resize_height_factor=self.resize_height_factor,
        resize_width_factor=self.resize_width_factor,
        grayscale=self.grayscale)
    env_spec.add_hparam("simulation_random_starts", True)
    env_spec.add_hparam("simulation_flip_first_random_for_beginning", True)
    env_spec.add_hparam("intrinsic_reward_scale", 0.0)
    initial_frames_problem = registry.problem(self.initial_frames_problem)
    env_spec.add_hparam("initial_frames_problem", initial_frames_problem)
    env_spec.add_hparam("video_num_input_frames", self.num_input_frames)
    env_spec.add_hparam("video_num_target_frames", self.video_num_target_frames)

    return env_spec
def evaluate_single_config(hparams, agent_model_dir):
    """Evaluate the PPO agent in the real environment."""
    eval_hparams = trainer_lib.create_hparams(hparams.ppo_params)
    eval_hparams.num_agents = hparams.num_agents
    env = setup_env(hparams, batch_size=hparams.num_agents)
    environment_spec = rl.standard_atari_env_spec(env)
    eval_hparams.add_hparam("environment_spec", environment_spec)
    eval_hparams.add_hparam("policy_to_actions_lambda",
                            hparams.policy_to_actions_lambda)

    env.start_new_epoch(0)
    rl_trainer_lib.evaluate(eval_hparams, agent_model_dir)
    rollouts = env.current_epoch_rollouts()[:hparams.num_agents]
    env.close()

    assert len(rollouts) == hparams.num_agents
    return tuple(
        compute_mean_reward(rollouts, clipped) for clipped in (True, False))
Пример #6
0
 def get_environment_spec(self):
   return rl.standard_atari_env_spec(
       self.env_name,
       resize_height_factor=self.resize_height_factor,
       resize_width_factor=self.resize_width_factor,
       grayscale=self.grayscale)