def initialize_env_specs(hparams): """Initializes env_specs using T2TGymEnvs.""" env = rl_utils.setup_env(hparams, hparams.batch_size, hparams.eval_max_num_noops) env.start_new_epoch(0) hparams.add_hparam("env_fn", rl.make_real_env_fn(env)) return hparams
def evaluate_world_model(agent_type, loop_hparams, planner_hparams, model_dir, policy_dir, random_starts_step_limit, debug_video_path, log_every_steps): """Evaluates the world model.""" if debug_video_path: debug_video_path = os.path.join(debug_video_path, "0.avi") storage_env = rl_utils.setup_env(loop_hparams, batch_size=1, max_num_noops=0) stacked_env = rl_utils.BatchStackWrapper(storage_env, loop_hparams.frame_stack_size) policy_hparams = trainer_lib.create_hparams(loop_hparams.base_algo_params) agent = make_agent_from_hparams( agent_type, storage_env, stacked_env, loop_hparams, policy_hparams, planner_hparams, model_dir, policy_dir, # TODO(koz4k): Loop over eval_sampling_temps? sampling_temp=loop_hparams.eval_sampling_temps[0], ) collect_frames_for_random_starts(storage_env, stacked_env, agent, loop_hparams.frame_stack_size, random_starts_step_limit, log_every_steps) return rl_utils.evaluate_world_model(storage_env, loop_hparams, model_dir, debug_video_path, split=None)
def setup_and_load_epoch(hparams, data_dir, which_epoch_data=None): """Load T2TGymEnv with data from one epoch. Args: hparams: hparams. data_dir: data directory. which_epoch_data: data from which epoch to load. Returns: env. """ t2t_env = rl_utils.setup_env( hparams, batch_size=hparams.real_batch_size, max_num_noops=hparams.max_num_noops ) # Load data. if which_epoch_data is not None: if which_epoch_data == "last": which_epoch_data = infer_last_epoch_num(data_dir) assert isinstance(which_epoch_data, int), \ "{}".format(type(which_epoch_data)) t2t_env.start_new_epoch(which_epoch_data, data_dir) else: t2t_env.start_new_epoch(-999) return t2t_env
def initialize_env_specs(hparams): """Initializes env_specs using T2TGymEnvs.""" env = rl_utils.setup_env(hparams, hparams.batch_size, hparams.eval_max_num_noops) env.start_new_epoch(0) # TODO(afrozm): Decouple env_fn from hparams and return both, is there # even a need to return hparams? Just return the env_fn? hparams.add_hparam("env_fn", rl.make_real_env_fn(env)) return hparams
def initialize_env_specs(hparams): """Initializes env_specs using T2TGymEnvs.""" env = rl_utils.setup_env(hparams, hparams.batch_size, hparams.eval_max_num_noops, hparams.rl_env_max_episode_steps, env_name=hparams.rl_env_name) env.start_new_epoch(0) return rl.make_real_env_fn(env)
def initialize_env_specs(hparams, env_problem_name): """Initializes env_specs using the appropriate env.""" if env_problem_name: env = registry.env_problem(env_problem_name, batch_size=hparams.batch_size) else: env = rl_utils.setup_env(hparams, hparams.batch_size, hparams.eval_max_num_noops, hparams.rl_env_max_episode_steps, env_name=hparams.rl_env_name) env.start_new_epoch(0) return rl.make_real_env_fn(env)
def training_loop(hparams, output_dir, report_fn=None, report_metric=None): """Run the main training loop.""" if report_fn: assert report_metric is not None # Directories subdirectories = [ "data", "tmp", "world_model", ("world_model", "debug_videos"), "policy", "eval_metrics" ] directories = setup_directories(output_dir, subdirectories) epoch = -1 data_dir = directories["data"] env = rl_utils.setup_env( hparams, batch_size=hparams.real_batch_size, max_num_noops=hparams.max_num_noops, rl_env_max_episode_steps=hparams.rl_env_max_episode_steps ) env.start_new_epoch(epoch, data_dir) if hparams.wm_policy_param_sharing: policy_model_dir = directories["world_model"] else: policy_model_dir = directories["policy"] learner = rl_utils.LEARNERS[hparams.base_algo]( hparams.frame_stack_size, policy_model_dir, policy_model_dir, hparams.epochs ) # Timing log function log_relative_time = make_relative_timing_fn() # Per-epoch state epoch_metrics = [] metrics = {} # Collect data from the real environment. tf.logging.info("Initial training of the policy in real environment.") train_agent_real_env(env, learner, hparams, epoch) metrics["mean_reward/train/clipped"] = rl_utils.compute_mean_reward( env.current_epoch_rollouts(), clipped=True ) tf.logging.info("Mean training reward (initial): {}".format( metrics["mean_reward/train/clipped"] )) env.generate_data(data_dir) eval_metrics_writer = tf.summary.FileWriter( directories["eval_metrics"] ) world_model_steps_num = 0 for epoch in range(hparams.epochs): log = make_log_fn(epoch, log_relative_time) # Train world model log("Training world model") world_model_steps_num = train_world_model( env, data_dir, directories["world_model"], hparams, world_model_steps_num, epoch ) # Train agent log("Training policy in simulated environment.") train_agent(env, learner, directories["world_model"], hparams, epoch) env.start_new_epoch(epoch, data_dir) # Train agent on real env (short) log("Training policy in real environment.") train_agent_real_env(env, learner, hparams, epoch) if hparams.stop_loop_early: return 0.0 env.generate_data(data_dir) metrics = load_metrics(directories["eval_metrics"], epoch) if metrics: # Skip eval if metrics have already been written for this epoch. Otherwise # we'd overwrite them with wrong data. log("Metrics found for this epoch, skipping evaluation.") else: metrics["mean_reward/train/clipped"] = rl_utils.compute_mean_reward( env.current_epoch_rollouts(), clipped=True ) log("Mean training reward: {}".format( metrics["mean_reward/train/clipped"] )) eval_metrics = rl_utils.evaluate_all_configs(hparams, policy_model_dir) log("Agent eval metrics:\n{}".format(pprint.pformat(eval_metrics))) metrics.update(eval_metrics) if hparams.eval_world_model: debug_video_path = os.path.join( directories["world_model", "debug_videos"], "{}.avi".format(env.current_epoch) ) wm_metrics = evaluate_world_model( env, hparams, directories["world_model"], debug_video_path ) log("World model eval metrics:\n{}".format(pprint.pformat(wm_metrics))) metrics.update(wm_metrics) rl_utils.summarize_metrics(eval_metrics_writer, metrics, epoch) # Report metrics if report_fn: if report_metric == "mean_reward": metric_name = rl_utils.get_metric_name( sampling_temp=hparams.eval_sampling_temps[0], max_num_noops=hparams.eval_max_num_noops, clipped=False ) report_fn(eval_metrics[metric_name], epoch) else: report_fn(eval_metrics[report_metric], epoch) epoch_metrics.append(metrics) # Return the evaluation metrics from the final epoch return epoch_metrics[-1]