예제 #1
0
def find_best_model(model_path, model_num):
    args = get_common_args()
    if args.alg == 'coma':
        args = get_coma_args(args)
        rnn_suffix = 'rnn_params.pkl'
        critic_fuffix = 'critic_params.pkl'
        policy = COMA
    elif args.alg == 'qmix':
        args = get_mixer_args(args)
        rnn_suffix = 'rnn_net_params.pkl'
        critic_fuffix = 'qmix_net_params.pkl'
        policy = QMIX
    elif args.alg == 'vdn':
        args = get_mixer_args(args)
        rnn_suffix = 'rnn_net_params.pkl'
        critic_fuffix = 'vdn_net_params.pkl'
        policy = VDN
    else:
        raise Exception("Not finished")
    env = StarCraft2Env(map_name=args.map,
                        step_mul=args.step_mul,
                        difficulty=args.difficulty,
                        game_version=args.game_version,
                        replay_dir=args.replay_dir)
    env_info = env.get_env_info()
    args.n_actions = env_info["n_actions"]
    args.n_agents = env_info["n_agents"]
    args.state_shape = env_info["state_shape"]
    args.obs_shape = env_info["obs_shape"]
    args.episode_limit = env_info["episode_limit"]
    args.evaluate_epoch = 100
    runner = Runner(env, args)
    max_win_rate = 0
    max_win_rate_idx = 0
    for num in range(model_num):
        critic_path = model_path + '/' + str(num) + '_' + critic_fuffix
        rnn_path = model_path + '/' + str(num) + '_' + rnn_suffix
        if os.path.exists(critic_path) and os.path.exists(rnn_path):
            os.rename(critic_path, model_path + '/' + critic_fuffix)
            os.rename(rnn_path, model_path + '/' + rnn_suffix)
            runner.agents.policy = policy(args)
            win_rate = runner.evaluate_sparse()
            if win_rate > max_win_rate:
                max_win_rate = win_rate
                max_win_rate_idx = num

            os.rename(model_path + '/' + critic_fuffix, critic_path)
            os.rename(model_path + '/' + rnn_suffix, rnn_path)
            print('The win rate of {} is  {}'.format(num, win_rate))
    print('The max win rate is {}, model index is {}'.format(
        max_win_rate, max_win_rate_idx))
예제 #2
0
    e1 = gw.Event(a, 'attack', c)
    e2 = gw.Event(b, 'attack', c)
    cfg.add_reward_rule(e1 & e2, receiver=[a, b], value=[1, 1])

    return cfg


if __name__ == '__main__':
    view_dic = {'pursuit': 5, 'battle': 6, 'double_attack': 4}
    num_neighbor_dic = {'pursuit': 3, 'battle': 4, 'double_attack': 1}
    reward_event_dic = {'pursuit': 0.7, 'battle': 0, 'double_attack': 0}
    for i in range(1):
        args = get_common_args()
        args.alg = 'ours'
        if args.alg.find('coma') > -1:
            args = get_coma_args(args)
        elif args.alg.find('central_v') > -1:
            args = get_centralv_args(args)
        elif args.alg.find('reinforce') > -1:
            args = get_reinforce_args(args)
        else:
            args = get_mixer_args(args)
        if args.alg.find('commnet') > -1:
            args = get_commnet_args(args)
        if args.alg.find('g2anet') > -1:
            args = get_g2anet_args(args)
        # env = StarCraft2Env(map_name=args.map,
        #                     step_mul=args.step_mul,
        #                     difficulty=args.difficulty,
        #                     game_version=args.game_version,
        #                     replay_dir=args.replay_dir)