def test_model(net_deer: model.DQNModel, net_tiger: model.DQNModel,
               device: torch.device,
               gw_config) -> Tuple[float, float, float, float]:
    test_env = magent.GridWorld(gw_config, map_size=MAP_SIZE)
    deer_handle, tiger_handle = test_env.get_handles()

    def reset_env():
        test_env.reset()
        test_env.add_walls(method="random",
                           n=MAP_SIZE * MAP_SIZE * WALLS_DENSITY)
        test_env.add_agents(deer_handle, method="random", n=COUNT_DEERS)
        test_env.add_agents(tiger_handle, method="random", n=COUNT_TIGERS)

    deer_env = data.MAgentEnv(test_env,
                              deer_handle,
                              reset_env_func=reset_env,
                              is_slave=True)
    tiger_env = data.MAgentEnv(test_env,
                               tiger_handle,
                               reset_env_func=reset_env)
    preproc = model.MAgentPreprocessor(device)
    deer_agent = ptan.agent.DQNAgent(net_deer,
                                     ptan.actions.ArgmaxActionSelector(),
                                     device,
                                     preprocessor=preproc)
    tiger_agent = ptan.agent.DQNAgent(net_tiger,
                                      ptan.actions.ArgmaxActionSelector(),
                                      device,
                                      preprocessor=preproc)

    t_obs = tiger_env.reset()
    d_obs = deer_env.reset()
    deer_steps = 0
    deer_rewards = 0.0
    tiger_steps = 0
    tiger_rewards = 0.0

    while True:
        d_actions = deer_agent(d_obs)[0]
        t_actions = tiger_agent(t_obs)[0]
        d_obs, d_r, d_dones, _ = deer_env.step(d_actions)
        t_obs, t_r, t_dones, _ = tiger_env.step(t_actions)
        tiger_steps += len(t_obs)
        tiger_rewards += sum(t_r)
        if t_dones[0]:
            break
        deer_steps += len(d_obs)
        deer_rewards += sum(d_r)
        if d_dones[0]:
            break

    return deer_rewards / COUNT_DEERS, deer_steps / COUNT_DEERS, \
           tiger_rewards / COUNT_TIGERS, tiger_steps / COUNT_TIGERS
def test_model(net: model.DQNModel, device: torch.device,
               gw_config) -> Tuple[float, float, float, float]:
    test_env = magent.GridWorld(gw_config, map_size=MAP_SIZE)
    group_a, group_b = test_env.get_handles()

    def reset_env():
        test_env.reset()
        test_env.add_walls(method="random",
                           n=MAP_SIZE * MAP_SIZE * WALLS_DENSITY)
        test_env.add_agents(group_a, method="random", n=COUNT_AGENTS_1)
        test_env.add_agents(group_b, method="random", n=COUNT_AGENTS_2)

    env_a = data.MAgentEnv(test_env,
                           group_a,
                           reset_env_func=reset_env,
                           is_slave=True)
    env_b = data.MAgentEnv(test_env,
                           group_b,
                           reset_env_func=reset_env,
                           steps_limit=MAX_EPISODE)
    preproc = model.MAgentPreprocessor(device)
    agent_a = ptan.agent.DQNAgent(net,
                                  ptan.actions.ArgmaxActionSelector(),
                                  device,
                                  preprocessor=preproc)
    agent_b = ptan.agent.DQNAgent(net,
                                  ptan.actions.ArgmaxActionSelector(),
                                  device,
                                  preprocessor=preproc)

    a_obs = env_a.reset()
    b_obs = env_b.reset()
    a_steps = 0
    a_rewards = 0.0
    b_steps = 0
    b_rewards = 0.0

    while True:
        a_actions = agent_a(a_obs)[0]
        b_actions = agent_b(b_obs)[0]
        a_obs, a_r, a_dones, _ = env_a.step(a_actions)
        b_obs, b_r, b_dones, _ = env_b.step(b_actions)
        a_steps += len(a_obs)
        a_rewards += sum(a_r)
        if a_dones[0]:
            break
        b_steps += len(b_obs)
        b_rewards += sum(b_r)
        if b_dones[0]:
            break

    return a_rewards / COUNT_AGENTS_1, b_steps / COUNT_AGENTS_1, \
           b_rewards / COUNT_AGENTS_2, b_steps / COUNT_AGENTS_2
示例#3
0
def test_model(net: model.DQNModel, device: torch.device,
               gw_config) -> Tuple[float, float]:
    test_env = magent.GridWorld(gw_config, map_size=MAP_SIZE)
    deer_handle, tiger_handle = test_env.get_handles()

    def reset_env():
        test_env.reset()
        test_env.add_walls(method="random",
                           n=MAP_SIZE * MAP_SIZE * WALLS_DENSITY)
        test_env.add_agents(deer_handle, method="random", n=COUNT_DEERS)
        test_env.add_agents(tiger_handle, method="random", n=COUNT_TIGERS)

    env = data.MAgentEnv(test_env, tiger_handle, reset_env_func=reset_env)
    preproc = model.MAgentPreprocessor(device)
    agent = ptan.agent.DQNAgent(net,
                                ptan.actions.ArgmaxActionSelector(),
                                device,
                                preprocessor=preproc)

    obs = env.reset()
    steps = 0
    rewards = 0.0

    while True:
        actions = agent(obs)[0]
        obs, r, dones, _ = env.step(actions)
        steps += len(obs)
        rewards += sum(r)
        if dones[0]:
            break

    return rewards / COUNT_TIGERS, steps / COUNT_TIGERS
    device = torch.device("cuda" if args.cuda else "cpu")
    saves_path = os.path.join("saves", args.name)
    os.makedirs(saves_path, exist_ok=True)

    m_env = magent.GridWorld(config, map_size=MAP_SIZE)

    # two groups of animal
    deer_handle, tiger_handle = m_env.get_handles()

    def reset_env():
        m_env.reset()
        m_env.add_walls(method="random", n=MAP_SIZE * MAP_SIZE * WALLS_DENSITY)
        m_env.add_agents(deer_handle, method="random", n=COUNT_DEERS)
        m_env.add_agents(tiger_handle, method="random", n=COUNT_TIGERS)

    env = data.MAgentEnv(m_env, tiger_handle, reset_env_func=reset_env)

    if args.mode == 'double_attack_nn':
        net = model.DQNNoisyModel(
            env.single_observation_space.spaces[0].shape,
            env.single_observation_space.spaces[1].shape,
            m_env.get_action_space(tiger_handle)[0]).to(device)
    else:
        net = model.DQNModel(
            env.single_observation_space.spaces[0].shape,
            env.single_observation_space.spaces[1].shape,
            m_env.get_action_space(tiger_handle)[0]).to(device)
    tgt_net = ptan.agent.TargetNet(net)
    print(net)

    if args.mode == 'double_attack':
    saves_path = os.path.join("saves", args.name)
    os.makedirs(saves_path, exist_ok=True)

    m_env = magent.GridWorld(config, map_size=MAP_SIZE)

    # two groups of animal
    deer_handle, tiger_handle = m_env.get_handles()

    def reset_env():
        m_env.reset()
        m_env.add_walls(method="random", n=MAP_SIZE * MAP_SIZE * WALLS_DENSITY)
        m_env.add_agents(deer_handle, method="random", n=COUNT_DEERS)
        m_env.add_agents(tiger_handle, method="random", n=COUNT_TIGERS)

    deer_env = data.MAgentEnv(m_env,
                              deer_handle,
                              reset_env_func=lambda: None,
                              is_slave=True)
    tiger_env = data.MAgentEnv(m_env,
                               tiger_handle,
                               reset_env_func=reset_env,
                               is_slave=False)

    deer_obs = data.MAgentEnv.handle_obs_space(m_env, deer_handle)
    tiger_obs = data.MAgentEnv.handle_obs_space(m_env, tiger_handle)

    net_deer = model.DQNModel(
        deer_obs.spaces[0].shape, deer_obs.spaces[1].shape,
        m_env.get_action_space(deer_handle)[0]).to(device)
    tgt_net_deer = ptan.agent.TargetNet(net_deer)
    print(net_deer)
    device = torch.device("cuda" if args.cuda else "cpu")
    saves_path = os.path.join("saves", args.name)
    os.makedirs(saves_path, exist_ok=True)

    m_env = magent.GridWorld(config, map_size=MAP_SIZE)

    a_handle, b_handle = m_env.get_handles()

    def reset_env():
        m_env.reset()
        m_env.add_walls(method="random", n=MAP_SIZE * MAP_SIZE * WALLS_DENSITY)
        m_env.add_agents(a_handle, method="random", n=COUNT_AGENTS_1)
        m_env.add_agents(b_handle, method="random", n=COUNT_AGENTS_2)

    a_env = data.MAgentEnv(m_env,
                           a_handle,
                           reset_env_func=lambda: None,
                           is_slave=True)
    b_env = data.MAgentEnv(m_env,
                           b_handle,
                           reset_env_func=reset_env,
                           is_slave=False,
                           steps_limit=MAX_EPISODE)

    obs = data.MAgentEnv.handle_obs_space(m_env, a_handle)

    net = model.DQNModel(obs.spaces[0].shape, obs.spaces[1].shape,
                         m_env.get_action_space(a_handle)[0]).to(device)
    tgt_net = ptan.agent.TargetNet(net)
    print(net)

    action_selector = ptan.actions.EpsilonGreedyActionSelector(