예제 #1
0
 def vectorized(env_id, seed, num_processes, log_dir, add_timestep,
          sensors={DEFAULT_SENSOR_NAME: None},
          addl_repeat_count=0, preprocessing_fn=None,
          env_specific_kwargs={},
          vis_interval=20,
          visdom_name='main',
          visdom_log_file=None,
          visdom_server='localhost',
          visdom_port='8097',
          num_val_processes=0,
          gae_gamma=None):
     '''Returns vectorized environment. Either the simulator implements this (habitat) or
        'vectorized' uses the call_to_run helper
     '''
     simulator, scenario = env_id.split('_')
     if simulator.lower() in ['habitat']: # These simulators internally handle vectorization/distribution
         env = make_habitat_vector_env(
                         scenario=scenario,
                         num_processes=num_processes,
                         preprocessing_fn=preprocessing_fn,
                         log_dir=log_dir,
                         num_val_processes=num_val_processes,
                         vis_interval=vis_interval,
                         visdom_name=visdom_name,
                         visdom_log_file=visdom_log_file,
                         visdom_server=visdom_server,
                         visdom_port=visdom_port,
                         seed=seed,
                         **env_specific_kwargs)   
     else: # These simulators must be manually vectorized
         envs = [ EnvFactory.call_to_run(env_id, seed, 
                         rank, log_dir, add_timestep,
                         sensors=sensors,
                         addl_repeat_count=addl_repeat_count,
                         preprocessing_fn=preprocessing_fn,
                         env_specific_kwargs=env_specific_kwargs,
                         vis_interval=vis_interval,
                         visdom_name=visdom_name,
                         visdom_log_file=visdom_log_file,
                         visdom_server=visdom_server,
                         visdom_port=visdom_port,
                         num_val_processes=num_val_processes,
                         num_processes=num_processes)
                  for rank in range(num_processes) ]
         if num_processes == 1:
             env = DummyVecEnv(envs)
         else:
             env = DistributedEnv.new(envs,
                         gae_gamma=gae_gamma,
                         distribution_method=DistributedEnv.distribution_schemes.vectorize)
     return env
예제 #2
0
        def _thunk():
            preprocessing_fn_implemented_inside_env = False
            logging_implemented_inside_env = False
            already_distributed = False
            if env_id.startswith("dm"):
                _, domain, task = env_id.split('.')
                env = dm_control2gym.make(domain_name=domain, task_name=task)
            elif env_id.startswith("Gibson"):
                env = GibsonEnv(env_id=env_id,
                                gibson_config=gibson_config,
                                blind=blind,
                                blank_sensor=blank_sensor,
                                start_locations_file=start_locations_file,
                                target_dim=target_dim, 
                                **env_specific_kwargs)
            elif env_id.startswith("DummyGibson"):
                env = DummyGibsonEnv(env_id=env_id,
                                     gibson_config=gibson_config,
                                     blind=blind,
                                     blank_sensor=blank_sensor,
                                     start_locations_file=start_locations_file,
                                     target_dim=target_dim,
                                     **env_specific_kwargs)
            elif env_id.startswith("Doom"):
                env_specific_kwargs['repeat_count'] = addl_repeat_count + 1
                num_train_processes = num_processes - num_val_processes
                # 1 (train only), 2 test only
                env_specific_kwargs['randomize_textures'] = 1 if rank < num_train_processes else 2
                vizdoom_class = eval(scenario.split('.')[0])
                env = vizdoom_class(**env_specific_kwargs)
            elif env_id.startswith("Habitat"):
                env = make_habitat_vector_env(
                                num_processes=rank,
                                target_dim=target_dim,
                                preprocessing_fn=preprocessing_fn,
                                log_dir=log_dir,
                                num_val_processes=num_val_processes,
                                visdom_name=visdom_name,
                                visdom_log_file=visdom_log_file,
                                visdom_server=visdom_server,
                                visdom_port=visdom_port,
                                seed=seed,
                                **env_specific_kwargs)
                already_distributed = True
                preprocessing_fn_implemented_inside_env = True
                logging_implemented_inside_env = True
            else:
                env = gym.make(env_id)

            if already_distributed: # Env is now responsible for logging, preprocessing, repeat_count
                return env

            is_atari = hasattr(gym.envs, 'atari') and isinstance(
                env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
            if is_atari:
                env = make_atari(env_id)


            if add_timestep:
                raise NotImplementedError("AddTimestep not implemented for SensorDict")
                obs_shape = env.observation_space.shape
                if add_timestep and len(obs_shape) == 1 \
                    and str(env).find('TimeLimit') > -1:
                    env = AddTimestep(env)

            if not (logging_implemented_inside_env or log_dir is None):
                os.makedirs(os.path.join(log_dir, visdom_name), exist_ok=True)
                print("Visdom log file", visdom_log_file)
                first_val_process = num_processes - num_val_processes
                if (rank == 0 or rank == first_val_process) and visdom_log_file is not None:
                    env = VisdomMonitor(env,
                                   directory=os.path.join(log_dir, visdom_name),
                                   video_callable=lambda x: x % vis_interval == 0,
                                   uid=str(rank),
                                   server=visdom_server,
                                   port=visdom_port,
                                   visdom_log_file=visdom_log_file,
                                   visdom_env=visdom_name)
                else:
                    print("Not using visdom")
                    env = wrappers.Monitor(env,
                                   directory=os.path.join(log_dir, visdom_name),
                                   uid=str(rank))

            if is_atari:
                env = wrap_deepmind(env)
            if addl_repeat_count > 0:
                if not hasattr(env, 'repeat_count') and not hasattr(env.unwrapped, 'repeat_count'):
                    env = SkipWrapper(repeat_count)(env)

            if sensors is not None:
                if hasattr(env, 'is_embodied') or hasattr(env.unwrapped, 'is_embodied'):
                    pass
                else:
                    assert len(sensors) == 1, 'Can only handle one sensor'
                    sensor_name = list(sensors.keys())[0]
                    env = SensorEnvWrapper(env, name=sensor_name)
                
            if not (preprocessing_fn_implemented_inside_env or preprocessing_fn is None):
                transform, space = preprocessing_fn(env.observation_space)
                env = ProcessObservationWrapper(env, transform, space)
            env.seed(seed + rank)
            return env