def define_batch_env(env_ctor, num_agents, isolate_envs): with tf.variable_scope('environments'): if isolate_envs == 'none': factory = lambda ctor: ctor() blocking = True elif isolate_envs == 'thread': factory = functools.partial(wrappers.Async, strategy='thread') blocking = False elif isolate_envs == 'process': factory = functools.partial(wrappers.Async, strategy='process') blocking = False else: raise NotImplementedError(isolate_envs) envs = [factory(env_ctor) for _ in range(num_agents)] env = batch_env.BatchEnv(envs, blocking) env = in_graph_batch_env.InGraphBatchEnv(env) return env
def define_batch_env(env_ctor, num_agents, env_processes): with tf.variable_scope('environments'): if env_processes: envs = [ wrappers.ExternalProcess(env_ctor) for _ in range(num_agents) ] else: envs = [env_ctor() for _ in range(num_agents)] env = batch_env.BatchEnv(envs, blocking=not env_processes) # # For testing the wrapper class BatchEnv: # env.reset() # env.step(np.array([[0.5,0.5]])) env = in_graph_batch_env.InGraphBatchEnv(env) # # For testing the wrapper class BatchEnv: # env.reset() # env.step(np.array([[0.5,0.5]])) return env