Exemplo n.º 1
0
    def __init__(self, path="data/arrange_model", messages=None, mode=1):
        # some parameter
        map_size = 250
        eps = 0.15

        # init the game
        env = magent.GridWorld(load_config(map_size))
        font = FontProvider('data/font_8x8/basic.txt')

        handles = env.get_handles()
        food_handle, handles = handles[0], handles[1:]
        models = []
        models.append(DeepQNetwork(env, handles[0], 'arrange', use_conv=True))

        # load model
        models[0].load(path, 10)

        # init environment
        env.reset()
        generate_map(mode, env, map_size, food_handle, handles, messages, font)

        # save to member variable
        self.env = env
        self.food_handle = food_handle
        self.handles = handles
        self.eps = eps
        self.models = models
        self.done = False
        self.map_size = map_size
        self.new_rule_ct = 0
        self.pos_reward_ct = set()
        self.num = None

        self.ct = 0
Exemplo n.º 2
0
    def __init__(self, map_size, seed):
        env = magent.GridWorld(get_config(map_size), map_size=map_size)

        handles = env.get_handles()

        names = ["predetor", "prey"]
        super().__init__(env, handles, names, map_size, seed)
Exemplo n.º 3
0
    def __init__(self, path="data/pursuit_model", total_step=500):
        # some parameter
        map_size = 1000
        eps = 0.00

        # init the game
        env = magent.GridWorld(load_config(map_size))

        handles = env.get_handles()
        models = []
        models.append(DeepQNetwork(env, handles[0], 'predator', use_conv=True))
        models.append(DeepQNetwork(env, handles[1], 'prey', use_conv=True))

        # load model
        models[0].load(path, 423, 'predator')
        models[1].load(path, 423, 'prey')

        # init environment
        env.reset()
        generate_map(env, map_size, handles)

        # save to member variable
        self.env = env
        self.handles = handles
        self.eps = eps
        self.models = models
        self.map_size = map_size
        self.total_step = total_step
        self.done = False
        self.total_handles = [
            self.env.get_num(self.handles[0]),
            self.env.get_num(self.handles[1])
        ]
        print(env.get_view2attack(handles[0]))
        plt.show()
Exemplo n.º 4
0
def test_model(net: model.DQNModel, device: torch.device,
               gw_config) -> Tuple[float, float]:
    test_env = magent.GridWorld(gw_config, map_size=MAP_SIZE)
    deer_handle, tiger_handle = test_env.get_handles()

    def reset_env():
        test_env.reset()
        test_env.add_walls(method="random",
                           n=MAP_SIZE * MAP_SIZE * WALLS_DENSITY)
        test_env.add_agents(deer_handle, method="random", n=COUNT_DEERS)
        test_env.add_agents(tiger_handle, method="random", n=COUNT_TIGERS)

    env = data.MAgentEnv(test_env, tiger_handle, reset_env_func=reset_env)
    preproc = model.MAgentPreprocessor(device)
    agent = ptan.agent.DQNAgent(net,
                                ptan.actions.ArgmaxActionSelector(),
                                device,
                                preprocessor=preproc)

    obs = env.reset()
    steps = 0
    rewards = 0.0

    while True:
        actions = agent(obs)[0]
        obs, r, dones, _ = env.step(actions)
        steps += len(obs)
        rewards += sum(r)
        if dones[0]:
            break

    return rewards / COUNT_TIGERS, steps / COUNT_TIGERS
Exemplo n.º 5
0
    def __init__(self, map_size, reward_args, max_frames):
        EzPickle.__init__(self, map_size, reward_args, max_frames)
        env = magent.GridWorld(load_config(map_size, **reward_args))
        handles = env.get_handles()

        names = ["omnivore"]
        super().__init__(env, handles[1:], names, map_size, max_frames)
Exemplo n.º 6
0
 def __init__(self, map_size, reward_args, max_frames, seed):
     EzPickle.__init__(self, map_size, reward_args, max_frames, seed)
     env = magent.GridWorld(get_config(map_size, **reward_args), map_size=map_size)
     self.leftID = 0
     self.rightID = 1
     names = ["red", "blue"]
     super().__init__(env, env.get_handles(), names, map_size, max_frames, seed)
Exemplo n.º 7
0
    def __init__(self, map_size, reward_args, max_frames, seed):
        EzPickle.__init__(self, map_size, reward_args, max_frames, seed)
        env = magent.GridWorld(get_config(map_size, **reward_args), map_size=map_size)

        handles = env.get_handles()

        names = ["predator", "prey"]
        super().__init__(env, handles, names, map_size, max_frames, seed)
Exemplo n.º 8
0
 def __init__(self, map_size, minimap_mode, reward_args, max_cycles):
     EzPickle.__init__(self, map_size, minimap_mode, reward_args, max_cycles)
     assert map_size >= 16, "size of map must be at least 16"
     env = magent.GridWorld(load_config(map_size, minimap_mode, **reward_args))
     reward_vals = np.array([KILL_REWARD] + list(reward_args.values()))
     reward_range = [np.minimum(reward_vals, 0).sum(), np.maximum(reward_vals, 0).sum()]
     names = ["redmelee", "redranged", "bluemele", "blueranged"]
     super().__init__(env, env.get_handles(), names, map_size, max_cycles, reward_range, minimap_mode)
Exemplo n.º 9
0
 def __init__(self, map_size, minimap_mode, reward_args, max_cycles):
     EzPickle.__init__(self, map_size, minimap_mode, reward_args, max_cycles)
     env = magent.GridWorld(load_config(map_size, minimap_mode, **reward_args))
     handles = env.get_handles()
     reward_vals = np.array([5] + list(reward_args.values()))
     reward_range = [np.minimum(reward_vals, 0).sum(), np.maximum(reward_vals, 0).sum()]
     names = ["omnivore"]
     super().__init__(env, handles[1:], names, map_size, max_cycles, reward_range, minimap_mode)
Exemplo n.º 10
0
    def __init__(self, map_size, max_frames, seed):
        EzPickle.__init__(self, map_size, max_frames, seed)
        env = magent.GridWorld(get_config(map_size), map_size=map_size)

        handles = env.get_handles()

        names = ["deer", "tiger"]
        super().__init__(env, handles, names, map_size, max_frames, seed)
Exemplo n.º 11
0
def test_model(net: model.DQNModel, device: torch.device,
               gw_config) -> Tuple[float, float, float, float]:
    test_env = magent.GridWorld(gw_config, map_size=MAP_SIZE)
    group_a, group_b = test_env.get_handles()

    def reset_env():
        test_env.reset()
        test_env.add_walls(method="random",
                           n=MAP_SIZE * MAP_SIZE * WALLS_DENSITY)
        test_env.add_agents(group_a, method="random", n=COUNT_AGENTS_1)
        test_env.add_agents(group_b, method="random", n=COUNT_AGENTS_2)

    env_a = data.MAgentEnv(test_env,
                           group_a,
                           reset_env_func=reset_env,
                           is_slave=True)
    env_b = data.MAgentEnv(test_env,
                           group_b,
                           reset_env_func=reset_env,
                           steps_limit=MAX_EPISODE)
    preproc = model.MAgentPreprocessor(device)
    agent_a = ptan.agent.DQNAgent(net,
                                  ptan.actions.ArgmaxActionSelector(),
                                  device,
                                  preprocessor=preproc)
    agent_b = ptan.agent.DQNAgent(net,
                                  ptan.actions.ArgmaxActionSelector(),
                                  device,
                                  preprocessor=preproc)

    a_obs = env_a.reset()
    b_obs = env_b.reset()
    a_steps = 0
    a_rewards = 0.0
    b_steps = 0
    b_rewards = 0.0

    while True:
        a_actions = agent_a(a_obs)[0]
        b_actions = agent_b(b_obs)[0]
        a_obs, a_r, a_dones, _ = env_a.step(a_actions)
        b_obs, b_r, b_dones, _ = env_b.step(b_actions)
        a_steps += len(a_obs)
        a_rewards += sum(a_r)
        if a_dones[0]:
            break
        b_steps += len(b_obs)
        b_rewards += sum(b_r)
        if b_dones[0]:
            break

    return (
        a_rewards / COUNT_AGENTS_1,
        b_steps / COUNT_AGENTS_1,
        b_rewards / COUNT_AGENTS_2,
        b_steps / COUNT_AGENTS_2,
    )
Exemplo n.º 12
0
 def __init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_features):
     EzPickle.__init__(self, map_size, minimap_mode, reward_args, max_cycles, extra_features)
     assert map_size >= 12, "size of map must be at least 12"
     env = magent.GridWorld(get_config(map_size, minimap_mode, **reward_args), map_size=map_size)
     self.leftID = 0
     self.rightID = 1
     reward_vals = np.array([KILL_REWARD] + list(reward_args.values()))
     reward_range = [np.minimum(reward_vals, 0).sum(), np.maximum(reward_vals, 0).sum()]
     names = ["red", "blue"]
     super().__init__(env, env.get_handles(), names, map_size, max_cycles, reward_range, minimap_mode, extra_features)
    def __init__(self, map_size, minimap_mode, reward_args, max_cycles):
        EzPickle.__init__(self, map_size, minimap_mode, reward_args, max_cycles)
        assert map_size >= 20, "size of map must be at least 20"
        env = magent.GridWorld(get_config(map_size, minimap_mode, **reward_args), map_size=map_size)

        handles = env.get_handles()
        reward_vals = np.array([1, -1, -1, -1, -1] + list(reward_args.values()))
        reward_range = [np.minimum(reward_vals, 0).sum(), np.maximum(reward_vals, 0).sum()]
        names = ["predator", "prey"]
        super().__init__(env, handles, names, map_size, max_cycles, reward_range, minimap_mode)
def test_model(net_deer: model.DQNModel, net_tiger: model.DQNModel,
               device: torch.device,
               gw_config) -> Tuple[float, float, float, float]:
    test_env = magent.GridWorld(gw_config, map_size=MAP_SIZE)
    deer_handle, tiger_handle = test_env.get_handles()

    def reset_env():
        test_env.reset()
        test_env.add_walls(method="random",
                           n=MAP_SIZE * MAP_SIZE * WALLS_DENSITY)
        test_env.add_agents(deer_handle, method="random", n=COUNT_DEERS)
        test_env.add_agents(tiger_handle, method="random", n=COUNT_TIGERS)

    deer_env = data.MAgentEnv(test_env,
                              deer_handle,
                              reset_env_func=reset_env,
                              is_slave=True)
    tiger_env = data.MAgentEnv(test_env,
                               tiger_handle,
                               reset_env_func=reset_env)
    preproc = model.MAgentPreprocessor(device)
    deer_agent = ptan.agent.DQNAgent(net_deer,
                                     ptan.actions.ArgmaxActionSelector(),
                                     device,
                                     preprocessor=preproc)
    tiger_agent = ptan.agent.DQNAgent(net_tiger,
                                      ptan.actions.ArgmaxActionSelector(),
                                      device,
                                      preprocessor=preproc)

    t_obs = tiger_env.reset()
    d_obs = deer_env.reset()
    deer_steps = 0
    deer_rewards = 0.0
    tiger_steps = 0
    tiger_rewards = 0.0

    while True:
        d_actions = deer_agent(d_obs)[0]
        t_actions = tiger_agent(t_obs)[0]
        d_obs, d_r, d_dones, _ = deer_env.step(d_actions)
        t_obs, t_r, t_dones, _ = tiger_env.step(t_actions)
        tiger_steps += len(t_obs)
        tiger_rewards += sum(t_r)
        if t_dones[0]:
            break
        deer_steps += len(d_obs)
        deer_rewards += sum(d_r)
        if d_dones[0]:
            break

    return deer_rewards / COUNT_DEERS, deer_steps / COUNT_DEERS, \
           tiger_rewards / COUNT_TIGERS, tiger_steps / COUNT_TIGERS
Exemplo n.º 15
0
    def __init__(self,
                 path="data/battle_model_3_players",
                 total_step=1000,
                 add_counter=10,
                 add_interval=50):
        # some parameter
        map_size = 125
        eps = 0.00

        # init the game
        env = magent.GridWorld(utils.load_config(map_size))

        handles = env.get_handles()
        models = []
        models.append(
            DeepQNetwork(env,
                         handles[0],
                         'trusty-battle-game-l1',
                         use_conv=True))
        # models.append(DeepQNetwork(env, handles[1], 'trusty-battle-game-l2', use_conv=True))
        models.append(
            DeepQNetwork(env,
                         handles[1],
                         'trusty-battle-game-r',
                         use_conv=True))

        # load model
        # tf.reset_default_graph()
        models[0].load(path, 1, 'trusty-battle-game-l1')
        # models[1].load(path, 1, 'trusty-battle-game-l2')
        # tf.reset_default_graph()
        models[2].load(path, 1, 'trusty-battle-game-r')

        # init environment
        env.reset()
        utils.generate_map(env, map_size, handles)

        # save to member variable
        self.env = env
        self.handles = handles
        self.eps = eps
        self.models = models
        self.map_size = map_size
        self.total_step = total_step
        self.add_interval = add_interval
        self.add_counter = add_counter
        self.done = False
        self.total_handles = [
            self.env.get_num(self.handles[0]),
            self.env.get_num(self.handles[1])
        ]
Exemplo n.º 16
0
    def __init__(self, config, **kwargs):
        self.env = magent.GridWorld(config, **kwargs)
        self.num_agents = 2
        self.agents = ["predator", "prey"]
        self.dones = {agent: False for agent in self.agents}
        self.agent_order = self.agents[:]

        self.action_spaces = {agent: Discrete(3) for agent in self.agents}
        self.observation_spaces = {agent: Discrete(4) for agent in self.agents}

        self.display_wait = 0.0
        self.rewards = {agent: 0 for agent in self.agents}
        self.dones = {agent: False for agent in self.agents}
        self.infos = {agent: {} for agent in self.agents}
        self.num_moves = 0
Exemplo n.º 17
0
    def __init__(self,
                 path="data/battle_model",
                 total_step=1000,
                 add_counter=10,
                 add_interval=50):
        # some parameter
        map_size = 125
        eps = 0.05

        # init the game
        env = magent.GridWorld(load_config(map_size))

        handles = env.get_handles()
        models = []
        models.append(
            DeepQNetwork(env,
                         handles[0],
                         'trusty-battle-game-l',
                         use_conv=True))
        models.append(
            DeepQNetwork(env,
                         handles[1],
                         'trusty-battle-game-r',
                         use_conv=True))

        # load model
        models[0].load(path, 0, 'trusty-battle-game-l')
        models[1].load(path, 0, 'trusty-battle-game-r')

        # init environment
        env.reset()
        generate_map(env, map_size, handles)

        # save to member variable
        self.env = env
        self.handles = handles
        self.eps = eps
        self.models = models
        self.map_size = map_size
        self.total_step = total_step
        self.add_interval = add_interval
        self.add_counter = add_counter
        self.done = False
        print(env.get_view2attack(handles[0]))
        plt.show()
Exemplo n.º 18
0
    def __init__(self, path="data/against_v2", total_step=500):
        # some parameter
        map_size = 125
        eps = 0.00

        # init the game
        env = magent.GridWorld("battle", map_size=map_size)

        handles = env.get_handles()
        models = []
        models.append(
            DeepQNetwork(env,
                         handles[0],
                         'trusty-battle-game-l',
                         use_conv=True))
        models.append(DeepQNetwork(env, handles[1], 'battle', use_conv=True))

        # load model
        # models[0].load(path, 999, 'against-a')
        # # models[0].load('data/battle_model_1000_vs_500', 1500, 'trusty-battle-game-l')
        # models[1].load(path, 999, 'battle')
        #
        models[0].load("data/battle_model_1000_vs_500", 1500,
                       'trusty-battle-game-l')
        models[1].load("data/battle_model_1000_vs_500", 1500,
                       'trusty-battle-game-r')

        # init environment
        env.reset()
        x0, y0, x1, y1 = utils.generate_map(env, map_size, handles)
        # generate_map(env, map_size, handles)

        # save to member variable
        self.env = env
        self.handles = handles
        self.eps = eps
        self.models = models
        self.map_size = map_size
        self.total_step = total_step
        self.done = False
        self.total_handles = [
            self.env.get_num(self.handles[0]),
            self.env.get_num(self.handles[1])
        ]
Exemplo n.º 19
0
def main(args):
    # Initialize the environment
    env = magent.GridWorld('battle', map_size=args.map_size)
    env.set_render_dir(
        os.path.join(BASE_DIR, 'examples/battle_model', 'build/render'))
    handles = env.get_handles()

    tf_config = tf.ConfigProto(allow_soft_placement=True,
                               log_device_placement=False)
    tf_config.gpu_options.allow_growth = True

    log_dir = os.path.join(BASE_DIR, 'data/tmp'.format(args.algo))
    model_dir = os.path.join(BASE_DIR, 'data/models/{}'.format(args.algo))

    start_from = 0

    sess = tf.Session(config=tf_config)
    models = [
        spawn_ai(args.algo, sess, env, handles[0], args.algo + '-me',
                 args.max_steps),
        spawn_ai(args.algo, sess, env, handles[1], args.algo + '-opponent',
                 args.max_steps)
    ]
    sess.run(tf.global_variables_initializer())
    runner = tools.Runner(sess,
                          env,
                          handles,
                          args.map_size,
                          args.max_steps,
                          models,
                          play,
                          render_every=args.save_every if args.render else 0,
                          save_every=args.save_every,
                          tau=0.01,
                          log_name=args.algo,
                          log_dir=log_dir,
                          model_dir=model_dir,
                          train=True)

    for k in range(start_from, start_from + args.n_round):
        eps = linear_decay(k, [0, int(args.n_round * 0.8), args.n_round],
                           [1, 0.2, 0.1])
        runner.run(eps, k)
Exemplo n.º 20
0
    def __init__(self,
                 models,
                 path="data/models",
                 total_step=1000,
                 add_counter=10,
                 add_interval=1000):
        # 设置地图的大小 限制活动范围
        map_size = 25

        eps = 0.05

        # init the game
        env = magent.GridWorld(load_config(map_size))

        handles = env.get_handles()  #[c_int(0),c_int(1)]
        #models = []
        #models.append(DeepQNetwork(env, handles[0], 'ac-0', use_conv=True))
        #models.append(ActorCritic('ac-0', env, handles[0]))
        #models.append(DeepQNetwork(env, handles[1], 'ac-1', use_conv=True))

        # load model
        #models[0].load(path, 0, 'ac-0')
        #models[1].load(path, 0, 'ac-1')

        # init environment 加载地图
        env.reset()
        generate_map(env, map_size, handles)

        # save to member variable
        self.env = env
        self.handles = handles
        self.eps = eps
        self.models = models
        self.map_size = map_size
        self.total_step = total_step
        # 设置多少步需要停止 添加一批智能体
        self.add_interval = add_interval
        # 设置最多添加多少次智能体
        self.add_counter = add_counter
        self.done = False
        print(env.get_view2attack(handles[0]))
        plt.show()
Exemplo n.º 21
0
    def reset_world(self,
                    map_name,
                    map_size,
                    num_walls,
                    num_predators,
                    num_preys,
                    algorithm,
                    render_dir='./build/render'):
        self.env = magent.GridWorld(map_name, map_size=map_size)
        self.env.set_render_dir(render_dir)

        # get group handles
        self.predator, self.prey = self.env.get_handles()
        self.alg = algorithm
        # init env and agents
        self.env.reset()
        self.env.add_walls(method="random", n=num_walls)
        self.env.add_agents(self.predator, method="random", n=num_predators)
        self.env.add_agents(self.prey, method="random", n=num_preys)
        self.get_observation()
Exemplo n.º 22
0
        "--render", default="render", help="Directory to store renders, default=render"
    )
    parser.add_argument(
        "--walls-density", type=float, default=0.04, help="Density of walls, default=0.04"
    )
    parser.add_argument(
        "--count_a", type=int, default=20, help="Size of the first group, default=100"
    )
    parser.add_argument(
        "--count_b", type=int, default=20, help="Size of the second group, default=100"
    )
    parser.add_argument("--max-steps", type=int, help="Set limit of steps")

    args = parser.parse_args()

    env = magent.GridWorld("battle", map_size=args.map_size)
    env.set_render_dir(args.render)
    a_handle, b_handle = env.get_handles()

    env.reset()
    env.add_walls(method="random", n=args.map_size * args.map_size * args.walls_density)
    env.add_agents(a_handle, method="random", n=args.count_a)
    env.add_agents(b_handle, method="random", n=args.count_b)

    v = env.get_view_space(a_handle)
    v = (v[-1],) + v[:2]
    net_a = model.DQNModel(v, env.get_feature_space(a_handle), env.get_action_space(a_handle)[0])
    net_a.load_state_dict(torch.load(args.model_a))
    print(net_a)

    v = env.get_view_space(b_handle)
Exemplo n.º 23
0
from typing import List, Tuple

import magent

from magent.builtin.rule_model import RandomActor
from magent.model import BaseModel
from numpy import ndarray

MAP_SIZE: int = 64

if __name__ == "__main__":
    environment: magent.GridWorld = magent.GridWorld("forest",
                                                     map_size=MAP_SIZE)
    environment.set_render_dir("render")

    deer_handle: int
    tiger_handle: int
    deer_handle, tiger_handle = environment.get_handles()

    models: List[BaseModel] = [
        RandomActor(environment, deer_handle),
        RandomActor(environment, tiger_handle)
    ]

    environment.reset()
    environment.add_walls(method="random", n=MAP_SIZE * MAP_SIZE * 0.04)
    environment.add_agents(deer_handle, method="random", n=5)
    environment.add_agents(tiger_handle, method="random", n=2)

    tiger_view_space: Tuple = environment.get_view_space(tiger_handle)
    tiger_feature_space: Tuple = environment.get_feature_space(tiger_handle)
Exemplo n.º 24
0
"""
First demo, show the usage of API
"""

import magent
# try:
#     from magent.builtin.mx_model import DeepQNetwork
# except ImportError as e:
from magent.builtin.tf_model import DeepQNetwork

if __name__ == "__main__":
    map_size = 100

    # init the game "pursuit"  (config file are stored in python/magent/builtin/config/)
    env = magent.GridWorld("pursuit", map_size=map_size)
    #
    env.set_render_dir("build/render")

    # get group handles
    predator, prey = env.get_handles()

    # init env and agents
    env.reset()
    env.add_walls(method="random", n=map_size * map_size * 0.01)
    env.add_agents(predator, method="random", n=map_size * map_size * 0.02)
    env.add_agents(prey, method="random", n=map_size * map_size * 0.02)

    # init two models
    model1 = DeepQNetwork(env, predator, "predator")
    model2 = DeepQNetwork(env, prey, "prey")
Exemplo n.º 25
0
        '--map_size', type=int, default=40,
        help='set the size of map')  # 40,then the amount of agents is 64
    parser.add_argument('--max_steps',
                        type=int,
                        default=400,
                        help='set the max steps')
    parser.add_argument('--idx', nargs='*', required=True)
    parser.add_argument('--neighbor_nums',
                        type=int,
                        default=-1,
                        help='set neighbors')  #-1 means all-connected graph.

    args = parser.parse_args()

    # Initialize the environment
    env = magent.GridWorld('battle', map_size=args.map_size)
    env.set_render_dir(
        os.path.join(BASE_DIR, 'examples/battle_model', 'build/render'))
    handles = env.get_handles()

    tf_config = tf.ConfigProto(allow_soft_placement=True,
                               log_device_placement=False)
    tf_config.gpu_options.allow_growth = True
    #
    main_model_dir = os.path.join(BASE_DIR,
                                  'data/models/{}-0'.format(args.algo))
    if args.algo == args.oppo:
        oppo_model_dir = os.path.join(BASE_DIR,
                                      'data/models/{}-1'.format(args.oppo))
    else:
        oppo_model_dir = os.path.join(BASE_DIR,
Exemplo n.º 26
0
        default="forest",
        choices=["forest", "double_attack"],
        help=
        "GridWorld mode, could be 'forest' or 'double_attack', default='forest'",
    )

    args = parser.parse_args()

    if args.mode == "forest":
        config = data.config_forest(args.map_size)
    elif args.mode == "double_attack":
        config = data.config_double_attack(args.map_size)
    else:
        config = None

    env = magent.GridWorld(config, map_size=args.map_size)
    env.set_render_dir(args.render)
    deer_handle, tiger_handle = env.get_handles()

    env.reset()
    env.add_walls(method="random",
                  n=args.map_size * args.map_size * args.walls_density)
    env.add_agents(deer_handle, method="random", n=args.deers)
    env.add_agents(tiger_handle, method="random", n=args.tigers)

    v = env.get_view_space(tiger_handle)
    v = (v[-1], ) + v[:2]
    net = model.DQNModel(v, env.get_feature_space(tiger_handle),
                         env.get_action_space(tiger_handle)[0])
    net.load_state_dict(torch.load(args.model))
    print(net)
Exemplo n.º 27
0
    parser.add_argument("--train", action="store_true")
    parser.add_argument("--map_size", type=int, default=60)
    parser.add_argument("--greedy", action="store_true")
    parser.add_argument("--name", type=str, default="battle")
    parser.add_argument("--eval", action="store_true")
    parser.add_argument('--alg', default='dqn', choices=['dqn', 'drqn', 'a2c'])
    args = parser.parse_args()

    # set logger
    log.basicConfig(level=log.INFO, filename=args.name + '.log')
    console = log.StreamHandler()
    console.setLevel(log.INFO)
    log.getLogger('').addHandler(console)

    # init the game
    env = magent.GridWorld(get_config(args.map_size))
    env.set_render_dir("build/render")

    # two groups of agents
    names = [args.name + "-l", args.name + "-r"]
    handles = env.get_handles()

    # sample eval observation set
    eval_obs = None
    if args.eval:
        print("sample eval set...")
        env.reset()
        generate_map(env, args.map_size, handles)
        eval_obs = buffer.sample_observation(env, handles, 2048, 500)[0]

    # init models
Exemplo n.º 28
0
        if args.alg.find('commnet') > -1:
            args = get_commnet_args(args)
        if args.alg.find('g2anet') > -1:
            args = get_g2anet_args(args)
        # env = StarCraft2Env(map_name=args.map,
        #                     step_mul=args.step_mul,
        #                     difficulty=args.difficulty,
        #                     game_version=args.game_version,
        #                     replay_dir=args.replay_dir)
        # env = magent.GridWorld("battle", map_size=30)
        args.map_size = 80  # pursuit:180 270;battle:80 100
        args.env_name = 'battle'
        args.map = args.alg
        args.name_time = 'est'
        # alt_wo_per alt_wo_dq
        env = magent.GridWorld(args.env_name, map_size=args.map_size)
        # env = magent.GridWorld(get_config_double_attack(args.map_size))  # pursuit 180 270 330
        handles = env.get_handles()
        eval_obs = None
        feature_dim = env.get_feature_space(handles[0])
        view_dim = env.get_view_space(handles[0])
        real_view_shape = view_dim
        v_dim_total = view_dim[0] * view_dim[1] * view_dim[2]
        obs_shape = (v_dim_total + feature_dim[0],)

        # env_info = env.get_env_info()
        # print(env.action_space[0][0])
        args.n_actions = env.action_space[0][0]
        args.fixed_n_actions = env.action_space[1][0]
        args.n_agents = 5  # pursuit:8;battle:5
        args.more_walls = 1  # pursuit:1;battle:1 10
Exemplo n.º 29
0
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_every", type=int, default=5)
    parser.add_argument("--n_round", type=int, default=200)
    parser.add_argument("--render", action="store_true")
    parser.add_argument("--load_from", type=int)
    parser.add_argument("--train", action="store_true")
    parser.add_argument("--greedy", action="store_true")
    parser.add_argument("--map_size", type=int, default=500)
    parser.add_argument("--name", type=str, default="tiger")
    parser.add_argument('--alg', default='dqn', choices=['dqn', 'drqn', 'a2c'])
    args = parser.parse_args()

    # init the game
    env = magent.GridWorld("double_attack", map_size=args.map_size)
    env.set_render_dir("build/render")

    # two groups of animal
    deer_handle, tiger_handle = env.get_handles()

    # init two models
    models = [
        RandomActor(env, deer_handle, tiger_handle),
    ]

    batch_size = 512
    unroll = 8

    if args.alg == 'dqn':
        from magent.builtin.tf_model import DeepQNetwork
Exemplo n.º 30
0
from senario_battle import generate_map
import magent
import numpy as np
env = magent.GridWorld('battle', map_size=40)
env.reset()
handles = env.get_handles()
generate_map(env, 40, handles)
state = list(env.get_observation(handles[0]))
handle = handles[0]
nei_len = 10
poss = env.get_pos(handle)
nei_space = nei_len**2
act_prob = []
for i in range(poss.shape[0]):
    sum_cnt = 0
    act_p = np.zeros(env.get_action_space(handle)[0])
    for j in range(poss.shape[0]):
        if i != j:
            if np.sum(np.square(poss[i] - poss[j])) < nei_space:
                # act_p+=acts_onehot[j]
                sum_cnt += 1
    print(sum_cnt)
# n_action = [env.get_action_space(handles[0])[0], env.get_action_space(handles[1])[0]]
# print(n_action)
# print(state.len())

# import matplotlib.pyplot as plt
# plt.imshow(state) #Needs to be in row,col order