Пример #1
0
def make_parallel_env(args):
    def get_env_fn(rank):
        def init_env():
            if args.env_name == "StagHunt":
                assert args.num_agents == 2, (
                    "only 2 agents is supported, check the config.py.")
                env = MGEnv(args)
            elif args.env_name == "StagHuntGW" or args.env_name == "EscalationGW":
                assert args.num_agents == 2, (
                    "only 2 agent is supported, check the config.py.")
                env = GridWorldEnv(args)
            elif args.env_name == "multi_StagHuntGW":
                env = multi_GridWorldEnv(args)
            else:
                print("Can not support the " + args.env_name + "environment.")
                raise NotImplementedError
            env.seed(args.seed + rank * 1000)
            return env

        return init_env

    if args.n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv(
            [get_env_fn(i) for i in range(args.n_rollout_threads)])
Пример #2
0
def make_parallel_env(env_id, n_rollout_threads, seed, discrete_action):
    def get_env_fn(rank):
        def init_env():
            env = make_env(env_id, discrete_action=discrete_action)
            env.seed(seed + rank * 1000)
            np.random.seed(seed + rank * 1000)
            return env
        return init_env
    if n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(n_rollout_threads)])
Пример #3
0
def make_parallel_env_transport(env_id, conf, seed, discrete_action=True):
    def get_env_fn(rank):
        def init_env():
            # env = make_env(env_id, discrete_action=discrete_action)
            # env.seed(seed + rank * 1000)
            np.random.seed(seed + rank * 1000)
            return Transport(conf)

        return init_env

    # if n_rollout_threads == 1:
    return DummyVecEnv([get_env_fn(0)])
Пример #4
0
def make_parallel_env(env_id, n_rollout_threads, seed, num_controlled_lagents, num_controlled_ragents, reward_type, render):
    def get_env_fn(rank):
        def init_env():
            env = MultiAgentEnv(env_id, num_controlled_lagents, num_controlled_ragents, reward_type, render)
            env.seed(seed + rank * 1000)
            np.random.seed(seed + rank * 1000)
            return env
        return init_env
    if n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(n_rollout_threads)])
Пример #5
0
def make_parallel_env(original_drug_smile, original_target, Hyperparams,
                      atoms_, model_to_explain, original_drug,
                      original_target_aff, pred_aff, device, cof):
    def get_env_fn(rank):
        def init_env():
            env = make_env(original_drug_smile, original_target, Hyperparams,
                           atoms_, model_to_explain, original_drug,
                           original_target_aff, pred_aff, device, cof)
            return env

        return init_env

    return DummyVecEnv([get_env_fn(0)])
Пример #6
0
def make_parallel_env(num_agents, n_rollout_threads, seed, shape_file):
    def get_env_fn(rank):
        def init_env():
            env = HeavyObjectEnv(num_agents=num_agents, shape_file=shape_file)
            #env.seed(seed + rank * 1000)
            #np.random.seed(seed + rank * 1000)
            return env

        return init_env

    if n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(n_rollout_threads)])
Пример #7
0
def make_parallel_env(**kwargs):
    def get_env_fn(rank):
        def init_env():
            env = make_env(**kwargs)
            env.seed(kwargs['seed'] + rank * 1000)
            np.random.seed(kwargs['seed'] + rank * 1000)
            return env

        return init_env

    if kwargs['n_rollout_threads'] == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv(
            [get_env_fn(i) for i in range(kwargs['n_rollout_threads'])])
Пример #8
0
def make_parallel_env(args):
    def get_env_fn(rank):
        def init_env():
            if args.env_name == "MPE":
                env = MPEEnv(args)
            else:
                print("Can not support the " + args.env_name + "environment." )
                raise NotImplementedError
            env.seed(args.seed + rank * 1000)
            # np.random.seed(args.seed + rank * 1000)
            return env
        return init_env
    if args.n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(args.n_rollout_threads)])
Пример #9
0
def make_parallel_env(scenario_name, n_rollout_threads, seed, use_discrete_action, use_max_speed, world_params):
    def get_env_fn(rank):
        def init_env():
            env = make_env(scenario_name,
                           use_discrete_action=use_discrete_action,
                           use_max_speed=use_max_speed,
                           world_params=world_params)
            env.seed(seed + rank * 1000)
            np.random.seed(seed + rank * 1000)
            return env

        return init_env

    if n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(n_rollout_threads)])
Пример #10
0
def make_parallel_env(config):
    # print("SHIT SHIT SHIT!!!")
    def get_env_fn(config):
        def init_env(config):
            env = make_env(config.env_id, config, discrete_action=config.discrete_action)
            # if config.rand_prey_speed:
            #     for agent in env.world.agents:
            #         if not agent.adversary:
            #             agent.max_speed = 0.4 + np.random.rand()
            # env.seed(seed + rank * 1000)
            # np.random.seed(seed + rank * 1000)
            return env
        return init_env
    if config.n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)], config)
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(config.n_rollout_threads)])
Пример #11
0
Файл: main.py Проект: xuezzee/-
def make_parallel_env_transport(env_id,
                                conf,
                                n_rollout_threads,
                                seed,
                                discrete_action=True):
    def get_env_fn(rank):
        def init_env():
            # env = make_env(env_id, discrete_action=discrete_action)
            # env.seed(seed + rank * 1000)
            np.random.seed(seed + rank * 1000)
            return Transport(conf)

        return init_env

    if n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
        # return get_env_fn(0)
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(n_rollout_threads)])
Пример #12
0
def make_parallel_env(env_id, n_rollout_threads, seed):
    def get_env_fn(rank):
        def init_env():
            env = football_env.create_environment(
                env_name=config["academy_scenario"],
                rewards=config["scoring"],
                render=config["render_mode"],
                number_of_left_players_agent_controls=config["num_to_control"],
                representation='simple115v2')
            env.seed(seed + rank * 1000)
            np.random.seed(seed + rank * 1000)
            return env

        return init_env

    if n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)])
    else:
        return SubprocVecEnv([get_env_fn(i) for i in range(n_rollout_threads)])
Пример #13
0
def make_parallel_football_env(seed_dir, n_rollout_threads, seed, dump_freq,
                               representation, render):
    def get_env_fn(rank):
        def init_env():
            env = make_football_env(seed_dir=seed_dir,
                                    dump_freq=dump_freq,
                                    representation=representation,
                                    render=render)
            env.seed(seed + rank * 1000)
            np.random.seed(seed + rank * 1000)
            return env

        return init_env

    if n_rollout_threads == 1:
        return DummyVecEnv([get_env_fn(0)], name="DummyVecEnv_football")
    else:
        return SubprocVecEnv(
            [get_env_fn(i) for i in range(n_rollout_threads)],
            name=f"VecEnv_football_{n_rollout_threads}subprocesses")