def find_best_model(model_path, model_num): args = get_common_args() if args.alg == 'coma': args = get_coma_args(args) rnn_suffix = 'rnn_params.pkl' critic_fuffix = 'critic_params.pkl' policy = COMA elif args.alg == 'qmix': args = get_mixer_args(args) rnn_suffix = 'rnn_net_params.pkl' critic_fuffix = 'qmix_net_params.pkl' policy = QMIX elif args.alg == 'vdn': args = get_mixer_args(args) rnn_suffix = 'rnn_net_params.pkl' critic_fuffix = 'vdn_net_params.pkl' policy = VDN else: raise Exception("Not finished") env = StarCraft2Env(map_name=args.map, step_mul=args.step_mul, difficulty=args.difficulty, game_version=args.game_version, replay_dir=args.replay_dir) env_info = env.get_env_info() args.n_actions = env_info["n_actions"] args.n_agents = env_info["n_agents"] args.state_shape = env_info["state_shape"] args.obs_shape = env_info["obs_shape"] args.episode_limit = env_info["episode_limit"] args.evaluate_epoch = 100 runner = Runner(env, args) max_win_rate = 0 max_win_rate_idx = 0 for num in range(model_num): critic_path = model_path + '/' + str(num) + '_' + critic_fuffix rnn_path = model_path + '/' + str(num) + '_' + rnn_suffix if os.path.exists(critic_path) and os.path.exists(rnn_path): os.rename(critic_path, model_path + '/' + critic_fuffix) os.rename(rnn_path, model_path + '/' + rnn_suffix) runner.agents.policy = policy(args) win_rate = runner.evaluate_sparse() if win_rate > max_win_rate: max_win_rate = win_rate max_win_rate_idx = num os.rename(model_path + '/' + critic_fuffix, critic_path) os.rename(model_path + '/' + rnn_suffix, rnn_path) print('The win rate of {} is {}'.format(num, win_rate)) print('The max win rate is {}, model index is {}'.format( max_win_rate, max_win_rate_idx))
e1 = gw.Event(a, 'attack', c) e2 = gw.Event(b, 'attack', c) cfg.add_reward_rule(e1 & e2, receiver=[a, b], value=[1, 1]) return cfg if __name__ == '__main__': view_dic = {'pursuit': 5, 'battle': 6, 'double_attack': 4} num_neighbor_dic = {'pursuit': 3, 'battle': 4, 'double_attack': 1} reward_event_dic = {'pursuit': 0.7, 'battle': 0, 'double_attack': 0} for i in range(1): args = get_common_args() args.alg = 'ours' if args.alg.find('coma') > -1: args = get_coma_args(args) elif args.alg.find('central_v') > -1: args = get_centralv_args(args) elif args.alg.find('reinforce') > -1: args = get_reinforce_args(args) else: args = get_mixer_args(args) if args.alg.find('commnet') > -1: args = get_commnet_args(args) if args.alg.find('g2anet') > -1: args = get_g2anet_args(args) # env = StarCraft2Env(map_name=args.map, # step_mul=args.step_mul, # difficulty=args.difficulty, # game_version=args.game_version, # replay_dir=args.replay_dir)