def train_agent_real_env( env, agent_model_dir, event_dir, epoch_data_dir, hparams, epoch=0, is_final_epoch=False): """Train the PPO agent in the real environment.""" del epoch_data_dir ppo_hparams = trainer_lib.create_hparams(hparams.ppo_params) ppo_params_names = ["epochs_num", "epoch_length", "learning_rate", "num_agents", "eval_every_epochs", "optimization_epochs", "effective_num_agents"] # This should be overridden. ppo_hparams.add_hparam("effective_num_agents", None) for param_name in ppo_params_names: ppo_param_name = "real_ppo_"+ param_name if ppo_param_name in hparams: ppo_hparams.set_hparam(param_name, hparams.get(ppo_param_name)) ppo_hparams.epochs_num = _ppo_training_epochs(hparams, epoch, is_final_epoch, True) # We do not save model, as that resets frames that we need at restarts. # But we need to save at the last step, so we set it very high. ppo_hparams.save_models_every_epochs = 1000000 environment_spec = rl.standard_atari_env_spec( batch_env=env, include_clipping=False ) ppo_hparams.add_hparam("environment_spec", environment_spec) rl_trainer_lib.train(ppo_hparams, event_dir + "real", agent_model_dir, name_scope="ppo_real%d" % (epoch + 1)) # Save unfinished rollouts to history. env.reset()
def initialize_env_specs(hparams): """Initializes env_specs using T2TGymEnvs.""" if getattr(hparams, "game", None): game_name = gym_env.camel_case_name(hparams.game) env = gym_env.T2TGymEnv("{}Deterministic-v4".format(game_name), batch_size=hparams.num_agents) env.start_new_epoch(0) hparams.add_hparam("environment_spec", rl.standard_atari_env_spec(env)) eval_env = gym_env.T2TGymEnv("{}Deterministic-v4".format(game_name), batch_size=hparams.num_eval_agents) eval_env.start_new_epoch(0) hparams.add_hparam( "environment_eval_spec", rl.standard_atari_env_eval_spec(eval_env)) return hparams
def get_environment_spec(self): env_spec = rl.standard_atari_env_spec(self.env_name) env_spec.wrappers = [ [tf_atari_wrappers.IntToBitWrapper, {}], [tf_atari_wrappers.StackWrapper, {"history": 4}] ] env_spec.simulated_env = True env_spec.add_hparam("simulation_random_starts", True) env_spec.add_hparam("simulation_flip_first_random_for_beginning", True) env_spec.add_hparam("intrinsic_reward_scale", 0.0) initial_frames_problem = registry.problem(self.initial_frames_problem) env_spec.add_hparam("initial_frames_problem", initial_frames_problem) env_spec.add_hparam("video_num_input_frames", self.num_input_frames) env_spec.add_hparam("video_num_target_frames", self.video_num_target_frames) return env_spec
def get_environment_spec(self): env_spec = rl.standard_atari_env_spec( self.env_name, simulated=True, resize_height_factor=self.resize_height_factor, resize_width_factor=self.resize_width_factor, grayscale=self.grayscale) env_spec.add_hparam("simulation_random_starts", True) env_spec.add_hparam("simulation_flip_first_random_for_beginning", True) env_spec.add_hparam("intrinsic_reward_scale", 0.0) initial_frames_problem = registry.problem(self.initial_frames_problem) env_spec.add_hparam("initial_frames_problem", initial_frames_problem) env_spec.add_hparam("video_num_input_frames", self.num_input_frames) env_spec.add_hparam("video_num_target_frames", self.video_num_target_frames) return env_spec
def evaluate_single_config(hparams, agent_model_dir): """Evaluate the PPO agent in the real environment.""" eval_hparams = trainer_lib.create_hparams(hparams.ppo_params) eval_hparams.num_agents = hparams.num_agents env = setup_env(hparams, batch_size=hparams.num_agents) environment_spec = rl.standard_atari_env_spec(env) eval_hparams.add_hparam("environment_spec", environment_spec) eval_hparams.add_hparam("policy_to_actions_lambda", hparams.policy_to_actions_lambda) env.start_new_epoch(0) rl_trainer_lib.evaluate(eval_hparams, agent_model_dir) rollouts = env.current_epoch_rollouts()[:hparams.num_agents] env.close() assert len(rollouts) == hparams.num_agents return tuple( compute_mean_reward(rollouts, clipped) for clipped in (True, False))
def get_environment_spec(self): return rl.standard_atari_env_spec( self.env_name, resize_height_factor=self.resize_height_factor, resize_width_factor=self.resize_width_factor, grayscale=self.grayscale)