def eval_fn(env, loop_hparams, policy_hparams, policy_dir, sampling_temp): """Eval function.""" base_env = env env = rl_utils.BatchStackWrapper(env, loop_hparams.frame_stack_size) sim_env_kwargs = rl.make_simulated_env_kwargs( base_env, loop_hparams, batch_size=planner_hparams.batch_size, model_dir=model_dir) agent = make_agent( agent_type, env, policy_hparams, policy_dir, sampling_temp, sim_env_kwargs, loop_hparams.frame_stack_size, planner_hparams.planning_horizon, planner_hparams.rollout_agent_type, num_rollouts=planner_hparams.num_rollouts, inner_batch_size=planner_hparams.batch_size, video_writer=video_writer, env_type=planner_hparams.env_type, uct_const=planner_hparams.uct_const, uct_std_normalization=planner_hparams.uct_std_normalization, uniform_first_action=planner_hparams.uniform_first_action) rl_utils.run_rollouts(env, agent, env.reset(), log_every_steps=log_every_steps) assert len(base_env.current_epoch_rollouts()) == env.batch_size
def eval_fn(env, loop_hparams, policy_hparams, policy_dir, sampling_temp): """Eval function.""" base_env = env env = rl_utils.BatchStackWrapper(env, loop_hparams.frame_stack_size) sim_env_kwargs = rl.make_simulated_env_kwargs( base_env, loop_hparams, batch_size=planner_hparams.batch_size, model_dir=model_dir ) planner_kwargs = planner_hparams.values() planner_kwargs.pop("batch_size") planner_kwargs.pop("rollout_agent_type") planner_kwargs.pop("env_type") agent = make_agent( agent_type, env, policy_hparams, policy_dir, sampling_temp, sim_env_kwargs, loop_hparams.frame_stack_size, planner_hparams.rollout_agent_type, inner_batch_size=planner_hparams.batch_size, env_type=planner_hparams.env_type, video_writers=video_writers, **planner_kwargs ) kwargs = {} if not agent.records_own_videos: kwargs["video_writers"] = video_writers rl_utils.run_rollouts( env, agent, env.reset(), log_every_steps=log_every_steps, **kwargs ) assert len(base_env.current_epoch_rollouts()) == env.batch_size
def eval_fn(env, loop_hparams, policy_hparams, policy_dir, sampling_temp): """Eval function.""" base_env = env env = rl_utils.BatchStackWrapper(env, loop_hparams.frame_stack_size) agent = make_agent_from_hparams(agent_type, base_env, env, loop_hparams, policy_hparams, planner_hparams, model_dir, policy_dir, sampling_temp, video_writers) if eval_mode == "agent_simulated": real_env = base_env.new_like(batch_size=1) stacked_env = rl_utils.BatchStackWrapper( real_env, loop_hparams.frame_stack_size) collect_frames_for_random_starts(real_env, stacked_env, agent, loop_hparams.frame_stack_size, random_starts_step_limit, log_every_steps) initial_frame_chooser = rl_utils.make_initial_frame_chooser( real_env, loop_hparams.frame_stack_size, simulation_random_starts=True, simulation_flip_first_random_for_beginning=False, split=None, ) env_fn = rl.make_simulated_env_fn_from_hparams( real_env, loop_hparams, batch_size=loop_hparams.eval_batch_size, initial_frame_chooser=initial_frame_chooser, model_dir=model_dir) sim_env = env_fn(in_graph=False) env = rl_utils.BatchStackWrapper(sim_env, loop_hparams.frame_stack_size) kwargs = {} if not agent.records_own_videos: kwargs["video_writers"] = video_writers step_limit = base_env.rl_env_max_episode_steps if step_limit == -1: step_limit = None rl_utils.run_rollouts(env, agent, env.reset(), log_every_steps=log_every_steps, step_limit=step_limit, **kwargs) if eval_mode == "agent_real": assert len(base_env.current_epoch_rollouts()) == env.batch_size
def eval_fn(env, loop_hparams, policy_hparams, policy_dir, sampling_temp): """Eval function.""" base_env = env env = rl_utils.BatchStackWrapper(env, loop_hparams.frame_stack_size) sim_env_kwargs = rl.make_simulated_env_kwargs( base_env, loop_hparams, batch_size=planner_hparams.num_rollouts, model_dir=model_dir) agent = make_agent(agent_type, env, policy_hparams, policy_dir, sampling_temp, sim_env_kwargs, loop_hparams.frame_stack_size, planner_hparams.planning_horizon, planner_hparams.rollout_agent_type) rl_utils.run_rollouts(env, agent, env.reset()) assert len(base_env.current_epoch_rollouts()) == env.batch_size
def collect_frames_for_random_starts( storage_env, stacked_env, agent, frame_stack_size, random_starts_step_limit, log_every_steps=None ): """Collects frames from real env for random starts of simulated env.""" del frame_stack_size storage_env.start_new_epoch(0) tf.logging.info( "Collecting %d frames for random starts.", random_starts_step_limit ) rl_utils.run_rollouts( stacked_env, agent, stacked_env.reset(), step_limit=random_starts_step_limit, many_rollouts_from_each_env=True, log_every_steps=log_every_steps, ) # Save unfinished rollouts to history. stacked_env.reset()