Exemple #1
0
 def __init__(self, args):
     self.n_actions = args.n_actions
     self.n_agents = args.n_agents
     self.state_shape = args.state_shape
     self.obs_shape = args.obs_shape
     if args.alg == 'vdn':
         self.policy = VDN(args)
     else:
         self.policy = QMIX(args)
     self.args = args
 def __init__(self, args):
     self.n_actions = args.n_actions
     self.n_agents = args.n_agents
     self.state_shape = args.state_shape
     self.obs_shape = args.obs_shape
     if args.alg == 'vdn':
         self.policy = VDN(args)
     elif args.alg == 'qmix':
         self.policy = QMIX(args)
     elif args.alg == 'coma':
         self.policy = COMA(args)
     elif args.alg == 'qtran_alt':
         self.policy = QtranAlt(args)
     elif args.alg == 'qtran_base':
         self.policy = QtranBase(args)
     elif args.alg == 'maven':
         self.policy = MAVEN(args)
     elif args.alg == 'central_v':
         self.policy = CentralV(args)
     elif args.alg == 'reinforce':
         self.policy = Reinforce(args)
     else:
         raise Exception("No such algorithm")
     self.args = args
     print('Init Agents')
Exemple #3
0
 def __init__(self, args):
     self.n_actions = args.n_actions
     self.n_agents = args.n_agents
     self.state_shape = args.state_shape
     self.obs_shape = args.obs_shape
     if args.alg == 'vdn':
         from policy.vdn import VDN
         self.policy = VDN(args)
     elif args.alg == 'iql':
         from policy.iql import IQL
         self.policy = IQL(args)
     elif args.alg == 'qmix':
         from policy.qmix import QMIX
         self.policy = QMIX(args)
     elif args.alg == 'coma':
         from policy.coma import COMA
         self.policy = COMA(args)
     elif args.alg == 'qtran_alt':
         from policy.qtran_alt import QtranAlt
         self.policy = QtranAlt(args)
     elif args.alg == 'qtran_base':
         from policy.qtran_base import QtranBase
         self.policy = QtranBase(args)
     elif args.alg == 'maven':
         from policy.maven import MAVEN
         self.policy = MAVEN(args)
     elif args.alg == 'central_v':
         from policy.central_v import CentralV
         self.policy = CentralV(args)
     elif args.alg == 'reinforce':
         from policy.reinforce import Reinforce
         self.policy = Reinforce(args)
     else:
         raise Exception("No such algorithm")
     self.args = args
Exemple #4
0
 def __init__(self, args):
     self.n_actions = args.n_actions
     self.n_agents = args.n_agents * 2
     self.state_shape = args.state_shape
     self.obs_shape = args.obs_shape
     self.idact_shape = args.id_dim + args.n_actions
     self.search_actions = np.eye(args.n_actions)
     self.search_ids = np.zeros(self.n_agents)
     if args.alg == 'vdn':
         self.policy = VDN(args)
     elif args.alg == 'qmix':
         self.policy = QMIX(args)
     elif args.alg == 'ours':
         self.policy = OURS(args)
     elif args.alg == 'coma':
         self.policy = COMA(args)
     elif args.alg == 'qtran_alt':
         self.policy = QtranAlt(args)
     elif args.alg == 'qtran_base':
         self.policy = QtranBase(args)
     elif args.alg == 'maven':
         self.policy = MAVEN(args)
     elif args.alg == 'central_v':
         self.policy = CentralV(args)
     elif args.alg == 'reinforce':
         self.policy = Reinforce(args)
     else:
         raise Exception("No such algorithm")
     if args.use_fixed_model:
         args_goal_a = get_common_args()
         args_goal_a.load_model = True
         args_goal_a = get_mixer_args(args_goal_a)
         args_goal_a.learn = False
         args_goal_a.epsilon = 0  # 1
         args_goal_a.min_epsilon = 0
         args_goal_a.map = 'battle'
         args_goal_a.n_actions = args.n_actions
         args_goal_a.episode_limit = args.episode_limit
         args_goal_a.n_agents = args.n_agents
         args_goal_a.state_shape = args.state_shape
         args_goal_a.feature_shape = args.feature_shape
         args_goal_a.view_shape = args.view_shape
         args_goal_a.obs_shape = args.obs_shape
         args_goal_a.real_view_shape = args.real_view_shape
         args_goal_a.load_num = args.load_num
         args_goal_a.use_ja = False
         args_goal_a.mlp_hidden_dim = [512, 512]
         self.fixed_policy = VDN_F(args_goal_a)
     self.args = args
     print('Init Agents')
Exemple #5
0
 def __init__(self, args):
     self.n_actions = args.n_actions
     self.n_agents = args.n_agents
     self.state_shape = args.state_shape
     self.obs_shape = args.obs_shape
     if args.alg == 'vdn':
         self.policy = VDN(args)
     elif args.alg == 'qmix':
         self.policy = QMIX(args)
     elif args.alg == 'coma':
         self.policy = COMA(args)
     elif args.alg == 'qtran_alt':
         self.policy = QtranAlt(args)
     elif args.alg == 'qtran_base':
         self.policy = QtranBase(args)
     else:
         raise Exception("No such algorithm")
     self.args = args
Exemple #6
0
class Agents:
    def __init__(self, args):
        self.n_actions = args.n_actions
        self.n_agents = args.n_agents
        self.state_shape = args.state_shape
        self.obs_shape = args.obs_shape
        if args.alg == 'vdn':
            self.policy = VDN(args)
        else:
            self.policy = QMIX(args)
        self.args = args

    def choose_action(self, obs, last_action, agent_num, avail_actions,
                      epsilon):
        inputs = obs.copy()
        avail_actions_ind = np.nonzero(avail_actions)[0]
        # 传入的agent_num是一个整数,代表第几个agent,现在要把他变成一个onehot向量
        agent_id = np.zeros(self.n_agents)
        agent_id[agent_num] = 1.

        if self.args.last_action:
            inputs = np.hstack((inputs, last_action))  # obs是数组,不能append
        if self.args.reuse_network:
            inputs = np.hstack((inputs, agent_id))
        hidden_state = self.policy.eval_hidden[:, agent_num, :]
        # 转化成Tensor,inputs的维度是(42,),要转化成(1,42)
        inputs = torch.tensor(inputs, dtype=torch.float32).unsqueeze(0)
        avail_actions = torch.tensor(avail_actions,
                                     dtype=torch.float32).unsqueeze(0)
        q_value, self.policy.eval_hidden[:,
                                         agent_num, :] = self.policy.eval_rnn.forward(
                                             inputs, hidden_state)
        q_value[avail_actions == 0.0] = -float(
            "inf")  # 传入的avail_actions参数是一个array
        if np.random.uniform() < epsilon:
            action = np.random.choice(avail_actions_ind)  # action是一个整数
        else:
            action = torch.argmax(q_value)
        return action

    def _get_max_episode_len(self, batch):
        terminated = batch['terminated']
        episode_num = terminated.shape[0]
        max_episode_len = 0
        for episode_idx in range(episode_num):
            for transition_idx in range(self.args.episode_limit):
                if terminated[episode_idx, transition_idx, 0] == 1:
                    if transition_idx + 1 >= max_episode_len:
                        max_episode_len = transition_idx + 1
                    break
        return max_episode_len

    def train(self, batch, train_step):
        # 每次学习时,各个episode的长度不一样,因此取其中最长的episode作为所有episode的长度
        max_episode_len = self._get_max_episode_len(batch)
        for key in batch.keys():
            if key != 'avail_u':
                batch[key] = batch[key][:, :max_episode_len]
            else:
                batch[key] = batch[key][:, :max_episode_len +
                                        1]  # avail_u要比其他的多一个
        self.policy.learn(batch, max_episode_len, train_step)
        if train_step > 0 and train_step % self.args.save_cycle == 0:
            self.policy.save_model(train_step)

    def check_consistency(self, batch):
        terminated = batch['terminated']
        episode_num = terminated.shape[0]
        episode_len = 0
        for episode_idx in range(episode_num):
            for transition_idx in range(self.args.episode_limit):
                if terminated[episode_idx, transition_idx, 0] == 1:
                    episode_len = transition_idx + 1
            for transition_idx in range(1, episode_len):
                a = batch['o'][episode_idx, transition_idx]
                b = batch['o_next'][episode_idx, transition_idx - 1]
                if not ((a == b).all()):
                    print('o is', batch['o'][episode_idx, transition_idx])
                    print('o_next is', batch['o_next'][episode_idx,
                                                       transition_idx])
            a = batch['o'][episode_idx, 1:episode_len]
            b = batch['o_next'][episode_idx, :episode_len - 1]
            # print(a == b)
            if not (a == b).all():
                print(episode_idx)