Example #1
0
 def eval_fn(env, loop_hparams, policy_hparams, policy_dir, sampling_temp):
     """Eval function."""
     base_env = env
     env = rl_utils.BatchStackWrapper(env, loop_hparams.frame_stack_size)
     sim_env_kwargs = rl.make_simulated_env_kwargs(
         base_env,
         loop_hparams,
         batch_size=planner_hparams.batch_size,
         model_dir=model_dir)
     agent = make_agent(
         agent_type,
         env,
         policy_hparams,
         policy_dir,
         sampling_temp,
         sim_env_kwargs,
         loop_hparams.frame_stack_size,
         planner_hparams.planning_horizon,
         planner_hparams.rollout_agent_type,
         num_rollouts=planner_hparams.num_rollouts,
         inner_batch_size=planner_hparams.batch_size,
         video_writer=video_writer,
         env_type=planner_hparams.env_type,
         uct_const=planner_hparams.uct_const,
         uct_std_normalization=planner_hparams.uct_std_normalization,
         uniform_first_action=planner_hparams.uniform_first_action)
     rl_utils.run_rollouts(env,
                           agent,
                           env.reset(),
                           log_every_steps=log_every_steps)
     assert len(base_env.current_epoch_rollouts()) == env.batch_size
Example #2
0
 def eval_fn(env, loop_hparams, policy_hparams, policy_dir, sampling_temp):
   """Eval function."""
   base_env = env
   env = rl_utils.BatchStackWrapper(env, loop_hparams.frame_stack_size)
   sim_env_kwargs = rl.make_simulated_env_kwargs(
       base_env, loop_hparams, batch_size=planner_hparams.batch_size,
       model_dir=model_dir
   )
   planner_kwargs = planner_hparams.values()
   planner_kwargs.pop("batch_size")
   planner_kwargs.pop("rollout_agent_type")
   planner_kwargs.pop("env_type")
   agent = make_agent(
       agent_type, env, policy_hparams, policy_dir, sampling_temp,
       sim_env_kwargs, loop_hparams.frame_stack_size,
       planner_hparams.rollout_agent_type,
       inner_batch_size=planner_hparams.batch_size,
       env_type=planner_hparams.env_type,
       video_writers=video_writers, **planner_kwargs
   )
   kwargs = {}
   if not agent.records_own_videos:
     kwargs["video_writers"] = video_writers
   rl_utils.run_rollouts(
       env, agent, env.reset(), log_every_steps=log_every_steps, **kwargs
   )
   assert len(base_env.current_epoch_rollouts()) == env.batch_size
Example #3
0
    def eval_fn(env, loop_hparams, policy_hparams, policy_dir, sampling_temp):
        """Eval function."""
        base_env = env
        env = rl_utils.BatchStackWrapper(env, loop_hparams.frame_stack_size)
        agent = make_agent_from_hparams(agent_type, base_env, env,
                                        loop_hparams, policy_hparams,
                                        planner_hparams, model_dir, policy_dir,
                                        sampling_temp, video_writers)

        if eval_mode == "agent_simulated":
            real_env = base_env.new_like(batch_size=1)
            stacked_env = rl_utils.BatchStackWrapper(
                real_env, loop_hparams.frame_stack_size)
            collect_frames_for_random_starts(real_env, stacked_env, agent,
                                             loop_hparams.frame_stack_size,
                                             random_starts_step_limit,
                                             log_every_steps)
            initial_frame_chooser = rl_utils.make_initial_frame_chooser(
                real_env,
                loop_hparams.frame_stack_size,
                simulation_random_starts=True,
                simulation_flip_first_random_for_beginning=False,
                split=None,
            )
            env_fn = rl.make_simulated_env_fn_from_hparams(
                real_env,
                loop_hparams,
                batch_size=loop_hparams.eval_batch_size,
                initial_frame_chooser=initial_frame_chooser,
                model_dir=model_dir)
            sim_env = env_fn(in_graph=False)
            env = rl_utils.BatchStackWrapper(sim_env,
                                             loop_hparams.frame_stack_size)

        kwargs = {}
        if not agent.records_own_videos:
            kwargs["video_writers"] = video_writers
        step_limit = base_env.rl_env_max_episode_steps
        if step_limit == -1:
            step_limit = None
        rl_utils.run_rollouts(env,
                              agent,
                              env.reset(),
                              log_every_steps=log_every_steps,
                              step_limit=step_limit,
                              **kwargs)
        if eval_mode == "agent_real":
            assert len(base_env.current_epoch_rollouts()) == env.batch_size
Example #4
0
 def eval_fn(env, loop_hparams, policy_hparams, policy_dir, sampling_temp):
     """Eval function."""
     base_env = env
     env = rl_utils.BatchStackWrapper(env, loop_hparams.frame_stack_size)
     sim_env_kwargs = rl.make_simulated_env_kwargs(
         base_env,
         loop_hparams,
         batch_size=planner_hparams.num_rollouts,
         model_dir=model_dir)
     agent = make_agent(agent_type, env, policy_hparams, policy_dir,
                        sampling_temp, sim_env_kwargs,
                        loop_hparams.frame_stack_size,
                        planner_hparams.planning_horizon,
                        planner_hparams.rollout_agent_type)
     rl_utils.run_rollouts(env, agent, env.reset())
     assert len(base_env.current_epoch_rollouts()) == env.batch_size
Example #5
0
def collect_frames_for_random_starts(
    storage_env, stacked_env, agent, frame_stack_size, random_starts_step_limit,
    log_every_steps=None
):
  """Collects frames from real env for random starts of simulated env."""
  del frame_stack_size
  storage_env.start_new_epoch(0)
  tf.logging.info(
      "Collecting %d frames for random starts.", random_starts_step_limit
  )
  rl_utils.run_rollouts(
      stacked_env, agent, stacked_env.reset(),
      step_limit=random_starts_step_limit,
      many_rollouts_from_each_env=True,
      log_every_steps=log_every_steps,
  )
  # Save unfinished rollouts to history.
  stacked_env.reset()