def find_best_model(model_path, model_num): args = get_common_args() if args.alg == 'coma': args = get_coma_args(args) rnn_suffix = 'rnn_params.pkl' critic_fuffix = 'critic_params.pkl' policy = COMA elif args.alg == 'qmix': args = get_mixer_args(args) rnn_suffix = 'rnn_net_params.pkl' critic_fuffix = 'qmix_net_params.pkl' policy = QMIX elif args.alg == 'vdn': args = get_mixer_args(args) rnn_suffix = 'rnn_net_params.pkl' critic_fuffix = 'vdn_net_params.pkl' policy = VDN else: raise Exception("Not finished") env = StarCraft2Env(map_name=args.map, step_mul=args.step_mul, difficulty=args.difficulty, game_version=args.game_version, replay_dir=args.replay_dir) env_info = env.get_env_info() args.n_actions = env_info["n_actions"] args.n_agents = env_info["n_agents"] args.state_shape = env_info["state_shape"] args.obs_shape = env_info["obs_shape"] args.episode_limit = env_info["episode_limit"] args.evaluate_epoch = 100 runner = Runner(env, args) max_win_rate = 0 max_win_rate_idx = 0 for num in range(model_num): critic_path = model_path + '/' + str(num) + '_' + critic_fuffix rnn_path = model_path + '/' + str(num) + '_' + rnn_suffix if os.path.exists(critic_path) and os.path.exists(rnn_path): os.rename(critic_path, model_path + '/' + critic_fuffix) os.rename(rnn_path, model_path + '/' + rnn_suffix) runner.agents.policy = policy(args) win_rate = runner.evaluate_sparse() if win_rate > max_win_rate: max_win_rate = win_rate max_win_rate_idx = num os.rename(model_path + '/' + critic_fuffix, critic_path) os.rename(model_path + '/' + rnn_suffix, rnn_path) print('The win rate of {} is {}'.format(num, win_rate)) print('The max win rate is {}, model index is {}'.format( max_win_rate, max_win_rate_idx))
def __init__(self, args): self.n_actions = args.n_actions self.n_agents = args.n_agents * 2 self.state_shape = args.state_shape self.obs_shape = args.obs_shape self.idact_shape = args.id_dim + args.n_actions self.search_actions = np.eye(args.n_actions) self.search_ids = np.zeros(self.n_agents) if args.alg == 'vdn': self.policy = VDN(args) elif args.alg == 'qmix': self.policy = QMIX(args) elif args.alg == 'ours': self.policy = OURS(args) elif args.alg == 'coma': self.policy = COMA(args) elif args.alg == 'qtran_alt': self.policy = QtranAlt(args) elif args.alg == 'qtran_base': self.policy = QtranBase(args) elif args.alg == 'maven': self.policy = MAVEN(args) elif args.alg == 'central_v': self.policy = CentralV(args) elif args.alg == 'reinforce': self.policy = Reinforce(args) else: raise Exception("No such algorithm") if args.use_fixed_model: args_goal_a = get_common_args() args_goal_a.load_model = True args_goal_a = get_mixer_args(args_goal_a) args_goal_a.learn = False args_goal_a.epsilon = 0 # 1 args_goal_a.min_epsilon = 0 args_goal_a.map = 'battle' args_goal_a.n_actions = args.n_actions args_goal_a.episode_limit = args.episode_limit args_goal_a.n_agents = args.n_agents args_goal_a.state_shape = args.state_shape args_goal_a.feature_shape = args.feature_shape args_goal_a.view_shape = args.view_shape args_goal_a.obs_shape = args.obs_shape args_goal_a.real_view_shape = args.real_view_shape args_goal_a.load_num = args.load_num args_goal_a.use_ja = False args_goal_a.mlp_hidden_dim = [512, 512] self.fixed_policy = VDN_F(args_goal_a) self.args = args print('Init Agents')
if __name__ == '__main__': view_dic = {'pursuit': 5, 'battle': 6, 'double_attack': 4} num_neighbor_dic = {'pursuit': 3, 'battle': 4, 'double_attack': 1} reward_event_dic = {'pursuit': 0.7, 'battle': 0, 'double_attack': 0} for i in range(1): args = get_common_args() args.alg = 'ours' if args.alg.find('coma') > -1: args = get_coma_args(args) elif args.alg.find('central_v') > -1: args = get_centralv_args(args) elif args.alg.find('reinforce') > -1: args = get_reinforce_args(args) else: args = get_mixer_args(args) if args.alg.find('commnet') > -1: args = get_commnet_args(args) if args.alg.find('g2anet') > -1: args = get_g2anet_args(args) # env = StarCraft2Env(map_name=args.map, # step_mul=args.step_mul, # difficulty=args.difficulty, # game_version=args.game_version, # replay_dir=args.replay_dir) # env = magent.GridWorld("battle", map_size=30) args.map_size = 80 # pursuit:180 270;battle:80 100 args.env_name = 'battle' args.map = args.alg args.name_time = 'est' # alt_wo_per alt_wo_dq