Example #1
0
class VDN:
    def __init__(self, args):
        self.n_actions = args.n_actions
        self.n_agents = args.n_agents
        self.state_shape = args.state_shape
        self.obs_shape = args.obs_shape
        input_shape = self.obs_shape
        # 根据参数决定RNN的输入维度
        if args.last_action:
            input_shape += self.n_actions
        if args.reuse_network:
            input_shape += self.n_agents

        # 神经网络
        self.eval_rnn = RNN(input_shape, args)  # 每个agent选动作的网络
        self.target_rnn = RNN(input_shape, args)
        self.eval_vdn_net = VDNNet()  # 把agentsQ值加起来的网络
        self.target_vdn_net = VDNNet()
        self.args = args
        if self.args.cuda:
            self.eval_rnn.cuda()
            self.target_rnn.cuda()
            self.eval_vdn_net.cuda()
            self.target_vdn_net.cuda()

        self.model_dir = args.model_dir + '/' + args.alg + '/' + args.map
        # 如果存在模型则加载模型
        if self.args.load_model:
            if os.path.exists(self.model_dir + '/rnn_net_params.pkl'):
                path_rnn = self.model_dir + '/rnn_net_params.pkl'
                path_vdn = self.model_dir + '/vdn_net_params.pkl'
                map_location = 'cuda:0' if self.args.cuda else 'cpu'
                self.eval_rnn.load_state_dict(
                    torch.load(path_rnn, map_location=map_location))
                self.eval_vdn_net.load_state_dict(
                    torch.load(path_vdn, map_location=map_location))
                print('Successfully load the model: {} and {}'.format(
                    path_rnn, path_vdn))
            else:
                raise Exception("No model!")

        # 让target_net和eval_net的网络参数相同
        self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
        self.target_vdn_net.load_state_dict(self.eval_vdn_net.state_dict())

        self.eval_parameters = list(self.eval_vdn_net.parameters()) + list(
            self.eval_rnn.parameters())
        if args.optimizer == "RMS":
            self.optimizer = torch.optim.RMSprop(self.eval_parameters,
                                                 lr=args.lr)

        # 执行过程中,要为每个agent都维护一个eval_hidden
        # 学习过程中,要为每个episode的每个agent都维护一个eval_hidden、target_hidden
        self.eval_hidden = None
        self.target_hidden = None
        print('Init alg VDN')

    def learn(self,
              batch,
              max_episode_len,
              train_step,
              epsilon=None):  # train_step表示是第几次学习,用来控制更新target_net网络的参数
        '''
        在learn的时候,抽取到的数据是四维的,四个维度分别为 1——第几个episode 2——episode中第几个transition
        3——第几个agent的数据 4——具体obs维度。因为在选动作时不仅需要输入当前的inputs,还要给神经网络输入hidden_state,
        hidden_state和之前的经验相关,因此就不能随机抽取经验进行学习。所以这里一次抽取多个episode,然后一次给神经网络
        传入每个episode的同一个位置的transition
        '''
        episode_num = batch['o'].shape[0]
        self.init_hidden(episode_num)
        for key in batch.keys():  # 把batch里的数据转化成tensor
            if key == 'u':
                batch[key] = torch.tensor(batch[key], dtype=torch.long)
            else:
                batch[key] = torch.tensor(batch[key], dtype=torch.float32)
        # TODO pymarl中取得经验没有取最后一条,找出原因
        u, r, avail_u, avail_u_next, terminated = batch['u'], batch['r'], batch['avail_u'], \
                                                  batch['avail_u_next'], batch['terminated']
        mask = 1 - batch["padded"].float()  # 用来把那些填充的经验的TD-error置0,从而不让它们影响到学习
        if self.args.cuda:
            u = u.cuda()
            r = r.cuda()
            mask = mask.cuda()
            terminated = terminated.cuda()
        # 得到每个agent对应的Q值,维度为(episode个数, max_episode_len, n_agents,n_actions)
        q_evals, q_targets = self.get_q_values(batch, max_episode_len)

        # 取每个agent动作对应的Q值,并且把最后不需要的一维去掉,因为最后一维只有一个值了
        q_evals = torch.gather(q_evals, dim=3, index=u).squeeze(3)

        # 得到target_q
        q_targets[avail_u_next == 0.0] = -9999999
        q_targets = q_targets.max(dim=3)[0]

        q_total_eval = self.eval_vdn_net(q_evals)
        q_total_target = self.target_vdn_net(q_targets)

        targets = r + self.args.gamma * q_total_target * (1 - terminated)

        td_error = targets.detach() - q_total_eval
        masked_td_error = mask * td_error  # 抹掉填充的经验的td_error

        # loss = masked_td_error.pow(2).mean()
        # 不能直接用mean,因为还有许多经验是没用的,所以要求和再比真实的经验数,才是真正的均值
        loss = (masked_td_error**2).sum() / mask.sum()
        # print('Loss is ', loss)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.eval_parameters,
                                       self.args.grad_norm_clip)
        self.optimizer.step()

        if train_step > 0 and train_step % self.args.target_update_cycle == 0:
            self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
            self.target_vdn_net.load_state_dict(self.eval_vdn_net.state_dict())

    def _get_inputs(self, batch, transition_idx):
        # 取出所有episode上该transition_idx的经验,u_onehot要取出所有,因为要用到上一条
        obs, obs_next, u_onehot = batch['o'][:, transition_idx], \
                                  batch['o_next'][:, transition_idx], batch['u_onehot'][:]
        episode_num = obs.shape[0]
        inputs, inputs_next = [], []
        inputs.append(obs)
        inputs_next.append(obs_next)

        # 给obs添加上一个动作、agent编号
        if self.args.last_action:
            if transition_idx == 0:  # 如果是第一条经验,就让前一个动作为0向量
                inputs.append(torch.zeros_like(u_onehot[:, transition_idx]))
            else:
                inputs.append(u_onehot[:, transition_idx - 1])
            inputs_next.append(u_onehot[:, transition_idx])
        if self.args.reuse_network:
            # 因为当前的obs三维的数据,每一维分别代表(episode,agent,obs维度),直接在dim_1上添加对应的向量
            # 即可,比如给agent_0后面加(1, 0, 0, 0, 0),表示5个agent中的0号。而agent_0的数据正好在第0行,那么需要加的
            # agent编号恰好就是一个单位矩阵,即对角线为1,其余为0
            inputs.append(
                torch.eye(self.args.n_agents).unsqueeze(0).expand(
                    episode_num, -1, -1))
            inputs_next.append(
                torch.eye(self.args.n_agents).unsqueeze(0).expand(
                    episode_num, -1, -1))
        # 要把obs中的三个拼起来,并且要把episode_num个episode、self.args.n_agents个agent的数据拼成episode_num*n_agents条数据
        # 因为这里所有agent共享一个神经网络,每条数据中带上了自己的编号,所以还是自己的数据
        inputs = torch.cat(
            [x.reshape(episode_num * self.args.n_agents, -1) for x in inputs],
            dim=1)
        inputs_next = torch.cat([
            x.reshape(episode_num * self.args.n_agents, -1)
            for x in inputs_next
        ],
                                dim=1)
        return inputs, inputs_next

    def get_q_values(self, batch, max_episode_len):
        episode_num = batch['o'].shape[0]
        q_evals, q_targets = [], []
        for transition_idx in range(max_episode_len):
            inputs, inputs_next = self._get_inputs(
                batch, transition_idx)  # 给obs加last_action、agent_id
            if self.args.cuda:
                inputs = inputs.cuda()
                inputs_next = inputs_next.cuda()
                self.eval_hidden = self.eval_hidden.cuda()
                self.target_hidden = self.target_hidden.cuda()
            q_eval, self.eval_hidden = self.eval_rnn(
                inputs, self.eval_hidden
            )  # 得到的q_eval维度为(episode_num*n_agents, n_actions)
            q_target, self.target_hidden = self.target_rnn(
                inputs_next, self.target_hidden)

            # 把q_eval维度重新变回(episode_num, n_agents, n_actions)
            q_eval = q_eval.view(episode_num, self.n_agents, -1)
            q_target = q_target.view(episode_num, self.n_agents, -1)
            q_evals.append(q_eval)
            q_targets.append(q_target)
        # 得的q_eval和q_target是一个列表,列表里装着max_episode_len个数组,数组的的维度是(episode个数, n_agents,n_actions)
        # 把该列表转化成(episode个数, max_episode_len, n_agents,n_actions)的数组
        q_evals = torch.stack(q_evals, dim=1)
        q_targets = torch.stack(q_targets, dim=1)
        return q_evals, q_targets

    def init_hidden(self, episode_num):
        # 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden
        self.eval_hidden = torch.zeros(
            (episode_num, self.n_agents, self.args.rnn_hidden_dim))
        self.target_hidden = torch.zeros(
            (episode_num, self.n_agents, self.args.rnn_hidden_dim))

    def save_model(self, train_step):
        num = str(train_step // self.args.save_cycle)
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        torch.save(self.eval_vdn_net.state_dict(),
                   self.model_dir + '/' + num + '_vdn_net_params.pkl')
        torch.save(self.eval_rnn.state_dict(),
                   self.model_dir + '/' + num + '_rnn_net_params.pkl')
Example #2
0
class QtranAlt:
    def __init__(self, args):
        self.n_actions = args.n_actions
        self.n_agents = args.n_agents
        self.state_shape = args.state_shape
        self.obs_shape = args.obs_shape
        rnn_input_shape = self.obs_shape

        # 根据参数决定RNN的输入维度
        if args.last_action:
            rnn_input_shape += self.n_actions  # 当前agent的上一个动作的one_hot向量
        if args.reuse_network:
            rnn_input_shape += self.n_agents
        self.args = args
        # 神经网络
        self.eval_rnn = RNN(rnn_input_shape, args)  # individual networks
        self.target_rnn = RNN(rnn_input_shape, args)

        self.eval_joint_q = QtranQAlt(args)  # counterfactual joint networks
        self.target_joint_q = QtranQAlt(args)
        self.v = QtranV(args)

        if self.args.cuda:
            self.eval_rnn.cuda()
            self.target_rnn.cuda()
            self.eval_joint_q.cuda()
            self.target_joint_q.cuda()
            self.v.cuda()

        self.model_dir = args.model_dir + '/' + args.alg + '/' + args.map
        # 如果存在模型则加载模型
        if self.args.load_model:
            if os.path.exists(self.model_dir + '/rnn_net_params.pkl'):
                path_rnn = self.model_dir + '/rnn_net_params.pkl'
                path_joint_q = self.model_dir + '/joint_q_params.pkl'
                path_v = self.model_dir + '/v_params.pkl'
                map_location = 'cuda:0' if self.args.cuda else 'cpu'
                self.eval_rnn.load_state_dict(
                    torch.load(path_rnn, map_location=map_location))
                self.eval_joint_q.load_state_dict(
                    torch.load(path_joint_q, map_location=map_location))
                self.v.load_state_dict(
                    torch.load(path_v, map_location=map_location))
                print('Successfully load the model: {}, {} and {}'.format(
                    path_rnn, path_joint_q, path_v))
            else:
                raise Exception("No model!")

        # 让target_net和eval_net的网络参数相同
        self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
        self.target_joint_q.load_state_dict(self.eval_joint_q.state_dict())

        self.eval_parameters = list(self.eval_joint_q.parameters()) + \
                               list(self.v.parameters()) + \
                               list(self.eval_rnn.parameters())
        if args.optimizer == "RMS":
            self.optimizer = torch.optim.RMSprop(self.eval_parameters,
                                                 lr=args.lr)

        # 执行过程中,要为每个agent都维护一个eval_hidden
        # 学习过程中,要为每个episode的每个agent都维护一个eval_hidden、target_hidden
        self.eval_hidden = None
        self.target_hidden = None
        print('Init alg QTRAN-alt')

    def learn(self,
              batch,
              max_episode_len,
              train_step,
              epsilon=None):  # train_step表示是第几次学习,用来控制更新target_net网络的参数
        '''
        在learn的时候,抽取到的数据是四维的,四个维度分别为 1——第几个episode 2——episode中第几个transition
        3——第几个agent的数据 4——具体obs维度。因为在选动作时不仅需要输入当前的inputs,还要给神经网络输入hidden_state,
        hidden_state和之前的经验相关,因此就不能随机抽取经验进行学习。所以这里一次抽取多个episode,然后一次给神经网络
        传入每个episode的同一个位置的transition
        '''
        episode_num = batch['o'].shape[0]
        self.init_hidden(episode_num)
        for key in batch.keys():  # 把batch里的数据转化成tensor
            if key == 'u':
                batch[key] = torch.tensor(batch[key], dtype=torch.long)
            else:
                batch[key] = torch.tensor(batch[key], dtype=torch.float32)
        s, s_next, u, r, avail_u, avail_u_next, terminated = batch['s'], batch['s_next'], batch['u'], \
                                                             batch['r'],  batch['avail_u'], batch['avail_u_next'],\
                                                             batch['terminated']
        mask = 1 - batch["padded"].float().repeat(
            1, 1, self.n_agents)  # 用来把那些填充的经验的TD-error置0,从而不让它们影响到学习
        if self.args.cuda:
            u = u.cuda()
            r = r.cuda()
            avail_u = avail_u.cuda()
            avail_u_next = avail_u_next.cuda()
            terminated = terminated.cuda()
            mask = mask.cuda()
        # 得到每个agent对应的Q和hidden_states,维度为(episode个数, max_episode_len, n_agents, n_actions/hidden_dim)
        individual_q_evals, individual_q_targets, hidden_evals, hidden_targets = self._get_individual_q(
            batch, max_episode_len)

        # 得到当前时刻和下一时刻每个agent的局部最优动作及其one_hot表示
        individual_q_clone = individual_q_evals.clone()
        individual_q_clone[avail_u == 0.0] = -999999
        individual_q_targets[avail_u_next == 0.0] = -999999

        opt_onehot_eval = torch.zeros(*individual_q_clone.shape)
        opt_action_eval = individual_q_clone.argmax(dim=3, keepdim=True)
        opt_onehot_eval = opt_onehot_eval.scatter(-1,
                                                  opt_action_eval[:, :].cpu(),
                                                  1)

        opt_onehot_target = torch.zeros(*individual_q_targets.shape)
        opt_action_target = individual_q_targets.argmax(dim=3, keepdim=True)
        opt_onehot_target = opt_onehot_target.scatter(
            -1, opt_action_target[:, :].cpu(), 1)

        # ---------------------------------------------L_td-------------------------------------------------------------

        # 计算joint_q和v,要注意joint_q是每个agent都有,v只有一个
        # joint_q的维度为(episode个数, max_episode_len, n_agents, n_actions), 而且joint_q在后面的l_nopt还要用到
        # v的维度为(episode个数, max_episode_len)
        joint_q_evals, joint_q_targets, v = self.get_qtran(
            batch, opt_onehot_target, hidden_evals, hidden_targets)

        # 取出当前agent动作对应的joint_q_chosen以及它的局部最优动作对应的joint_q
        joint_q_chosen = torch.gather(joint_q_evals, dim=-1, index=u).squeeze(
            -1)  # (episode个数, max_episode_len, n_agents)
        joint_q_opt = torch.gather(joint_q_targets,
                                   dim=-1,
                                   index=opt_action_target).squeeze(-1)

        # loss
        y_dqn = r.repeat(1, 1,
                         self.n_agents) + self.args.gamma * joint_q_opt * (
                             1 - terminated.repeat(1, 1, self.n_agents))
        td_error = joint_q_chosen - y_dqn.detach()
        l_td = ((td_error * mask)**2).sum() / mask.sum()
        # ---------------------------------------------L_td-------------------------------------------------------------

        # ---------------------------------------------L_opt------------------------------------------------------------

        # 将局部最优动作的Q值相加  (episode个数,max_episode_len)
        # 这里要使用individual_q_clone,它把不能执行的动作Q值改变了,使用individual_q_evals可能会使用不能执行的动作的Q值
        q_sum_opt = individual_q_clone.max(dim=-1)[0].sum(dim=-1)

        # 重新得到joint_q_opt_eval,它和joint_q_evals的区别是前者输入的动作是当前局部最优动作,后者输入的动作是当前执行的动作
        joint_q_opt_evals, _, _ = self.get_qtran(batch,
                                                 opt_onehot_eval,
                                                 hidden_evals,
                                                 hidden_targets,
                                                 hat=True)
        joint_q_opt_evals = torch.gather(
            joint_q_opt_evals, dim=-1, index=opt_action_eval).squeeze(
                -1)  # (episode个数, max_episode_len, n_agents)

        # 因为QTRAN-alt要对每个agent都计算l_opt,所以要把q_sum_opt和v再增加一个agent维
        q_sum_opt = q_sum_opt.unsqueeze(-1).expand(-1, -1, self.n_agents)
        v = v.unsqueeze(-1).expand(-1, -1, self.n_agents)
        opt_error = q_sum_opt - joint_q_opt_evals.detach(
        ) + v  # 计算l_opt时需要将joint_q_opt_evals固定
        l_opt = ((opt_error * mask)**2).sum() / mask.sum()

        # ---------------------------------------------L_opt------------------------------------------------------------

        # ---------------------------------------------L_nopt-----------------------------------------------------------
        # 因为L_nopt约束的是当前agent所有可执行的动作中,对应的最小的d,为了让不能执行的动作不影响d的计算,将不能执行的动作对应的q变大
        individual_q_evals[avail_u == 0.0] = 999999

        # 得到agent_i之外的其他agent的执行动作的Q值之和q_other_sum
        #   1. 先得到每个agent的执行动作的Q值q_all,(episode个数, max_episode_len, n_agents, 1)
        q_all_chosen = torch.gather(individual_q_evals, dim=-1, index=u)
        #   2. 把q_all最后一个维度上当前agent的Q值变成所有agent的Q值,(episode个数, max_episode_len, n_agents, n_agents)
        q_all_chosen = q_all_chosen.view((episode_num, max_episode_len, 1,
                                          -1)).repeat(1, 1, self.n_agents, 1)
        q_mask = (1 - torch.eye(self.n_agents)).unsqueeze(0).unsqueeze(0)
        if self.args.cuda:
            q_mask = q_mask.cuda()
        q_other_chosen = q_all_chosen * q_mask  # 把每个agent自己的Q值置为0,从而才能相加得到其他agent的Q值之和
        #   3. 求和,同时由于对于当前agent的每个动作,都要和q_other_sum相加,所以把q_other_sum扩展出n_actions维度
        q_other_sum = q_other_chosen.sum(dim=-1, keepdim=True).repeat(
            1, 1, 1, self.n_actions)

        # 当前agent的每个动作的Q和其他agent执行动作的Q相加,得到D中的第一项
        q_sum_nopt = individual_q_evals + q_other_sum

        # 因为joint_q_evals的维度是(episode个数,max_episode_len,n_agents,n_actions),所以要对v扩展出一个n_actions维度
        v = v.unsqueeze(-1).expand(-1, -1, -1, self.n_actions)
        d = q_sum_nopt - joint_q_evals.detach(
        ) + v  # 计算l_nopt时需要将qtran_q_evals固定
        d = d.min(dim=-1)[0]
        l_nopt = ((d * mask)**2).sum() / mask.sum()
        # ---------------------------------------------L_nopt-----------------------------------------------------------

        # print('l_td is {}, l_opt is {}, l_nopt is {}'.format(l_td, l_opt, l_nopt))
        loss = l_td + self.args.lambda_opt * l_opt + self.args.lambda_nopt * l_nopt
        # loss = l_td + self.args.lambda_opt * l_opt
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.eval_parameters,
                                       self.args.grad_norm_clip)
        self.optimizer.step()

        if train_step > 0 and train_step % self.args.target_update_cycle == 0:
            self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
            self.target_joint_q.load_state_dict(self.eval_joint_q.state_dict())

    def _get_individual_q(self, batch, max_episode_len):
        episode_num = batch['o'].shape[0]
        q_evals, q_targets, hidden_evals, hidden_targets = [], [], [], []
        for transition_idx in range(max_episode_len):
            inputs, inputs_next = self._get_individual_inputs(
                batch, transition_idx)  # 给obs加last_action、agent_id
            if self.args.cuda:
                inputs = inputs.cuda()
                self.eval_hidden = self.eval_hidden.cuda()
                inputs_next = inputs_next.cuda()
                self.target_hidden = self.target_hidden.cuda()
            q_eval, self.eval_hidden = self.eval_rnn(inputs, self.eval_hidden)
            q_target, self.target_hidden = self.target_rnn(
                inputs_next, self.target_hidden)
            hidden_eval, hidden_target = self.eval_hidden.clone(
            ), self.target_hidden.clone()

            # 把q_eval维度重新变回(8, 5,n_actions)
            q_eval = q_eval.view(episode_num, self.n_agents, -1)
            q_target = q_target.view(episode_num, self.n_agents, -1)
            hidden_eval = hidden_eval.view(episode_num, self.n_agents, -1)
            hidden_target = hidden_target.view(episode_num, self.n_agents, -1)
            q_evals.append(q_eval)
            q_targets.append(q_target)
            hidden_evals.append(hidden_eval)
            hidden_targets.append(hidden_target)
        # 得的q_eval和q_target是一个列表,列表里装着max_episode_len个数组,数组的的维度是(episode个数, n_agents,n_actions)
        # 把该列表转化成(episode个数, max_episode_len, n_agents,n_actions)的数组
        q_evals = torch.stack(q_evals, dim=1)
        q_targets = torch.stack(q_targets, dim=1)
        hidden_evals = torch.stack(hidden_evals, dim=1)
        hidden_targets = torch.stack(hidden_targets, dim=1)
        return q_evals, q_targets, hidden_evals, hidden_targets

    def _get_individual_inputs(self, batch, transition_idx):
        # 取出所有episode上该transition_idx的经验,u_onehot要取出所有,因为要用到上一条
        obs, obs_next, u_onehot = batch['o'][:, transition_idx], \
                                  batch['o_next'][:, transition_idx], batch['u_onehot'][:]
        episode_num = obs.shape[0]
        inputs, inputs_next = [], []
        inputs.append(obs)
        inputs_next.append(obs_next)
        # 给obs添加上一个动作、agent编号
        if self.args.last_action:
            if transition_idx == 0:  # 如果是第一条经验,就让前一个动作为0向量
                inputs.append(torch.zeros_like(u_onehot[:, transition_idx]))
            else:
                inputs.append(u_onehot[:, transition_idx - 1])
            inputs_next.append(u_onehot[:, transition_idx])
        if self.args.reuse_network:
            # 因为当前的obs三维的数据,每一维分别代表(episode编号,agent编号,obs维度),直接在dim_1上添加对应的向量
            # 即可,比如给agent_0后面加(1, 0, 0, 0, 0),表示5个agent中的0号。而agent_0的数据正好在第0行,那么需要加的
            # agent编号恰好就是一个单位矩阵,即对角线为1,其余为0
            inputs.append(
                torch.eye(self.args.n_agents).unsqueeze(0).expand(
                    episode_num, -1, -1))
            inputs_next.append(
                torch.eye(self.args.n_agents).unsqueeze(0).expand(
                    episode_num, -1, -1))
        # 要把obs中的三个拼起来,并且要把episode_num个episode、self.args.n_agents个agent的数据拼成40条(40,96)的数据,
        # 因为这里所有agent共享一个神经网络,每条数据中带上了自己的编号,所以还是自己的数据
        inputs = torch.cat(
            [x.reshape(episode_num * self.args.n_agents, -1) for x in inputs],
            dim=1)
        inputs_next = torch.cat([
            x.reshape(episode_num * self.args.n_agents, -1)
            for x in inputs_next
        ],
                                dim=1)
        return inputs, inputs_next

    def get_qtran(self,
                  batch,
                  local_opt_actions,
                  hidden_evals,
                  hidden_targets=None,
                  hat=False):
        episode_num, max_episode_len, _, _ = hidden_evals.shape
        s = batch['s'][:, :max_episode_len]
        s_next = batch['s_next'][:, :max_episode_len]
        u_onehot = batch['u_onehot'][:, :max_episode_len]
        v_state = s.clone()

        # s和s_next没有n_agents维度,每个agent的joint_q网络都需要, 所以要把s转化成四维
        s = s.unsqueeze(-2).expand(-1, -1, self.n_agents, -1)
        s_next = s_next.unsqueeze(-2).expand(-1, -1, self.n_agents, -1)
        # 添加agent编号对应的one-hot向量
        '''
        因为当前的inputs三维的数据,每一维分别代表(episode编号,agent编号,inputs维度),直接在后面添加对应的向量
        即可,比如给agent_0后面加(1, 0, 0, 0, 0),表示5个agent中的0号。而agent_0的数据正好在第0行,那么需要加的
        agent编号恰好就是一个单位矩阵,即对角线为1,其余为0
        '''
        action_onehot = torch.eye(
            self.n_agents).unsqueeze(0).unsqueeze(0).expand(
                episode_num, max_episode_len, -1, -1)
        s_eval = torch.cat([s, action_onehot], dim=-1)
        s_target = torch.cat([s_next, action_onehot], dim=-1)
        if self.args.cuda:
            s_eval = s_eval.cuda()
            s_target = s_target.cuda()
            v_state = v_state.cuda()
            u_onehot = u_onehot.cuda()
            hidden_evals = hidden_evals.cuda()
            hidden_targets = hidden_targets.cuda()
            local_opt_actions = local_opt_actions.cuda()
        if hat:
            # 神经网络输出的q_eval、q_target的维度为(episode_num * max_episode_len * n_agents, n_actions)
            q_evals = self.eval_joint_q(s_eval, hidden_evals,
                                        local_opt_actions)
            q_targets = None
            v = None

            # 把q_eval维度变回(episode_num, max_episode_len, n_agents, n_actions)
            q_evals = q_evals.view(episode_num, max_episode_len, -1,
                                   self.n_actions)
        else:
            q_evals = self.eval_joint_q(s_eval, hidden_evals, u_onehot)
            q_targets = self.target_joint_q(s_target, hidden_targets,
                                            local_opt_actions)
            v = self.v(v_state, hidden_evals)
            # 把q_eval、q_target维度变回(episode_num, max_episode_len, n_agents, n_actions)
            q_evals = q_evals.view(episode_num, max_episode_len, -1,
                                   self.n_actions)
            q_targets = q_targets.view(episode_num, max_episode_len, -1,
                                       self.n_actions)
            # 把v维度变回(episode_num, max_episode_len)
            v = v.view(episode_num, -1)

        return q_evals, q_targets, v

    def init_hidden(self, episode_num):
        # 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden
        self.eval_hidden = torch.zeros(
            (episode_num, self.n_agents, self.args.rnn_hidden_dim))
        self.target_hidden = torch.zeros(
            (episode_num, self.n_agents, self.args.rnn_hidden_dim))

    def save_model(self, train_step):
        num = str(train_step // self.args.save_cycle)
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

        torch.save(self.eval_rnn.state_dict(),
                   self.model_dir + '/' + num + '_rnn_net_params.pkl')
        torch.save(self.eval_joint_q.state_dict(),
                   self.model_dir + '/' + num + '_joint_q_params.pkl')
        torch.save(self.v.state_dict(),
                   self.model_dir + '/' + num + '_v_params.pkl')
class QTran:
    def __init__(self, args, itr):
        self.n_actions = args.n_actions
        self.n_agents = args.n_agents
        self.state_shape = args.state_shape
        self.obs_shape = args.obs_shape
        self.args = args

        rnn_input_shape = self.obs_shape

        # 根据参数决定RNN的输入维度
        if args.last_action:
            rnn_input_shape += self.n_actions  # 当前agent的上一个动作的one_hot向量
        if args.reuse_network:
            rnn_input_shape += self.n_agents

        # 神经网络
        self.eval_rnn = RNN(rnn_input_shape, args)  # 每个agent选动作的网络
        self.target_rnn = RNN(rnn_input_shape, args)

        # self.eval_joint_q = QtranQBase(args)  # Joint action-value network
        # self.target_joint_q = QtranQBase(args)
        # 默认值分解算法使用 qtran_base
        if self.args.alg == 'qtran_base':
            self.eval_joint_q = QtranQBase(args)
            self.target_joint_q = QtranQBase(args)
        elif self.args.alg == 'qtran_alt':
            # self.mixer = QTranAlt(args)
            # self.target_mixer = QTranAlt(args)
            raise Exception("Not supported yet.")
        else:
            raise Exception('QTRAN只有qtran_base!')

        self.v = QtranV(args)

        if self.args.cuda:
            self.eval_rnn.cuda()
            self.target_rnn.cuda()
            self.eval_joint_q.cuda()
            self.target_joint_q.cuda()
            self.v.cuda()

        # self.model_dir = args.model_dir + '/' + args.alg + '/' + args.map + '/' + str(itr)
        # self.model_dir = args.model_dir + '/' + args.map + '/' + args.alg + '_' + str(self.args.epsilon_anneal_steps // 10000) + 'w' + '/' + str(itr)
        # 如果存在模型则加载模型
        if self.args.load_model:
            if os.path.exists(self.args.model_dir + '/rnn_net_params.pkl'):
                path_rnn = self.args.model_dir + '/rnn_net_params.pkl'
                path_joint_q = self.args.model_dir + '/joint_q_params.pkl'
                path_v = self.args.model_dir + '/v_params.pkl'
                map_location = 'cuda:0' if self.args.cuda else 'cpu'
                self.eval_rnn.load_state_dict(torch.load(path_rnn, map_location=map_location))
                self.eval_joint_q.load_state_dict(torch.load(path_joint_q, map_location=map_location))
                self.v.load_state_dict(torch.load(path_v, map_location=map_location))
                print('Successfully load the model: {}, {} and {}'.format(path_rnn, path_joint_q, path_v))
            else:
                raise Exception("No model!")

        # 让target_net和eval_net的网络参数相同
        self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
        self.target_joint_q.load_state_dict(self.eval_joint_q.state_dict())

        self.eval_parameters = list(self.eval_joint_q.parameters()) + \
                               list(self.v.parameters()) + \
                               list(self.eval_rnn.parameters())
        if args.optim == "RMS":
            self.optimizer = torch.optim.RMSprop(self.eval_parameters, lr=args.lr)

        # 执行过程中,要为每个agent都维护一个eval_hidden
        # 学习过程中,要为每个episode的每个agent都维护一个eval_hidden、target_hidden
        self.eval_hidden = None
        self.target_hidden = None
        print('Init alg QTRAN-base')

    def learn(self, batch, max_episode_len, train_step, epsilon=None):  # train_step表示是第几次学习,用来控制更新target_net网络的参数
        '''
        在learn的时候,抽取到的数据是四维的,四个维度分别为 1——第几个episode 2——episode中第几个transition
        3——第几个agent的数据 4——具体obs维度。因为在选动作时不仅需要输入当前的inputs,还要给神经网络输入hidden_state,
        hidden_state和之前的经验相关,因此就不能随机抽取经验进行学习。所以这里一次抽取多个episode,然后一次给神经网络
        传入每个episode的同一个位置的transition
        '''
        episode_num = batch['o'].shape[0]
        self.init_hidden(episode_num)
        # for key in batch.keys():  # 把batch里的数据转化成tensor
        #     if key == 'u':
        #         batch[key] = torch.tensor(batch[key], dtype=torch.long)
        #     else:
        #         batch[key] = torch.tensor(batch[key], dtype=torch.float32)
        # 数据转为tensor
        for key in batch.keys():
            if key == 'a':
                batch[key] = torch.LongTensor(batch[key])
            else:
                batch[key] = torch.Tensor(batch[key])
        u, r, avail_u, avail_u_next, terminated = batch['a'], batch['r'],  batch['avail_a'], \
                                                  batch['next_avail_a'], batch['done']
        mask = (1 - batch["padded"].float()).squeeze(-1)  # 用来把那些填充的经验的TD-error置0,从而不让它们影响到学习
        if self.args.cuda:
            u = u.cuda()
            r = r.cuda()
            avail_u = avail_u.cuda()
            avail_u_next = avail_u_next.cuda()
            terminated = terminated.cuda()
            mask = mask.cuda()
        # 得到每个agent对应的Q和hidden_states,维度为(episode个数, max_episode_len, n_agents, n_actions/hidden_dim)
        individual_q_evals, individual_q_targets, hidden_evals, hidden_targets = self._get_individual_q(batch, max_episode_len)

        # 得到当前时刻和下一时刻每个agent的局部最优动作及其one_hot表示
        individual_q_clone = individual_q_evals.clone()
        individual_q_clone[avail_u == 0.0] = - 999999
        individual_q_targets[avail_u_next == 0.0] = - 999999

        opt_onehot_eval = torch.zeros(*individual_q_clone.shape)
        opt_action_eval = individual_q_clone.argmax(dim=3, keepdim=True)
        opt_onehot_eval = opt_onehot_eval.scatter(-1, opt_action_eval[:, :].cpu(), 1)

        opt_onehot_target = torch.zeros(*individual_q_targets.shape)
        opt_action_target = individual_q_targets.argmax(dim=3, keepdim=True)
        opt_onehot_target = opt_onehot_target.scatter(-1, opt_action_target[:, :].cpu(), 1)

        # ---------------------------------------------L_td-------------------------------------------------------------
        # 计算joint_q和v
        # joint_q、v的维度为(episode个数, max_episode_len, 1), 而且joint_q在后面的l_nopt还要用到
        joint_q_evals, joint_q_targets, v = self.get_qtran(batch, hidden_evals, hidden_targets, opt_onehot_target)

        # loss
        y_dqn = r.squeeze(-1) + self.args.gamma * joint_q_targets * (1 - terminated.squeeze(-1))
        td_error = joint_q_evals - y_dqn.detach()
        l_td = ((td_error * mask) ** 2).sum() / mask.sum()
        # ---------------------------------------------L_td-------------------------------------------------------------

        # ---------------------------------------------L_opt------------------------------------------------------------
        # 将局部最优动作的Q值相加
        # 这里要使用individual_q_clone,它把不能执行的动作Q值改变了,使用individual_q_evals可能会使用不能执行的动作的Q值
        q_sum_opt = individual_q_clone.max(dim=-1)[0].sum(dim=-1)  # (episode个数, max_episode_len)

        # 重新得到joint_q_hat_opt,它和joint_q_evals的区别是前者输入的动作是局部最优动作,后者输入的动作是执行的动作
        # (episode个数, max_episode_len)
        joint_q_hat_opt, _, _ = self.get_qtran(batch, hidden_evals, hidden_targets, opt_onehot_eval, hat=True)
        opt_error = q_sum_opt - joint_q_hat_opt.detach() + v  # 计算l_opt时需要将joint_q_hat_opt固定
        l_opt = ((opt_error * mask) ** 2).sum() / mask.sum()
        # ---------------------------------------------L_opt------------------------------------------------------------

        # ---------------------------------------------L_nopt-----------------------------------------------------------
        # 每个agent的执行动作的Q值,(episode个数, max_episode_len, n_agents, 1)
        q_individual = torch.gather(individual_q_evals, dim=-1, index=u).squeeze(-1)
        q_sum_nopt = q_individual.sum(dim=-1)  # (episode个数, max_episode_len)

        nopt_error = q_sum_nopt - joint_q_evals.detach() + v  # 计算l_nopt时需要将joint_q_evals固定
        nopt_error = nopt_error.clamp(max=0)
        l_nopt = ((nopt_error * mask) ** 2).sum() / mask.sum()
        # ---------------------------------------------L_nopt-----------------------------------------------------------

        # print('l_td is {}, l_opt is {}, l_nopt is {}'.format(l_td, l_opt, l_nopt))
        loss = l_td + self.args.lambda_opt * l_opt + self.args.lambda_nopt * l_nopt
        # loss = l_td + self.args.lambda_opt * l_opt
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.eval_parameters, self.args.clip_norm)
        self.optimizer.step()

        if train_step > 0 and train_step % self.args.target_update_period == 0:
            self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
            self.target_joint_q.load_state_dict(self.eval_joint_q.state_dict())

    def _get_individual_q(self, batch, max_episode_len):
        episode_num = batch['o'].shape[0]
        q_evals, q_targets, hidden_evals, hidden_targets = [], [], [], []
        for transition_idx in range(max_episode_len):
            inputs, inputs_next = self._get_individual_inputs(batch, transition_idx)  # 给obs加last_action、agent_id
            if self.args.cuda:
                inputs = inputs.cuda()
                inputs_next = inputs_next.cuda()
                self.eval_hidden = self.eval_hidden.cuda()
                self.target_hidden = self.target_hidden.cuda()

            # 要用第一条经验把target网络的hidden_state初始化好,直接用第二条经验传入target网络不对
            if transition_idx == 0:
                _, self.target_hidden = self.target_rnn(inputs, self.eval_hidden)
            q_eval, self.eval_hidden = self.eval_rnn(inputs, self.eval_hidden)  # inputs维度为(40,96),得到的q_eval维度为(40,n_actions)
            q_target, self.target_hidden = self.target_rnn(inputs_next, self.target_hidden)
            hidden_eval, hidden_target = self.eval_hidden.clone(), self.target_hidden.clone()

            # 把q_eval维度重新变回(8, 5,n_actions)
            q_eval = q_eval.view(episode_num, self.n_agents, -1)
            q_target = q_target.view(episode_num, self.n_agents, -1)
            hidden_eval = hidden_eval.view(episode_num, self.n_agents, -1)
            hidden_target = hidden_target.view(episode_num, self.n_agents, -1)
            q_evals.append(q_eval)
            q_targets.append(q_target)
            hidden_evals.append(hidden_eval)
            hidden_targets.append(hidden_target)
        # 得的q_eval和q_target是一个列表,列表里装着max_episode_len个数组,数组的的维度是(episode个数, n_agents,n_actions)
        # 把该列表转化成(episode个数, max_episode_len, n_agents,n_actions)的数组
        q_evals = torch.stack(q_evals, dim=1)
        q_targets = torch.stack(q_targets, dim=1)
        hidden_evals = torch.stack(hidden_evals, dim=1)
        hidden_targets = torch.stack(hidden_targets, dim=1)
        return q_evals, q_targets, hidden_evals, hidden_targets

    def _get_individual_inputs(self, batch, transition_idx):
        # 取出所有episode上该transition_idx的经验,u_onehot要取出所有,因为要用到上一条
        obs, obs_next, u_onehot = batch['o'][:, transition_idx], \
                                  batch['next_o'][:, transition_idx], batch['onehot_a'][:]
        episode_num = obs.shape[0]
        inputs, inputs_next = [], []
        inputs.append(obs)
        inputs_next.append(obs_next)
        # 给obs添加上一个动作、agent编号
        if self.args.last_action:
            if transition_idx == 0:  # 如果是第一条经验,就让前一个动作为0向量
                inputs.append(torch.zeros_like(u_onehot[:, transition_idx]))
            else:
                inputs.append(u_onehot[:, transition_idx - 1])
            inputs_next.append(u_onehot[:, transition_idx])
        if self.args.reuse_network:
            # 因为当前的obs三维的数据,每一维分别代表(episode编号,agent编号,obs维度),直接在dim_1上添加对应的向量
            # 即可,比如给agent_0后面加(1, 0, 0, 0, 0),表示5个agent中的0号。而agent_0的数据正好在第0行,那么需要加的
            # agent编号恰好就是一个单位矩阵,即对角线为1,其余为0
            inputs.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))
            inputs_next.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))
        # 要把obs中的三个拼起来,并且要把episode_num个episode、self.args.n_agents个agent的数据拼成40条(40,96)的数据,
        # 因为这里所有agent共享一个神经网络,每条数据中带上了自己的编号,所以还是自己的数据
        inputs = torch.cat([x.reshape(episode_num * self.args.n_agents, -1) for x in inputs], dim=1)
        inputs_next = torch.cat([x.reshape(episode_num * self.args.n_agents, -1) for x in inputs_next], dim=1)
        return inputs, inputs_next

    def get_qtran(self, batch, hidden_evals, hidden_targets, local_opt_actions, hat=False):
        episode_num, max_episode_len, _, _ = hidden_targets.shape
        states = batch['s'][:, :max_episode_len]
        states_next = batch['next_s'][:, :max_episode_len]
        u_onehot = batch['onehot_a'][:, :max_episode_len]
        if self.args.cuda:
            states = states.cuda()
            states_next = states_next.cuda()
            u_onehot = u_onehot.cuda()
            hidden_evals = hidden_evals.cuda()
            hidden_targets = hidden_targets.cuda()
            local_opt_actions = local_opt_actions.cuda()
        if hat:
            # 神经网络输出的q_eval、q_target、v的维度为(episode_num * max_episode_len, 1)
            q_evals = self.eval_joint_q(states, hidden_evals, local_opt_actions)
            q_targets = None
            v = None

            # 把q_eval维度变回(episode_num, max_episode_len)
            q_evals = q_evals.view(episode_num, -1, 1).squeeze(-1)
        else:
            q_evals = self.eval_joint_q(states, hidden_evals, u_onehot)
            q_targets = self.target_joint_q(states_next, hidden_targets, local_opt_actions)
            v = self.v(states, hidden_evals)
            # 把q_eval、q_target、v维度变回(episode_num, max_episode_len)
            q_evals = q_evals.view(episode_num, -1, 1).squeeze(-1)
            q_targets = q_targets.view(episode_num, -1, 1).squeeze(-1)
            v = v.view(episode_num, -1, 1).squeeze(-1)

        return q_evals, q_targets, v

    def init_hidden(self, episode_num):
        # 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden
        self.eval_hidden = torch.zeros((episode_num, self.n_agents, self.args.rnn_hidden_dim))
        self.target_hidden = torch.zeros((episode_num, self.n_agents, self.args.rnn_hidden_dim))

    def save_model(self, train_step):
        # num = str(train_step // self.args.save_cycle)
        # if not os.path.exists(self.model_dir):
        #     os.makedirs(self.model_dir)

        if not os.path.exists(self.args.model_dir):
            os.makedirs(self.args.model_dir)

        if type(train_step) == str:
            num = train_step
        else:
            num = str(train_step // self.args.save_model_period)

        torch.save(self.eval_rnn.state_dict(),  self.args.model_dir + '/' + num + '_rnn_net_params.pkl')
        torch.save(self.eval_joint_q.state_dict(), self.args.model_dir + '/' + num + '_joint_q_params.pkl')
        torch.save(self.v.state_dict(), self.args.model_dir + '/' + num + '_v_params.pkl')
Example #4
0
class QMIX_PG():
    def __init__(self, agent, args):
        self.args = args
        self.log_alpha = torch.zeros(
            1, dtype=torch.float32)  # , requires_grad=True)
        if args.cuda:
            self.log_alpha = self.log_alpha.cuda()
        self.log_alpha.requires_grad = True

        self.alpha = self.args.alpha
        self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=3e-4)

        self.agent = agent
        # self.target_policy = copy.deepcopy(self.agent.policy)  # TODO

        self.policy_params = list(agent.policy.parameters())
        self.policy_optimiser = torch.optim.RMSprop(
            params=self.policy_params,
            lr=args.actor_lr)  # , alpha=args.optim_alpha, eps=args.optim_eps)
        self.n_actions = args.n_actions
        self.n_agents = args.n_agents
        self.state_shape = args.state_shape
        self.obs_shape = args.obs_shape
        input_shape = self.obs_shape
        # 根据参数决定RNN的输入维度
        if args.last_action:
            input_shape += self.n_actions
        if args.reuse_network:
            input_shape += self.n_agents

        # 神经网络
        self.eval_rnn = RNN(input_shape, args)  # 每个agent选动作的网络
        self.target_rnn = RNN(input_shape, args)
        self.eval_rnn_2 = RNN(input_shape, args)  # 每个agent选动作的网络
        self.target_rnn_2 = RNN(input_shape, args)
        self.eval_qmix_net = QMixNet(args)  # 把agentsQ值加起来的网络
        self.target_qmix_net = QMixNet(args)
        self.eval_qmix_net_2 = QMixNet(args)  # 把agentsQ值加起来的网络
        self.target_qmix_net_2 = QMixNet(args)
        self.args = args
        if self.args.cuda:
            self.eval_rnn.cuda()
            self.target_rnn.cuda()
            self.eval_rnn_2.cuda()
            self.target_rnn_2.cuda()
            self.eval_qmix_net.cuda()
            self.target_qmix_net.cuda()
            self.eval_qmix_net_2.cuda()
            self.target_qmix_net_2.cuda()

        # self.model_dir = args.model_dir + '/' + args.alg + '/' + args.map

        tmp = f'clamp2-5_rewardscale10_' + f'{args.buffer_size}_{args.actor_buffer_size}_{args.critic_buffer_size}_{args.actor_train_steps}_{args.critic_train_steps}_' \
                                           f'{args.actor_update_delay}_{args.critic_lr}'

        self.model_dir = args.model_dir + '/linear_mix/' + 'qmix_sac_cf' + '/' + tmp + '/' + args.map  # _gradclip0.5
        # 如果存在模型则加载模型
        if self.args.load_model:
            if os.path.exists(self.model_dir + '/rnn_net_params.pkl'):
                path_rnn = self.model_dir + '/rnn_net_params.pkl'
                path_qmix = self.model_dir + '/qmix_net_params.pkl'
                path_policy = self.model_dir + '/policy_net_params.pkl'
                map_location = 'cuda:0' if self.args.cuda else 'cpu'
                self.eval_rnn.load_state_dict(
                    torch.load(path_rnn, map_location=map_location))
                self.eval_rnn_2.load_state_dict(
                    torch.load(path_rnn, map_location=map_location))
                self.eval_qmix_net.load_state_dict(
                    torch.load(path_qmix, map_location=map_location))
                self.eval_qmix_net_2.load_state_dict(
                    torch.load(path_qmix, map_location=map_location))
                # self.target_policy.load_state_dict(torch.load(path_policy, map_location=map_location))  # TODO
                print('Successfully load the model: {} and {}'.format(
                    path_rnn, path_qmix))
            else:
                raise Exception("No model!")

        # 让target_net和eval_net的网络参数相同
        self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
        self.target_rnn_2.load_state_dict(self.eval_rnn.state_dict())
        self.target_qmix_net.load_state_dict(self.eval_qmix_net.state_dict())
        self.target_qmix_net_2.load_state_dict(
            self.eval_qmix_net_2.state_dict())
        # self.target_policy.load_state_dict(self.agent.policy.state_dict())  # TODO

        self.eval_parameters = list(self.eval_qmix_net.parameters()) + list(
            self.eval_rnn.parameters())
        if args.optimizer == "RMS":
            self.optimizer = torch.optim.RMSprop(self.eval_parameters,
                                                 lr=args.critic_lr)

        self.eval_parameters_2 = list(
            self.eval_qmix_net_2.parameters()) + list(
                self.eval_rnn_2.parameters())
        if args.optimizer == "RMS":
            self.optimizer_2 = torch.optim.RMSprop(self.eval_parameters_2,
                                                   lr=args.critic_lr)

        # 执行过程中,要为每个agent都维护一个eval_hidden
        # 学习过程中,要为每个episode的每个agent都维护一个eval_hidden、target_hidden
        self.eval_hidden = None
        self.target_hidden = None
        print('Init alg QMIX')

    def train_actor(self,
                    batch,
                    max_episode_len,
                    actor_sample_times=1):  # EpisodeBatch
        episode_num = batch['o'].shape[0]
        self.init_hidden(episode_num)
        for key in batch.keys():  # 把batch里的数据转化成tensor
            if key == 'u':
                batch[key] = torch.tensor(batch[key], dtype=torch.long)
            else:
                batch[key] = torch.tensor(batch[key], dtype=torch.float32)
        # s, s_next, u, r, avail_u, avail_u_next, terminated = batch['s'], batch['s_next'], batch['u'], \
        #                                                      batch['r'], batch['avail_u'], batch['avail_u_next'], \
        #                                                      batch['terminated']
        s = batch['s']
        terminated = batch['terminated']
        mask = 1 - batch["padded"].float()  # 用来把那些填充的经验的TD-error置0,从而不让它们影响到学习
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        mask = mask.repeat(1, 1, self.n_agents).view(-1)
        # mask = mask.repeat(1, 1, self.n_agents)
        # actions = batch["u"][:, :-1]
        actions = batch["u"]
        avail_u = batch["avail_u"]
        # terminated = batch["terminated"][:, :-1].float()
        # avail_actions = batch["avail_u"][:, :-1]

        if self.args.cuda:
            s = s.cuda()
            # u = u.cuda()
            # r = r.cuda()
            # s_next = s_next.cuda()
            # terminated = terminated.cuda()
            actions = actions.cuda()
            avail_u = avail_u.cuda()
            mask = mask.cuda()

        # # build q
        # inputs = self.critic._build_inputs(batch, bs, max_t)
        # q_vals = self.critic.forward(inputs).detach()[:, :-1]

        episode_num = batch['o'].shape[0]
        q_evals, q_targets = [], []

        actions_prob = []
        actions_logprobs = []
        actions_probs_nozero = []
        # self.agent.init_hidden(batch.batch_size)
        # self.policy_hidden= self.policy.init_hidden().unsqueeze(0).expand(batch_size, self.n_agents, -1)  # bav
        for transition_idx in range(max_episode_len):
            inputs, inputs_next = self._get_inputs(
                batch, transition_idx)  # 给obs加last_action、agent_id
            if self.args.cuda:
                inputs = inputs.cuda()
                inputs_next = inputs_next.cuda()
                self.eval_hidden = self.eval_hidden.cuda()
            q_eval, self.eval_hidden = self.eval_rnn(
                inputs, self.eval_hidden
            )  # inputs维度为(40,96),得到的q_eval维度为(40,n_actions) Q_i
            # 把q_eval维度重新变回(8, 5,n_actions)
            q_eval = q_eval.view(episode_num, self.n_agents, -1)
            q_evals.append(q_eval)

            agent_outs, self.policy_hidden = self.agent.policy(
                inputs, self.policy_hidden)
            avail_actions_ = avail_u[:, transition_idx]
            reshaped_avail_actions = avail_actions_.reshape(
                episode_num * self.n_agents, -1)
            agent_outs[reshaped_avail_actions == 0] = -1e11
            # agent_outs = torch.nn.functional.softmax(agent_outs, dim=-1)
            # agent_outs = gumbel_softmax(agent_outs, hard=False)
            agent_outs = F.softmax(agent_outs / 1, dim=1)  # 概率分布
            # If hard=True, then the returned sample will be one-hot, otherwise it will
            # be a probabilitiy distribution that sums to 1 across classes
            agent_outs = agent_outs.view(episode_num, self.n_agents, -1)
            actions_prob.append(agent_outs)  # 每个动作的概率

            # actions_prob_nozero = agent_outs.clone()
            # actions_prob_nozero[reshaped_avail_actions.view(episode_num, self.n_agents, -1) == 0] = 1e-11  #  # TODO 概率没有0
            # actions_probs_nozero.append(actions_prob_nozero)
            # Have to deal with situation of 0.0 probabilities because we can't do log 0
            z = agent_outs == 0.0
            z = z.float() * 1e-8
            actions_logprobs.append(torch.log(agent_outs + z))

        # 得的q_eval和q_target是一个列表,列表里装着max_episode_len个数组,数组的的维度是(episode个数, n_agents,n_actions)
        # 把该列表转化成(episode个数, max_episode_len, n_agents,n_actions)的数组
        q_vals = torch.stack(q_evals, dim=1)

        actions_prob = torch.stack(actions_prob, dim=1)  # Concat over time
        log_prob_pi = torch.stack(actions_logprobs, dim=1)

        # cur_pol_sample_actions = np.random.choice(np.arange(self.n_actions), 1,
        #                           p=actions_prob.detach().cpu().numpy())  # action是一个整数 按概率分布采样
        # actor_sample_times=1
        # for i in range(actor_sample_times):
        # todo
        actor_sample_times = 1
        samples = torch.multinomial(
            actions_prob.view(-1, self.n_actions),
            actor_sample_times,
            replacement=True)  # a.shape = (batch_size, num_a)
        cur_pol_sample_actions = samples.view(
            episode_num, -1, self.n_agents,
            actor_sample_times)  # max_episode_len
        #### actor_sample_times=1
        q_curpi_sample_actions = torch.gather(
            q_vals, dim=3, index=cur_pol_sample_actions).squeeze(3)  # (1,60,3)

        #### actor_sample_times
        # q_curpi_sample_actions = torch.gather(q_vals, dim=3,index=cur_pol_sample_actions) # (1,60,3)  todo actor_sample_times
        # q_curpi_sample_actions = q_curpi_sample_actions.permute(3, 0, 1, 2).reshape(-1, max_episode_len, self.n_agents)  # todo actor_sample_times
        # s = s.repeat(repeats=(actor_sample_times, 1, 1))  # (1,60,3) todo actor_sample_times
        # mask = mask.repeat(repeats=(actor_sample_times, 1, 1)).view(-1)  # (1,60,3) todo actor_sample_times

        q_curpi_sample = self.eval_qmix_net(q_curpi_sample_actions,
                                            s).detach()  # TODO # (1,60,1)
        Q_curpi_sample = q_curpi_sample.repeat(
            repeats=(1, 1, self.n_agents))  # (1,60,3)

        # q_targets_mean = torch.sum(actions_prob * q_vals, dim=-1).view(-1).detach() # (180)

        # q_i_mean_negi_mean = torch.sum(actions_prob * (-self.alpha * log_prob_pi + q_vals), dim=-1)  # (1,60,3) # TODO
        q_i_mean_negi_mean = torch.sum(actions_prob * q_vals,
                                       dim=-1)  # (1,60,3) # TODO
        # q_i_mean_negi_mean = q_i_mean_negi_mean.repeat(repeats=(actor_sample_times, 1, 1))  # todo actor_sample_times
        q_i_mean_negi_sample = []
        # torch.stack( [torch.cat((q_targets_mean[:, :, i].unsqueeze(2), torch.cat((q_taken[:, :, :i], q_taken[:, :, i + 1:]), dim=2)),
        # dim=2) for i in range(3)],dim=1) # 顺序不对
        q_curpi_sample_actions_detached = q_curpi_sample_actions.detach()
        for i in range(self.n_agents):
            q_temp = copy.deepcopy(q_curpi_sample_actions_detached)
            q_temp[:, :,
                   i] = q_i_mean_negi_mean[:, :,
                                           i]  # q_i(mean_action) q_-i(tacken_action),
            q_i_mean_negi_sample.append(q_temp)
        q_i_mean_negi_sample = torch.stack(q_i_mean_negi_sample,
                                           dim=2)  # [1, 60, 3, 3]
        q_i_mean_negi_sample = q_i_mean_negi_sample.view(
            episode_num, -1, self.n_agents)  # [1, 60*3, 3]

        # s_repeat = s.repeat(repeats=(1, self.n_agents, 1))  # (1,60,48)-> # (1,60*3,48) #TODO 顺序错误
        s_repeat = s.repeat(
            repeats=(1, 1, self.n_agents))  # (1,60,48)-> # (1,60,48*3)
        # s_repeat = s_repeat.view(episode_num, self.n_agents, -1)  # (1,60,48*3)-> # (1,3,48*60)#TODO 一样的
        s_repeat = s_repeat.view(
            episode_num, self.n_agents * max_episode_len,
            self.args.state_shape)  # (1,60,48*3)-> # (1,60*3,48)#

        Q_i_mean_negi_sample = self.eval_qmix_net(
            q_i_mean_negi_sample, s_repeat).detach()  # TODO #  (1,60*3,1)
        # q_total_target = q_total_target.repeat(repeats=(1, 1, self.n_agents)) # (1,60,3)
        Q_i_mean_negi_sample = Q_i_mean_negi_sample.view(
            episode_num, -1, self.n_agents)  # (1,60*3,1)->(1,60,3)

        ###方法1
        # pi = actions_prob.view(-1, self.n_actions)
        # # Calculate policy grad with mask
        # pi_sample = torch.gather(pi, dim=1, index=cur_pol_sample_actions.reshape(-1, 1)).squeeze(1)
        # pi_sample[mask == 0] = 1.0
        # log_pi_sample = torch.log(pi_sample)
        ###方法2
        log_pi_sample = torch.gather(
            log_prob_pi, dim=3,
            index=cur_pol_sample_actions).squeeze(-1).view(-1)

        # advantages = (-self.alpha * log_pi_sample + (Q_curpi_sample.view(-1) - Q_i_mean_negi_sample.view(-1)).detach())  # TODO Bug?
        advantages = (-self.alpha * log_pi_sample + Q_curpi_sample.view(-1) -
                      Q_i_mean_negi_sample.view(-1)).detach()  # TODO
        # if self.args.advantage_norm:  # TODO
        #     EPS = 1e-10
        #     advantages = (advantages - advantages.mean()) / (advantages.std() + EPS)

        policy_loss = -((advantages * log_pi_sample) * mask).sum() / mask.sum()

        # Optimise agents
        self.policy_optimiser.zero_grad()
        # don't want critic to accumulate gradients from policy loss
        disable_gradients(self.eval_rnn)
        disable_gradients(self.eval_qmix_net)
        policy_loss.backward()  # policy gradient
        enable_gradients(self.eval_rnn)
        enable_gradients(self.eval_qmix_net)

        grad_norm = torch.nn.utils.clip_grad_norm_(
            self.policy_params, self.args.grad_norm_clip)  # 0.5 todo
        self.policy_optimiser.step()

        # compute parameters sum for debugging
        p_sum = 0.
        for p in self.policy_params:
            p_sum += p.data.abs().sum().item() / 100.0

        self.soft_update()

    def train_critic(self,
                     batch,
                     max_episode_len,
                     train_step,
                     epsilon=None):  # train_step表示是第几次学习,用来控制更新target_net网络的参数
        '''
        在learn的时候,抽取到的数据是四维的,四个维度分别为 1——第几个episode 2——episode中第几个transition
        3——第几个agent的数据 4——具体obs维度。因为在选动作时不仅需要输入当前的inputs,还要给神经网络输入hidden_state,
        hidden_state和之前的经验相关,因此就不能随机抽取经验进行学习。所以这里一次抽取多个episode,然后一次给神经网络
        传入每个episode的同一个位置的transition
        '''
        episode_num = batch['o'].shape[0]
        self.init_hidden(episode_num)
        for key in batch.keys():  # 把batch里的数据转化成tensor
            if key == 'u':
                batch[key] = torch.tensor(batch[key], dtype=torch.long)
            else:
                batch[key] = torch.tensor(batch[key], dtype=torch.float32)
        s, s_next, u, r, avail_u, avail_u_next, terminated = batch['s'], batch['s_next'], batch['u'], \
                                                             batch['r'], batch['avail_u'], batch['avail_u_next'], \
                                                             batch['terminated']
        mask = 1 - batch["padded"].float()  # 用来把那些填充的经验的TD-error置0,从而不让它们影响到学习
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        mask = mask.repeat(1, 1, self.n_agents)
        # r = 10. * (r - r.mean()) / (r.std() + 1e-6)  # normalize with batch mean and std; plus a small number to prevent numerical problem
        r = (r - r.mean()) / (r.std() + 1e-6)  # todo
        # 得到每个agent对应的Q值,维度为(episode个数, max_episode_len, n_agents, n_actions)
        q_evals, q_targets = self.get_q_values(batch, max_episode_len)
        q_evals_2, q_targets_2 = self.get_q_values_2(batch, max_episode_len)
        if self.args.cuda:
            s = s.cuda()
            u = u.cuda()
            r = r.cuda()
            s_next = s_next.cuda()
            terminated = terminated.cuda()
            mask = mask.cuda()
        # 取每个agent动作对应的Q值,并且把最后不需要的一维去掉,因为最后一维只有一个值了
        q_evals = torch.gather(q_evals, dim=3, index=u).squeeze(3)
        q_evals_2 = torch.gather(q_evals_2, dim=3, index=u).squeeze(3)

        episode_num = batch['o'].shape[0]
        q_targets_sample = []

        actions_prob = []
        actions_probs_nozero = []
        actions_logprobs = []
        # self.agent.init_hidden(batch.batch_size)
        # self.policy_hidden= self.policy.init_hidden().unsqueeze(0).expand(batch_size, self.n_agents, -1)  # bav
        for transition_idx in range(max_episode_len):
            inputs, inputs_next = self._get_inputs(
                batch, transition_idx)  # 给obs加last_action、agent_id
            if self.args.cuda:
                inputs = inputs.cuda()
                inputs_next = inputs_next.cuda()
            # agent_outs, self.target_policy_hidden = self.target_policy(inputs_next, self.target_policy_hidden)
            agent_outs, self.policy_hidden = self.agent.policy(
                inputs_next, self.policy_hidden)
            avail_actions_ = avail_u[:, transition_idx]
            reshaped_avail_actions = avail_actions_.reshape(
                episode_num * self.n_agents, -1)
            agent_outs[reshaped_avail_actions == 0] = -1e11

            # agent_outs = gumbel_softmax(agent_outs, hard=True)  # one-hot TODO ac_sample1
            # agent_outs = agent_outs.view(episode_num, self.n_agents, -1) # (1,3,9)
            # action_next = agent_outs.max(dim=2, keepdim=True)[1] # (1,3,1)
            # # action_next = torch.nonzero(agent_outs).squeeze()
            # actions_next_sample.append(action_next) # 选择动作的序号

            # agent_outs = gumbel_softmax(agent_outs, hard=True)  # 概率 TODO ac_mean
            agent_outs = F.softmax(agent_outs / 1, dim=1)  # 概率分布
            agent_outs = agent_outs.view(episode_num, self.n_agents,
                                         -1)  # 概率有0
            actions_prob.append(agent_outs)  # 每个动作的概率

            # actions_prob_nozero=agent_outs.clone()
            # actions_prob_nozero[reshaped_avail_actions.view(episode_num, self.n_agents, -1) == 0] = 1e-11 # 概率没有0
            # actions_probs_nozero.append(actions_prob_nozero)

            # Have to deal with situation of 0.0 probabilities because we can't do log 0
            z = agent_outs == 0.0
            z = z.float() * 1e-8
            actions_logprobs.append(torch.log(agent_outs + z))

        actions_prob = torch.stack(actions_prob, dim=1)  # Concat over time
        log_prob_pi = torch.stack(actions_logprobs, dim=1)  # Concat over time

        # actions_probs_nozero = torch.stack(actions_probs_nozero, dim=1)  # Concat over time
        # actions_probs_nozero[mask == 0] = 1.0
        # log_prob_pi = torch.log(actions_probs_nozero)

        # Updating alpha wrt entropy
        # alpha = 0.0  # trade-off between exploration (max entropy) and exploitation (max Q)
        target_entropy = -1. * self.n_actions
        if self.args.auto_entropy is True:
            #  alpha_loss = -(self.log_alpha * (log_prob + target_entropy).detach()).mean() #target_entropy=-2
            alpha_loss = (torch.sum(actions_prob.detach() *
                                    (-self.log_alpha *
                                     (log_prob_pi + target_entropy).detach()),
                                    dim=-1) * mask).sum() / mask.sum()
            # print('alpha loss: ',alpha_loss)
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            self.alpha = self.log_alpha.exp()
        else:
            self.alpha = 1.
            alpha_loss = 0

        # Calculated baseline
        q_targets_sample = torch.sum(
            actions_prob * (q_targets - self.alpha * log_prob_pi),
            dim=-1).view(episode_num, max_episode_len,
                         -1).detach()  # (1,60,3) TODO ac_mean
        q_targets_sample_2 = torch.sum(
            actions_prob * (q_targets_2 - self.alpha * log_prob_pi),
            dim=-1).view(episode_num, max_episode_len, -1).detach()  # (1,60,3)

        # actions_next_sample = torch.stack(actions_next_sample, dim=1)  # Concat over time # (1,60,3,1) TODO ac_sample1
        # # q_targets_sample = q_targets[:,:,actions_next]
        # q_targets_sample = torch.gather(q_targets, dim=3, index=actions_next_sample).squeeze(3) # (1,60,3)

        q_total_eval = self.eval_qmix_net(q_evals, s)  # [1, 60, 1]
        q_total_target = self.target_qmix_net(q_targets_sample, s_next)

        q_total_eval_2 = self.eval_qmix_net_2(q_evals_2, s)  # [1, 60, 1]
        q_total_target_2 = self.target_qmix_net_2(q_targets_sample_2, s_next)

        q_total_target_min = torch.min(q_total_target, q_total_target_2)
        q_total_target = q_total_target_min

        targets = r + self.args.gamma * q_total_target * (1 - terminated)

        ### update q1
        td_error = (q_total_eval - targets.detach())
        masked_td_error = mask * td_error  # 抹掉填充的经验的td_error
        # 不能直接用mean,因为还有许多经验是没用的,所以要求和再比真实的经验数,才是真正的均值
        loss = (masked_td_error**2).sum() / mask.sum()
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.eval_parameters,
                                       self.args.grad_norm_clip)
        self.optimizer.step()

        ### update q2
        td_error_2 = (q_total_eval_2 - targets.detach())
        masked_td_error_2 = mask * td_error_2  # 抹掉填充的经验的td_error
        # 不能直接用mean,因为还有许多经验是没用的,所以要求和再比真实的经验数,才是真正的均值
        loss_2 = (masked_td_error_2**2).sum() / mask.sum()
        self.optimizer_2.zero_grad()
        loss_2.backward()
        torch.nn.utils.clip_grad_norm_(self.eval_parameters_2,
                                       self.args.grad_norm_clip)
        self.optimizer_2.step()

        # if train_step > 0 and train_step % self.args.target_update_cycle == 0: # 200
        #     self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
        #     self.target_qmix_net.load_state_dict(self.eval_qmix_net.state_dict())
        #     self.target_policy.load_state_dict(self.agent.policy.state_dict()) #TODO
        # Update the frozen target models

        if train_step % 10000 == 0:  # how often to save the model args.save_cycle = 5000
            self.save_model(train_step)  # TODO

    def _get_inputs(self, batch, transition_idx):
        # 取出所有episode上该transition_idx的经验,u_onehot要取出所有,因为要用到上一条
        obs, obs_next, u_onehot = batch['o'][:, transition_idx], \
                                  batch['o_next'][:, transition_idx], batch['u_onehot'][:]
        episode_num = obs.shape[0]
        inputs, inputs_next = [], []
        inputs.append(obs)
        inputs_next.append(obs_next)
        # 给obs添加上一个动作、agent编号 30+9+3

        if self.args.last_action:
            if transition_idx == 0:  # 如果是第一条经验,就让前一个动作为0向量
                inputs.append(torch.zeros_like(u_onehot[:, transition_idx]))
            else:
                inputs.append(u_onehot[:, transition_idx - 1])
            inputs_next.append(u_onehot[:, transition_idx])
        if self.args.reuse_network:
            # 因为当前的obs三维的数据,每一维分别代表(episode编号,agent编号,obs维度),直接在dim_1上添加对应的向量
            # 即可,比如给agent_0后面加(1, 0, 0, 0, 0),表示5个agent中的0号。而agent_0的数据正好在第0行,那么需要加的
            # agent编号恰好就是一个单位矩阵,即对角线为1,其余为0
            inputs.append(
                torch.eye(self.args.n_agents).unsqueeze(0).expand(
                    episode_num, -1, -1))
            inputs_next.append(
                torch.eye(self.args.n_agents).unsqueeze(0).expand(
                    episode_num, -1, -1))
        # 要把obs中的三个拼起来,并且要把episode_num个episode、self.args.n_agents个agent的数据拼成40条(40,96)的数据,
        # 因为这里所有agent共享一个神经网络,每条数据中带上了自己的编号,所以还是自己的数据
        inputs = torch.cat(
            [x.reshape(episode_num * self.args.n_agents, -1) for x in inputs],
            dim=1)
        inputs_next = torch.cat([
            x.reshape(episode_num * self.args.n_agents, -1)
            for x in inputs_next
        ],
                                dim=1)
        return inputs, inputs_next  # (3,42)

    def get_q_values(self, batch, max_episode_len):
        episode_num = batch['o'].shape[0]
        q_evals, q_targets = [], []
        for transition_idx in range(max_episode_len):
            inputs, inputs_next = self._get_inputs(
                batch, transition_idx)  # 给obs加last_action、agent_id
            if self.args.cuda:
                inputs = inputs.cuda()
                inputs_next = inputs_next.cuda()
                self.eval_hidden = self.eval_hidden.cuda()
                self.target_hidden = self.target_hidden.cuda()
            q_eval, self.eval_hidden = self.eval_rnn(
                inputs, self.eval_hidden
            )  # inputs维度为(40,96),得到的q_eval维度为(40,n_actions)
            q_target, self.target_hidden = self.target_rnn(
                inputs_next, self.target_hidden)

            # 把q_eval维度重新变回(8, 5,n_actions)
            q_eval = q_eval.view(episode_num, self.n_agents, -1)
            q_target = q_target.view(episode_num, self.n_agents, -1)
            q_evals.append(q_eval)
            q_targets.append(q_target)
        # 得的q_eval和q_target是一个列表,列表里装着max_episode_len个数组,数组的的维度是(episode个数, n_agents,n_actions)
        # 把该列表转化成(episode个数, max_episode_len, n_agents,n_actions)的数组
        q_evals = torch.stack(q_evals, dim=1)
        q_targets = torch.stack(q_targets, dim=1)
        return q_evals, q_targets

    def get_q_values_2(self, batch, max_episode_len):
        episode_num = batch['o'].shape[0]
        q_evals, q_targets = [], []
        for transition_idx in range(max_episode_len):
            inputs, inputs_next = self._get_inputs(
                batch, transition_idx)  # 给obs加last_action、agent_id
            if self.args.cuda:
                inputs = inputs.cuda()
                inputs_next = inputs_next.cuda()
                self.eval_hidden = self.eval_hidden.cuda()
                self.target_hidden = self.target_hidden.cuda()
            q_eval, self.eval_hidden_2 = self.eval_rnn_2(
                inputs, self.eval_hidden_2
            )  # inputs维度为(40,96),得到的q_eval维度为(40,n_actions)
            q_target, self.target_hidden_2 = self.target_rnn_2(
                inputs_next, self.target_hidden_2)

            # 把q_eval维度重新变回(8, 5,n_actions)
            q_eval = q_eval.view(episode_num, self.n_agents, -1)
            q_target = q_target.view(episode_num, self.n_agents, -1)
            q_evals.append(q_eval)
            q_targets.append(q_target)
        # 得的q_eval和q_target是一个列表,列表里装着max_episode_len个数组,数组的的维度是(episode个数, n_agents,n_actions)
        # 把该列表转化成(episode个数, max_episode_len, n_agents,n_actions)的数组
        q_evals = torch.stack(q_evals, dim=1)
        q_targets = torch.stack(q_targets, dim=1)
        return q_evals, q_targets

    def init_hidden(self, episode_num):
        # 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden
        self.eval_hidden = torch.zeros(
            (episode_num, self.n_agents, self.args.rnn_hidden_dim))
        self.target_hidden = torch.zeros(
            (episode_num, self.n_agents, self.args.rnn_hidden_dim))
        self.eval_hidden_2 = torch.zeros(
            (episode_num, self.n_agents, self.args.rnn_hidden_dim))
        self.target_hidden_2 = torch.zeros(
            (episode_num, self.n_agents, self.args.rnn_hidden_dim))
        self.policy_hidden = torch.zeros(
            (episode_num, self.n_agents, self.args.rnn_hidden_dim))
        # self.target_policy_hidden = torch.zeros((episode_num, self.n_agents, self.args.rnn_hidden_dim))  # TODO
        if self.args.cuda:
            self.eval_hidden = self.eval_hidden.cuda()
            self.target_hidden = self.target_hidden.cuda()
            self.eval_hidden_2 = self.eval_hidden.cuda()
            self.target_hidden_2 = self.target_hidden.cuda()
            self.policy_hidden = self.policy_hidden.cuda()
            # self.target_policy_hidden = self.target_policy_hidden.cuda()

    def save_model(self, train_step):
        num = str(train_step // self.args.save_cycle)
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        torch.save(self.eval_qmix_net.state_dict(),
                   self.model_dir + '/' + num + '_qmix_net_params.pkl')
        torch.save(self.eval_rnn.state_dict(),
                   self.model_dir + '/' + num + '_rnn_net_params.pkl')
        torch.save(self.agent.policy.state_dict(),
                   self.model_dir + '/' + num + '_policy_net_params.pkl')

    def soft_update(self):
        self.tau = 0.005
        for param, target_param in zip(self.eval_rnn.parameters(),
                                       self.target_rnn.parameters()):
            target_param.data.copy_(self.tau * param.data +
                                    (1 - self.tau) * target_param.data)
        for param, target_param in zip(self.eval_qmix_net.parameters(),
                                       self.target_qmix_net.parameters()):
            target_param.data.copy_(self.tau * param.data +
                                    (1 - self.tau) * target_param.data)
        for param, target_param in zip(self.eval_rnn_2.parameters(),
                                       self.target_rnn_2.parameters()):
            target_param.data.copy_(self.tau * param.data +
                                    (1 - self.tau) * target_param.data)
        for param, target_param in zip(self.eval_qmix_net_2.parameters(),
                                       self.target_qmix_net_2.parameters()):
            target_param.data.copy_(self.tau * param.data +
                                    (1 - self.tau) * target_param.data)
Example #5
0
class Q_Decom:
    def __init__(self, args, itr):
        self.args = args
        input_shape = self.args.obs_shape
        # 调整RNN输入维度
        if args.last_action:
            input_shape += self.args.n_actions
        if args.reuse_network:
            input_shape += self.args.n_agents
        # 设置网络
        self.eval_rnn = RNN(input_shape, args)
        self.target_rnn = RNN(input_shape, args)
        # 通过数字标记,方便后续对算法类型进行判断
        self.wqmix = 0
        if self.args.alg == 'cwqmix' or self.args.alg == 'owqmix':
            self.wqmix = 1

        # 默认值分解算法使用QMIX
        if 'qmix' in self.args.alg:
            self.eval_mix_net = QMIXMixer(args)
            self.target_mix_net = QMIXMixer(args)
            if self.wqmix > 0:
                self.qstar_eval_mix = QStar(args)
                self.qstar_target_mix = QStar(args)
                self.qstar_eval_rnn = RNN(input_shape, args)
                self.qstar_target_rnn = RNN(input_shape, args)
                if self.args.alg == 'cwqmix':
                    self.alpha = 0.75
                elif self.args.alg == 'owqmix':
                    self.alpha = 0.5
                else:
                    raise Exception('没有这个算法')
        elif self.args.alg == 'vdn':
            self.eval_mix_net = VDNMixer()
            self.target_mix_net = VDNMixer()

        # 是否使用GPU
        if args.cuda:
            self.eval_rnn.cuda()
            self.target_rnn.cuda()
            self.eval_mix_net.cuda()
            self.target_mix_net.cuda()
            if self.wqmix > 0:
                self.qstar_eval_mix.cuda()
                self.qstar_target_mix.cuda()
                self.qstar_eval_rnn.cuda()
                self.qstar_target_rnn.cuda()
        # 是否加载模型
        self.model_dir = args.model_dir + '/' + args.alg + '/' + args.map + '/' + str(
            itr)
        if args.load_model:
            if os.path.exists(self.model_dir + '/rnn_net_params.pkl'):
                path_rnn = self.model_dir + '/rnn_net_params.pkl'
                path_mix = self.model_dir + '/' + self.args.alg + '_net_params.pkl'

                map_location = 'cuda:0' if args.cuda else 'cpu'

                self.eval_rnn.load_state_dict(
                    torch.load(path_rnn, map_location=map_location))
                self.eval_mix_net.load_state_dict(
                    torch.load(path_mix, map_location=map_location))
                if self.wqmix > 0:
                    path_agent_rnn = self.model_dir + '/rnn_net_params2.pkl'
                    path_qstar = self.model_dir + '/' + 'qstar_net_params.pkl'
                    self.qstar_eval_rnn.load_state_dict(
                        torch.load(path_agent_rnn, map_location=map_location))
                    self.qstar_eval_mix.load_state_dict(
                        torch.load(path_qstar, map_location=map_location))
                print('成功加载模型 %s' % path_rnn + ' 和 %s' % path_mix)
            else:
                raise Exception("模型不存在")
        # 令target网络与eval网络参数相同
        self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
        self.target_mix_net.load_state_dict(self.eval_mix_net.state_dict())
        # 获取所有参数
        self.eval_params = list(self.eval_rnn.parameters()) + list(
            self.eval_mix_net.parameters())
        # 学习过程中要为每个episode的每个agent维护一个eval_hidden,执行过程中要为每个agetn维护一个eval_hidden
        self.eval_hidden = None
        self.target_hidden = None
        if self.wqmix > 0:
            # 令target网络与eval网络参数相同
            self.qstar_target_rnn.load_state_dict(
                self.qstar_eval_rnn.state_dict())
            self.qstar_target_mix.load_state_dict(
                self.qstar_eval_mix.state_dict())
            # 获取所有参数
            self.qstar_params = list(self.qstar_eval_rnn.parameters()) + list(
                self.qstar_eval_mix.parameters())
            # init hidden
            self.qstar_eval_hidden = None
            self.qstar_target_hidden = None
        # 获取优化器
        if args.optim == 'RMS':
            self.optimizer = torch.optim.RMSprop(self.eval_params, lr=args.lr)
            if self.wqmix > 0:
                self.qstar_optimizer = torch.optim.RMSprop(self.qstar_params,
                                                           lr=args.lr)
        else:
            self.optimizer = torch.optim.Adam(self.eval_params)
            if self.wqmix > 0:
                self.qstar_optimizer = torch.optim.Adam(self.qstar_params)
        print("值分解算法 " + self.args.alg + " 初始化")

    def learn(self, batch, max_episode_len, train_step, epsilon=None):
        """
        在learn的时候,抽取到的数据是四维的,四个维度分别为
        1——第几个episode
        2——episode中第几个transition
        3——第几个agent的数据
        4——具体obs维度。
        因为在选动作时不仅需要输入当前的inputs,还要给神经网络输入hidden_state,
        hidden_state和之前的经验相关,因此就不能随机抽取经验进行学习。所以这里一次抽取多个episode,
        然后一次给神经网络传入每个episode的同一个位置的transition
        :param batch:
        :param max_episode_len:
        :param train_step:
        :param epsilon:
        :return:
        """
        # 获得episode的数目
        episode_num = batch['o'].shape[0]
        # 初始化隐藏状态
        self.init_hidden(episode_num)
        # 数据转为tensor
        for key in batch.keys():
            if key == 'a':
                batch[key] = torch.LongTensor(batch[key])
            else:
                batch[key] = torch.Tensor(batch[key])
        s, next_s, a, r, avail_a, next_avail_a, done = batch['s'], batch['next_s'], batch['a'], \
                                                       batch['r'], batch['avail_a'], batch['next_avail_a'], \
                                                       batch['done']
        # 避免填充的产生 TD-error 影响训练
        mask = 1 - batch["padded"].float()
        # 获取当前与下个状态的q值,(episode, max_episode_len, n_agents, n_actions)
        eval_qs, target_qs = self.get_q(batch, episode_num, max_episode_len)
        # 是否使用GPU
        if self.args.cuda:
            a = a.cuda()
            r = r.cuda()
            done = done.cuda()
            mask = mask.cuda()
            if 'qmix' in self.args.alg:
                s = s.cuda()
                next_s = next_s.cuda()
        # 得到每个动作对应的 q 值
        eval_qs = torch.gather(eval_qs, dim=3, index=a).squeeze(3)
        # 计算Q_tot
        eval_q_total = self.eval_mix_net(eval_qs, s)
        qstar_q_total = None
        # 需要先把不行动作的mask掉
        target_qs[next_avail_a == 0.0] = -9999999
        if self.wqmix > 0:
            # TODO 找到使得Q_tot最大的联合动作,由于qmix是单调假设的,每个agent q值最大则 Q_tot最大,因此联合动作就是每个agent q值最大的动作
            argmax_u = target_qs.argmax(dim=3).unsqueeze(3)
            qstar_eval_qs, qstar_target_qs = self.get_q(
                batch, episode_num, max_episode_len, True)
            # 获得对应的动作q值
            qstar_eval_qs = torch.gather(qstar_eval_qs, dim=3,
                                         index=a).squeeze(3)
            qstar_target_qs = torch.gather(qstar_target_qs,
                                           dim=3,
                                           index=argmax_u).squeeze(3)
            # 通过前馈网络得到qstar
            qstar_q_total = self.qstar_eval_mix(qstar_eval_qs, s)
            next_q_total = self.qstar_target_mix(qstar_target_qs, next_s)
        else:
            # 得到 target q,是inf出现的nan
            # target_qs[next_avail_a == 0.0] = float('-inf')
            target_qs = target_qs.max(dim=3)[0]
            # 计算target Q_tot
            next_q_total = self.target_mix_net(target_qs, next_s)

        target_q_total = r + self.args.gamma * next_q_total * (1 - done)
        weights = torch.Tensor(np.ones(eval_q_total.shape))
        if self.wqmix > 0:
            # 1- 可以保证weights在 (0, 1]
            # TODO: 这里只说是 (0, 1] 之间,也没说怎么设置还是学习,暂时认为是一个随机数
            # weights = torch.Tensor(1 - np.random.ranf(eval_q_total.shape))
            weights = torch.full(eval_q_total.shape, self.alpha)
            if self.args.alg == 'cwqmix':
                error = mask * (target_q_total - qstar_q_total)
            elif self.args.alg == 'owqmix':
                error = mask * (target_q_total - eval_q_total)
            else:
                raise Exception("模型不存在")
            weights[error > 0] = 1.
            # qstar 参数更新
            qstar_error = mask * (qstar_q_total - target_q_total.detach())

            qstar_loss = (qstar_error**2).sum() / mask.sum()
            self.qstar_optimizer.zero_grad()
            qstar_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.qstar_params,
                                           self.args.clip_norm)
            self.qstar_optimizer.step()

        # 计算 TD error
        # TODO 这里权值detach有影响吗
        td_error = mask * (eval_q_total - target_q_total.detach())
        if self.args.cuda:
            weights = weights.cuda()
        loss = (weights.detach() * td_error**2).sum() / mask.sum()
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.eval_params, self.args.clip_norm)
        self.optimizer.step()

        if train_step > 0 and train_step % self.args.target_update_period == 0:
            self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
            self.target_mix_net.load_state_dict(self.eval_mix_net.state_dict())
            if self.wqmix > 0:
                self.qstar_target_rnn.load_state_dict(
                    self.qstar_eval_rnn.state_dict())
                self.qstar_target_mix.load_state_dict(
                    self.qstar_eval_mix.state_dict())

    def init_hidden(self, episode_num):
        """
        为每个episode中的每个agent都初始化一个eval_hidden,target_hidden
        :param episode_num:
        :return:
        """
        self.eval_hidden = torch.zeros(
            (episode_num, self.args.n_agents, self.args.rnn_hidden_dim))
        self.target_hidden = torch.zeros(
            (episode_num, self.args.n_agents, self.args.rnn_hidden_dim))
        if self.wqmix > 0:
            self.qstar_eval_hidden = torch.zeros(
                (episode_num, self.args.n_agents, self.args.rnn_hidden_dim))
            self.qstar_target_hidden = torch.zeros(
                (episode_num, self.args.n_agents, self.args.rnn_hidden_dim))

    def get_q(self, batch, episode_num, max_episode_len, wqmix=False):
        eval_qs, target_qs = [], []
        for trans_idx in range(max_episode_len):
            # 每个obs加上agent编号和last_action
            inputs, next_inputs = self.get_inputs(batch, episode_num,
                                                  trans_idx)
            # 是否使用GPU
            if self.args.cuda:
                inputs = inputs.cuda()
                next_inputs = next_inputs.cuda()
                if wqmix:
                    self.qstar_eval_hidden = self.qstar_eval_hidden.cuda()
                    self.qstar_target_hidden = self.qstar_target_hidden.cuda()
                else:
                    self.eval_hidden = self.eval_hidden.cuda()
                    self.target_hidden = self.target_hidden.cuda()
            # 得到q值
            if wqmix:
                eval_q, self.qstar_eval_hidden = self.qstar_eval_rnn(
                    inputs, self.qstar_eval_hidden)
                target_q, self.qstar_target_hidden = self.qstar_target_rnn(
                    next_inputs, self.qstar_target_hidden)
            else:
                eval_q, self.eval_hidden = self.eval_rnn(
                    inputs, self.eval_hidden)
                target_q, self.target_hidden = self.target_rnn(
                    next_inputs, self.target_hidden)
            # 形状变换
            eval_q = eval_q.view(episode_num, self.args.n_agents, -1)
            target_q = target_q.view(episode_num, self.args.n_agents, -1)
            # 添加这个transition 的信息
            eval_qs.append(eval_q)
            target_qs.append(target_q)
        # 将max_episode_len个(episode, n_agents, n_actions) 堆叠为 (episode, max_episode_len, n_agents, n_actions)
        eval_qs = torch.stack(eval_qs, dim=1)
        target_qs = torch.stack(target_qs, dim=1)
        return eval_qs, target_qs

    def get_inputs(self, batch, episode_num, trans_idx):
        # 取出所有episode上该trans_idx的经验,onehot_a要取出所有,因为要用到上一条
        obs, next_obs, onehot_a = batch['o'][:, trans_idx], \
                                  batch['next_o'][:, trans_idx], batch['onehot_a'][:]
        inputs, next_inputs = [], []
        inputs.append(obs)
        next_inputs.append(next_obs)
        # 给obs添加上一个动作,agent编号
        if self.args.last_action:
            if trans_idx == 0:
                inputs.append(torch.zeros_like(onehot_a[:, trans_idx]))
            else:
                inputs.append(onehot_a[:, trans_idx - 1])
            next_inputs.append(onehot_a[:, trans_idx])
        if self.args.reuse_network:
            """
            给数据增加agent编号,对于每个episode的数据,分为多个agent,每个agent编号为独热编码,
            这样对于所有agent的编号堆叠起来就是一个单位矩阵
            """
            inputs.append(
                torch.eye(self.args.n_agents).unsqueeze(0).expand(
                    episode_num, -1, -1))
            next_inputs.append(
                torch.eye(self.args.n_agents).unsqueeze(0).expand(
                    episode_num, -1, -1))
        # 将之前append的数据合并,得到形状为(episode_num*n_agents, obs*(n_actions*n_agents))
        inputs = torch.cat(
            [x.reshape(episode_num * self.args.n_agents, -1) for x in inputs],
            dim=1)
        next_inputs = torch.cat([
            x.reshape(episode_num * self.args.n_agents, -1)
            for x in next_inputs
        ],
                                dim=1)
        return inputs, next_inputs

    def save_model(self, train_step):
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

        if type(train_step) == str:
            num = train_step
        else:
            num = str(train_step // self.args.save_model_period)

        torch.save(
            self.eval_mix_net.state_dict(), self.model_dir + '/' + num + '_' +
            self.args.alg + '_net_params.pkl')
        torch.save(self.eval_rnn.state_dict(),
                   self.model_dir + '/' + num + '_rnn_params.pkl')
Example #6
0
class DMAQ_qattenLearner:
    def __init__(self, args, itr):
        self.args = args
        input_shape = self.args.obs_shape
        # 调整RNN输入维度
        if args.last_action:
            input_shape += self.args.n_actions
        if args.reuse_network:
            input_shape += self.args.n_agents
        # 设置网络
        self.eval_rnn = RNN(input_shape, args)
        self.target_rnn = RNN(input_shape, args)

        # 默认值分解算法使用 dmaq
        if self.args.alg == 'dmaq_qatten':
            self.mixer = DMAQ_QattenMixer()
            self.target_mixer = DMAQ_QattenMixer()
        elif self.args.alg == 'dmaq':
            self.mixer = DMAQer(args)
            self.target_mixer = DMAQer(args)
        else:
            raise Exception('Unsupported!')

        # 是否使用GPU
        if args.cuda:
            self.eval_rnn.cuda()
            self.target_rnn.cuda()
            self.mixer.cuda()
            self.target_mixer.cuda()
        # 是否加载模型
        # self.model_dir = args.model_dir + '/' + args.alg + '/' + args.map + '/' + str(itr)
        # self.model_dir = args.model_dir + '/' + args.map + '/' + args.alg + '_' + str(self.args.epsilon_anneal_steps // 10000) + 'w' + '/' + str(itr)
        if args.load_model:
            if os.path.exists(self.args.model_dir + '/rnn_net_params.pkl'):
                path_rnn = self.args.model_dir + '/rnn_net_params.pkl'
                path_mix = self.args.model_dir + '/' + self.args.alg + '_net_params.pkl'

                map_location = 'cuda:0' if args.cuda else 'cpu'

                self.eval_rnn.load_state_dict(
                    torch.load(path_rnn, map_location=map_location))
                self.mixer.load_state_dict(
                    torch.load(path_mix, map_location=map_location))

                print('成功加载模型 %s' % path_rnn + ' 和 %s' % path_mix)
            else:
                raise Exception("模型不存在")
        # 令target网络与eval网络参数相同
        self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
        self.target_mixer.load_state_dict(self.mixer.state_dict())
        # 获取所有参数
        self.eval_params = list(self.eval_rnn.parameters()) + list(
            self.mixer.parameters())
        # 学习过程中要为每个episode的每个agent维护一个eval_hidden,执行过程中要为每个agetn维护一个eval_hidden
        self.eval_hidden = None
        self.target_hidden = None

        # 获取优化器
        if args.optim == 'RMS':
            self.optimizer = torch.optim.RMSprop(self.eval_params, lr=args.lr)
        else:
            self.optimizer = torch.optim.Adam(self.eval_params)
        print("值分解算法 " + self.args.alg + " 初始化")

    def learn(self, batch, max_episode_len, train_step):
        """
        在learn的时候,抽取到的数据是四维的,四个维度分别为
        1——第几个episode
        2——episode中第几个transition
        3——第几个agent的数据
        4——具体obs维度。
        因为在选动作时不仅需要输入当前的inputs,还要给神经网络输入hidden_state,
        hidden_state和之前的经验相关,因此就不能随机抽取经验进行学习。所以这里一次抽取多个episode,
        然后一次给神经网络传入每个episode的同一个位置的transition
        :param batch:
        :param max_episode_len:
        :param train_step:
        :param epsilon:
        :return:
        """
        # 获得episode的数目
        episode_num = batch['o'].shape[0]
        # 初始化隐藏状态
        self.init_hidden(episode_num)
        # 数据转为tensor
        for key in batch.keys():
            if key == 'a':
                batch[key] = torch.LongTensor(batch[key])
            else:
                batch[key] = torch.Tensor(batch[key])
        s, next_s, a, r, avail_a, next_avail_a, done, actions_onehot = batch['s'], batch['next_s'], batch['a'], \
                                                                       batch['r'], batch['avail_a'], batch[
                                                                           'next_avail_a'], \
                                                                       batch['done'], batch['onehot_a']
        # 避免填充的产生 TD-error 影响训练
        mask = 1 - batch["padded"].float()
        # 获取当前与下个状态的q值,(episode, max_episode_len, n_agents, n_actions)
        eval_qs, target_qs = self.get_q(batch, episode_num, max_episode_len)
        # 是否使用GPU
        if self.args.cuda:
            a = a.cuda()
            r = r.cuda()
            done = done.cuda()
            mask = mask.cuda()
            s = s.cuda()
            next_s = next_s.cuda()
            actions_onehot = actions_onehot.cuda()
        # 得到每个动作对应的 q 值
        eval_qsa = torch.gather(eval_qs, dim=3, index=a).squeeze(3)
        max_action_qvals = eval_qs.max(dim=3)[0]

        target_qs[next_avail_a == 0.0] = -9999999
        target_qsa = target_qs.max(dim=3)[0]
        # target_max_actions = target_qs.argmax(dim=3).unsqueeze(3)
        # 计算Q_tot
        q_attend_regs = None
        if self.args.alg == "dmaq_qatten":
            ans_chosen, q_attend_regs, head_entropies = self.mixer(eval_qsa,
                                                                   s,
                                                                   is_v=True)
            ans_adv, _, _ = self.mixer(eval_qsa,
                                       s,
                                       actions=actions_onehot,
                                       max_q_i=max_action_qvals,
                                       is_v=False)
            eval_qsa = ans_chosen + ans_adv
        else:
            ans_chosen = self.mixer(eval_qsa, s, is_v=True)
            ans_adv = self.mixer(eval_qsa,
                                 s,
                                 actions=actions_onehot,
                                 max_q_i=max_action_qvals,
                                 is_v=False)
            eval_qsa = ans_chosen + ans_adv

        # if self.args.double_q:
        #     if self.args.self.mixer == "dmaq_qatten":
        #         target_chosen, _, _ = self.target_mixer(target_chosen_qvals, batch["state"][:, 1:], is_v=True)
        #         target_adv, _, _ = self.target_mixer(target_chosen_qvals, batch["state"][:, 1:],
        #                                              actions=cur_max_actions_onehot,
        #                                              max_q_i=target_max_qvals, is_v=False)
        #         target_max_qvals = target_chosen + target_adv
        #     else:
        #     target_chosen = self.target_mixer(target_chosen_qvals, batch["state"][:, 1:], is_v=True)
        #     target_adv = self.target_mixer(target_chosen_qvals, batch["state"][:, 1:],
        #                                    actions=cur_max_actions_onehot,
        #                                    max_q_i=target_max_qvals, is_v=False)
        #     target_max_qvals = target_chosen + target_adv
        # else:
        target_max_qvals = self.target_mixer(target_qsa, next_s, is_v=True)

        # Calculate 1-step Q-Learning targets
        targets = r + self.args.gamma * (1 - done) * target_max_qvals

        # Td-error
        td_error = (eval_qsa - targets.detach())

        mask = mask.expand_as(td_error)

        # 0-out the targets that came from padded data
        masked_td_error = td_error * mask

        # Normal L2 loss, take mean over actual data
        loss = (masked_td_error**2).sum() / mask.sum()
        if self.args.alg == "dmaq_qatten":
            loss += q_attend_regs
        # else:
        #     loss = (masked_td_error ** 2).sum() / mask.sum()

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.eval_params, self.args.clip_norm)
        self.optimizer.step()

        if train_step > 0 and train_step % self.args.target_update_period == 0:
            self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
            self.target_mixer.load_state_dict(self.mixer.state_dict())

    def init_hidden(self, episode_num):
        """
        为每个episode中的每个agent都初始化一个eval_hidden,target_hidden
        :param episode_num:
        :return:
        """
        self.eval_hidden = torch.zeros(
            (episode_num, self.args.n_agents, self.args.rnn_hidden_dim))
        self.target_hidden = torch.zeros(
            (episode_num, self.args.n_agents, self.args.rnn_hidden_dim))

    def get_q(
        self,
        batch,
        episode_num,
        max_episode_len,
    ):
        eval_qs, target_qs = [], []
        for trans_idx in range(max_episode_len):
            # 每个obs加上agent编号和last_action
            inputs, next_inputs = self.get_inputs(batch, episode_num,
                                                  trans_idx)
            # 是否使用GPU
            if self.args.cuda:
                inputs = inputs.cuda()
                next_inputs = next_inputs.cuda()
                self.eval_hidden = self.eval_hidden.cuda()
                self.target_hidden = self.target_hidden.cuda()
            # 得到q值
            eval_q, self.eval_hidden = self.eval_rnn(inputs, self.eval_hidden)
            target_q, self.target_hidden = self.target_rnn(
                next_inputs, self.target_hidden)
            # 形状变换
            eval_q = eval_q.view(episode_num, self.args.n_agents, -1)
            target_q = target_q.view(episode_num, self.args.n_agents, -1)
            # 添加这个transition 的信息
            eval_qs.append(eval_q)
            target_qs.append(target_q)
        # 将max_episode_len个(episode, n_agents, n_actions) 堆叠为 (episode, max_episode_len, n_agents, n_actions)
        eval_qs = torch.stack(eval_qs, dim=1)
        target_qs = torch.stack(target_qs, dim=1)
        return eval_qs, target_qs

    def get_inputs(self, batch, episode_num, trans_idx):
        # 取出所有episode上该trans_idx的经验,onehot_a要取出所有,因为要用到上一条
        obs, next_obs, onehot_a = batch['o'][:, trans_idx], \
                                  batch['next_o'][:, trans_idx], batch['onehot_a'][:]
        inputs, next_inputs = [], []
        inputs.append(obs)
        next_inputs.append(next_obs)
        # 给obs添加上一个动作,agent编号
        if self.args.last_action:
            if trans_idx == 0:
                inputs.append(torch.zeros_like(onehot_a[:, trans_idx]))
            else:
                inputs.append(onehot_a[:, trans_idx - 1])
            next_inputs.append(onehot_a[:, trans_idx])
        if self.args.reuse_network:
            """
            给数据增加agent编号,对于每个episode的数据,分为多个agent,每个agent编号为独热编码,
            这样对于所有agent的编号堆叠起来就是一个单位矩阵
            """
            inputs.append(
                torch.eye(self.args.n_agents).unsqueeze(0).expand(
                    episode_num, -1, -1))
            next_inputs.append(
                torch.eye(self.args.n_agents).unsqueeze(0).expand(
                    episode_num, -1, -1))
        # 将之前append的数据合并,得到形状为(episode_num*n_agents, obs*(n_actions*n_agents))
        inputs = torch.cat(
            [x.reshape(episode_num * self.args.n_agents, -1) for x in inputs],
            dim=1)
        next_inputs = torch.cat([
            x.reshape(episode_num * self.args.n_agents, -1)
            for x in next_inputs
        ],
                                dim=1)
        return inputs, next_inputs

    def save_model(self, train_step):
        if not os.path.exists(self.args.model_dir):
            os.makedirs(self.args.model_dir)

        if type(train_step) == str:
            num = train_step
        else:
            num = str(train_step // self.args.save_model_period)

        torch.save(
            self.mixer.state_dict(), self.args.model_dir + '/' + num + '_' +
            self.args.alg + '_net_params.pkl')
        torch.save(self.eval_rnn.state_dict(),
                   self.args.model_dir + '/' + num + '_rnn_params.pkl')
Example #7
0
class QMIX_PG():
    def __init__(self, agent, args):
        self.agent = agent
        self.target_policy = copy.deepcopy(self.agent.policy)  # TODO

        self.policy_params = list(agent.policy.parameters())
        self.policy_optimiser = torch.optim.RMSprop(params=self.policy_params,
                                                    lr=args.actor_lr)  # , alpha=args.optim_alpha, eps=args.optim_eps)
        self.n_actions = args.n_actions
        self.n_agents = args.n_agents
        self.state_shape = args.state_shape
        self.obs_shape = args.obs_shape
        input_shape = self.obs_shape
        # 根据参数决定RNN的输入维度
        if args.last_action:
            input_shape += self.n_actions
        if args.reuse_network:
            input_shape += self.n_agents

        # 神经网络
        self.eval_rnn = RNN(input_shape, args)  # 每个agent选动作的网络
        self.target_rnn = RNN(input_shape, args)
        self.eval_qmix_net = QMixNet(args)  # 把agentsQ值加起来的网络
        self.target_qmix_net = QMixNet(args)
        self.args = args
        if self.args.cuda:
            self.eval_rnn.cuda()
            self.target_rnn.cuda()
            self.eval_qmix_net.cuda()
            self.target_qmix_net.cuda()
            self.agent.policy.cuda()
            self.target_policy.cuda()
        tmp = f'clamp2-5_' + f'{args.loss_coeff_entropy}_' + f'{args.actor_buffer_size}_{args.critic_buffer_size}_{args.actor_train_steps}_{args.critic_train_steps}_' \
                                                             f'{args.actor_update_delay}_{args.critic_lr}'  # f'{args.anneal_epsilon}_'
        self.model_dir = 'linear_mix/' + args.model_dir + '/qmix_ac_total_counterfactual' + '/' + tmp + '/' + args.map

        # 如果存在模型则加载模型
        if self.args.load_model:
            if os.path.exists(self.model_dir + '/rnn_net_params.pkl'):
                path_rnn = self.model_dir + '/rnn_net_params.pkl'
                path_qmix = self.model_dir + '/qmix_net_params.pkl'
                path_policy = self.model_dir + '/policy_net_params.pkl'
                map_location = 'cuda:0' if self.args.cuda else 'cpu'
                self.eval_rnn.load_state_dict(torch.load(path_rnn, map_location=map_location))
                self.eval_qmix_net.load_state_dict(torch.load(path_qmix, map_location=map_location))
                self.target_policy.load_state_dict(torch.load(path_policy, map_location=map_location))  # TODO
                print('Successfully load the model: {} and {}'.format(path_rnn, path_qmix))
            else:
                raise Exception("No model!")

        # 让target_net和eval_net的网络参数相同
        self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
        self.target_qmix_net.load_state_dict(self.eval_qmix_net.state_dict())
        self.target_policy.load_state_dict(self.agent.policy.state_dict())  # TODO

        self.eval_parameters = list(self.eval_qmix_net.parameters()) + list(self.eval_rnn.parameters())
        if args.optimizer == "RMS":
            self.optimizer = torch.optim.RMSprop(self.eval_parameters, lr=args.critic_lr)

        # 执行过程中,要为每个agent都维护一个eval_hidden
        # 学习过程中,要为每个episode的每个agent都维护一个eval_hidden、target_hidden
        self.eval_hidden = None
        self.target_hidden = None
        print('Init alg QMIX')

    def train_actor(self, batch, max_episode_len, update_agent_id=0):  # EpisodeBatch
        # Get the relevant quantities
        # bs = batch.batch_size
        # max_t = batch.max_seq_length
        # actions = batch["u"][:, :-1]
        # terminated = batch["terminated"][:, :-1].float()
        # avail_actions = batch["avail_u"][:, :-1]
        # # mask = batch["filled"][:, :-1].float()
        # mask = 1 - batch["padded"].float()
        # mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        # mask = mask.repeat(1, 1, self.n_agents).view(-1)
        # states = batch["s"][:, :-1]

        episode_num = batch['o'].shape[0]
        self.init_hidden(episode_num)
        for key in batch.keys():  # 把batch里的数据转化成tensor
            if key == 'u':
                batch[key] = torch.tensor(batch[key], dtype=torch.long)
            else:
                batch[key] = torch.tensor(batch[key], dtype=torch.float32)
        # s, s_next, u, r, avail_u, avail_u_next, terminated = batch['s'], batch['s_next'], batch['u'], \
        #                                                      batch['r'], batch['avail_u'], batch['avail_u_next'], \
        #                                                      batch['terminated']
        s = batch['s']
        terminated = batch['terminated']
        mask = 1 - batch["padded"].float()  # 用来把那些填充的经验的TD-error置0,从而不让它们影响到学习
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        mask = mask.repeat(1, 1, self.n_agents).view(-1)
        # actions = batch["u"][:, :-1]
        actions = batch["u"]
        avail_u = batch["avail_u"]
        # terminated = batch["terminated"][:, :-1].float()
        # avail_actions = batch["avail_u"][:, :-1]

        if self.args.cuda:
            s = s.cuda()
            # u = u.cuda()
            # r = r.cuda()
            # s_next = s_next.cuda()
            # terminated = terminated.cuda()
            actions = actions.cuda()
            avail_u = avail_u.cuda()
            mask = mask.cuda()

        # # build q
        # inputs = self.critic._build_inputs(batch, bs, max_t)
        # q_vals = self.critic.forward(inputs).detach()[:, :-1]

        # episode_num = batch['o'].shape[0]
        q_evals, q_targets = [], []

        actions_prob = []
        # self.agent.init_hidden(batch.batch_size)
        # self.policy_hidden= self.policy.init_hidden().unsqueeze(0).expand(batch_size, self.n_agents, -1)  # bav
        for transition_idx in range(max_episode_len):
            inputs, inputs_next = self._get_inputs(batch, transition_idx)  # 给obs加last_action、agent_id
            if self.args.cuda:
                inputs = inputs.cuda()
                inputs_next = inputs_next.cuda()
                self.eval_hidden = self.eval_hidden.cuda()
            q_eval, self.eval_hidden = self.eval_rnn(inputs,
                                                     self.eval_hidden)  # inputs维度为(40,96),得到的q_eval维度为(40,n_actions) Q_i
            # 把q_eval维度重新变回(8, 5,n_actions)
            q_eval = q_eval.view(episode_num, self.n_agents, -1)
            q_evals.append(q_eval)

            agent_outs, self.policy_hidden = self.agent.policy(inputs, self.policy_hidden)
            avail_actions_ = avail_u[:, transition_idx]
            reshaped_avail_actions = avail_actions_.reshape(episode_num * self.n_agents, -1)
            agent_outs[reshaped_avail_actions == 0] = -1e11
            # agent_outs = torch.nn.functional.softmax(agent_outs, dim=-1)
            # agent_outs = gumbel_softmax(agent_outs, hard=False) #todo
            agent_outs = F.softmax(agent_outs / 1, dim=1)  # 概率分布

            # If hard=True, then the returned sample will be one-hot, otherwise it will
            # be a probabilitiy distribution that sums to 1 across classes
            agent_outs = agent_outs.view(episode_num, self.n_agents, -1)
            actions_prob.append(agent_outs)  # 每个动作的概率

        # 得的q_eval和q_target是一个列表,列表里装着max_episode_len个数组,数组的的维度是(episode个数, n_agents,n_actions)
        # 把该列表转化成(episode个数, max_episode_len, n_agents,n_actions)的数组
        q_vals = torch.stack(q_evals, dim=1)

        # mac_out = []
        # self.mac.init_hidden(batch.batch_size)
        # for t in range(batch.max_seq_length - 1):
        #     agent_outs = self.mac.forward(batch, t=t)
        #     mac_out.append(agent_outs)
        actions_prob = torch.stack(actions_prob, dim=1)  # Concat over time

        # Mask out unavailable actions, renormalise (as in action selection)
        # actions_prob[avail_actions == 0] = 0
        # actions_prob = actions_prob / actions_prob.sum(dim=-1, keepdim=True)
        # actions_prob[avail_actions == 0] = 0

        # Calculated baseline
        q_taken_actions = torch.gather(q_vals, dim=3, index=actions).squeeze(3)  # (1,60,3)  TODO NOTE
        q_taken = self.eval_qmix_net(q_taken_actions, s).detach()  # TODO # (1,60,1)
        Q_taken = q_taken.repeat(repeats=(1, 1, self.n_agents))  # (1,60,3)

        pi = actions_prob.view(-1, self.n_actions)
        # q_targets_mean = torch.sum(actions_prob * q_vals, dim=-1).view(-1).detach() # (180)
        q_mean_actions = torch.sum(actions_prob * q_vals, dim=-1).detach()  # (1,60,3)
        q_i_mean_negi_taken = []
        # torch.stack( [torch.cat((q_targets_mean[:, :, i].unsqueeze(2), torch.cat((q_taken[:, :, :i], q_taken[:, :, i + 1:]), dim=2)),
        # dim=2) for i in range(3)],dim=1) # 顺序不对
        q_taken_actions_detached = q_taken_actions.detach()
        for i in range(self.n_agents):
            q_temp = copy.deepcopy(q_taken_actions_detached)
            q_temp[:, :, i] = q_mean_actions[:, :, i]  # q_i(mean_action) q_-i(tacken_action),
            q_i_mean_negi_taken.append(q_temp)
        q_i_mean_negi_taken = torch.stack(q_i_mean_negi_taken, dim=2)  # [1, 60, 3, 3]
        q_i_mean_negi_taken = q_i_mean_negi_taken.view(episode_num, -1, self.n_agents)  # [1, 60*3, 3]
        # TODO
        # s_repeat = s.repeat(repeats=(1, self.n_agents, 1))  # (1,60,48)-> # (1,60*3,48) #TODO 顺序错误
        s_repeat = s.repeat(repeats=(1, 1, self.n_agents))  # (1,60,48)-> # (1,60,48*3)
        s_repeat = s_repeat.view(episode_num, self.n_agents, -1)  # (1,60,48*3)-> # (1,60*3,48)

        Q_i_mean_negi_taken = self.eval_qmix_net(q_i_mean_negi_taken, s_repeat).detach()  # TODO #  (1,60*3,1)
        # q_total_target = q_total_target.repeat(repeats=(1, 1, self.n_agents)) # (1,60,3)
        Q_i_mean_negi_taken = Q_i_mean_negi_taken.view(episode_num, -1, self.n_agents)  # (1,60*3,1)->(1,60,3)

        # Calculate policy grad with mask
        pi_taken = torch.gather(pi, dim=1, index=actions.reshape(-1, 1)).squeeze(1)
        pi_taken[mask == 0] = 1.0
        log_pi_taken = torch.log(pi_taken)

        # # TODO normalnize
        # advantages_original=(q_taken.view(-1,self.n_agents)-baseline.view(-1,self.n_agents))
        # advantages = (advantages_original - advantages_original.mean(dim=0)) / advantages_original.std(dim=0) #
        # # advantages = (advantages_original-advantages_original.mean(dim=1).repeat(repeats=(1, max_episode_len, 1)))/advantages_original.std(dim=1).repeat(repeats=(1, max_episode_len, 1))
        # advantages=advantages.view(-1).detach()

        # TODO no normalnize
        advantages = (Q_taken.view(-1) - Q_i_mean_negi_taken.view(-1)).detach()
        # advantages = (Q_taken.view(-1) ).detach()#TODO

        if self.args.advantage_norm:  # TODO
            EPS = 1e-10
            advantages = (advantages - advantages.mean()) / (advantages.std() + EPS)
            # advantages = (q_taken.view(-1) - baseline.view(-1)).detach()
            # advantages = (-log_pi_taken + q_taken.view(-1).detach() - baseline) # TODO
        # import numpy as np
        # update_agent_id = np.random.randint(0, self.n_agents)  # TODO
        # update_agent_id=update_agent_id+1 if update_agent_id<self.n_agents-1 else 0
        # policy_loss = - ((advantages[update_agent_id::self.n_agents] * log_pi_taken[update_agent_id::self.n_agents]) * mask[update_agent_id::self.n_agents]).sum() / mask[update_agent_id::self.n_agents].sum()
        policy_loss = - ((advantages * log_pi_taken) * mask).sum() / mask.sum()

        # policy_entropy = torch.mean(torch.exp(log_pi_taken) * log_pi_taken)
        # regu_policy_loss = policy_loss + self.args.loss_coeff_entropy * policy_entropy

        # Optimise agents
        self.policy_optimiser.zero_grad()
        policy_loss.backward()  # policy gradient
        grad_norm = torch.nn.utils.clip_grad_norm_(self.policy_params, self.args.grad_norm_clip)
        self.policy_optimiser.step()
        self.soft_update()

        if policy_loss.item() > 1e3:
            print("policy_loss:", policy_loss.item())

    def train_critic(self, batch, max_episode_len, train_step):  # train_step表示是第几次学习,用来控制更新target_net网络的参数
        '''
        在learn的时候,抽取到的数据是四维的,四个维度分别为 1——第几个episode 2——episode中第几个transition
        3——第几个agent的数据 4——具体obs维度。因为在选动作时不仅需要输入当前的inputs,还要给神经网络输入hidden_state,
        hidden_state和之前的经验相关,因此就不能随机抽取经验进行学习。所以这里一次抽取多个episode,然后一次给神经网络
        传入每个episode的同一个位置的transition
        '''
        episode_num = batch['o'].shape[0]
        self.init_hidden(episode_num)
        for key in batch.keys():  # 把batch里的数据转化成tensor
            if key == 'u':
                batch[key] = torch.tensor(batch[key], dtype=torch.long)
            else:
                batch[key] = torch.tensor(batch[key], dtype=torch.float32)
        s, s_next, u, r, avail_u, avail_u_next, terminated = batch['s'], batch['s_next'], batch['u'], \
                                                             batch['r'], batch['avail_u'], batch['avail_u_next'], \
                                                             batch['terminated']
        mask = 1 - batch["padded"].float()  # 用来把那些填充的经验的TD-error置0,从而不让它们影响到学习
        r = (r - r.mean()) / (r.std() + 1e-6)  # todo
        # 得到每个agent对应的Q值,维度为(episode个数, max_episode_len, n_agents, n_actions)
        q_evals, q_targets = self.get_q_values(batch, max_episode_len)

        if self.args.cuda:
            s = s.cuda()
            u = u.cuda()
            r = r.cuda()
            s_next = s_next.cuda()
            terminated = terminated.cuda()
            mask = mask.cuda()
        # 取每个agent动作对应的Q值,并且把最后不需要的一维去掉,因为最后一维只有一个值了
        q_evals = torch.gather(q_evals, dim=3, index=u).squeeze(3)

        episode_num = batch['o'].shape[0]
        q_targets_sample = []

        actions_prob = []
        actions_next_sample = []
        # self.agent.init_hidden(batch.batch_size)
        # self.policy_hidden= self.policy.init_hidden().unsqueeze(0).expand(batch_size, self.n_agents, -1)  # bav
        for transition_idx in range(max_episode_len):
            inputs, inputs_next = self._get_inputs(batch, transition_idx)  # 给obs加last_action、agent_id
            if self.args.cuda:
                inputs = inputs.cuda()
                inputs_next = inputs_next.cuda()
            # agent_outs, self.target_policy_hidden = self.target_policy(inputs_next, self.target_policy_hidden) TODO
            agent_outs, self.policy_hidden = self.agent.policy(inputs_next, self.policy_hidden)

            avail_actions_ = avail_u[:, transition_idx]
            reshaped_avail_actions = avail_actions_.reshape(episode_num * self.n_agents, -1)
            agent_outs[reshaped_avail_actions == 0] = -1e11

            # agent_outs = gumbel_softmax(agent_outs, hard=True)  # one-hot TODO ac_sample1
            # agent_outs = agent_outs.view(episode_num, self.n_agents, -1) # (1,3,9)
            # action_next = agent_outs.max(dim=2, keepdim=True)[1] # (1,3,1)
            # # action_next = torch.nonzero(agent_outs).squeeze()
            # actions_next_sample.append(action_next) # 选择动作的序号

            # agent_outs = gumbel_softmax(agent_outs, hard=True)  # 概率 TODO ac_mean
            agent_outs = F.softmax(agent_outs / 1, dim=1)  # 概率分布
            agent_outs = agent_outs.view(episode_num, self.n_agents, -1)
            actions_prob.append(agent_outs)  # 每个动作的概率

        actions_prob = torch.stack(actions_prob, dim=1)  # Concat over time
        # pi = actions_prob.view(-1, self.n_actions)
        # Calculated baseline
        baseline = torch.sum(actions_prob * q_targets, dim=-1).view(episode_num, max_episode_len,
                                                                    -1).detach()  # (1,60,3)
        q_targets_sample = baseline

        # actions_next_sample = torch.stack(actions_next_sample, dim=1)  # Concat over time # (1,60,3,1) TODO
        # # q_targets_sample = q_targets[:,:,actions_next]
        # q_targets_sample = torch.gather(q_targets, dim=3, index=actions_next_sample).squeeze(3) # (1,60,3)

        # # 得到target_q
        # q_targets[avail_u_next == 0.0] = - 9999999
        # q_targets = q_targets.max(dim=3)[0]

        q_total_eval = self.eval_qmix_net(q_evals, s)  # [1, 60, 1]
        q_total_target = self.target_qmix_net(q_targets_sample, s_next)

        targets = r + self.args.gamma * q_total_target * (1 - terminated)

        td_error = (q_total_eval - targets.detach())
        masked_td_error = mask * td_error  # 抹掉填充的经验的td_error

        # 不能直接用mean,因为还有许多经验是没用的,所以要求和再比真实的经验数,才是真正的均值
        critic_loss = (masked_td_error ** 2).sum() / mask.sum()
        self.optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.eval_parameters, self.args.grad_norm_clip)
        self.optimizer.step()

        if critic_loss.item() > 1e3:
            print("critic_loss:", critic_loss.item())

        # if train_step > 0 and train_step % self.args.target_update_cycle == 0: # 200
        #     self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
        #     self.target_qmix_net.load_state_dict(self.eval_qmix_net.state_dict())
        #     self.target_policy.load_state_dict(self.agent.policy.state_dict()) #TODO
        # Update the frozen target models

        if train_step % self.args.save_cycle == 0:  # how often to save the model args.save_cycle = 5000
            self.save_model(train_step)  # TODO

    def _get_inputs(self, batch, transition_idx):
        # 取出所有episode上该transition_idx的经验,u_onehot要取出所有,因为要用到上一条
        obs, obs_next, u_onehot = batch['o'][:, transition_idx], \
                                  batch['o_next'][:, transition_idx], batch['u_onehot'][:]
        episode_num = obs.shape[0]
        inputs, inputs_next = [], []
        inputs.append(obs)
        inputs_next.append(obs_next)
        # 给obs添加上一个动作、agent编号 30+9+3

        if self.args.last_action:
            if transition_idx == 0:  # 如果是第一条经验,就让前一个动作为0向量
                inputs.append(torch.zeros_like(u_onehot[:, transition_idx]))
            else:
                inputs.append(u_onehot[:, transition_idx - 1])
            inputs_next.append(u_onehot[:, transition_idx])
        if self.args.reuse_network:
            # 因为当前的obs三维的数据,每一维分别代表(episode编号,agent编号,obs维度),直接在dim_1上添加对应的向量
            # 即可,比如给agent_0后面加(1, 0, 0, 0, 0),表示5个agent中的0号。而agent_0的数据正好在第0行,那么需要加的
            # agent编号恰好就是一个单位矩阵,即对角线为1,其余为0
            inputs.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))
            inputs_next.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))
        # 要把obs中的三个拼起来,并且要把episode_num个episode、self.args.n_agents个agent的数据拼成40条(40,96)的数据,
        # 因为这里所有agent共享一个神经网络,每条数据中带上了自己的编号,所以还是自己的数据
        inputs = torch.cat([x.reshape(episode_num * self.args.n_agents, -1) for x in inputs], dim=1)
        inputs_next = torch.cat([x.reshape(episode_num * self.args.n_agents, -1) for x in inputs_next], dim=1)
        return inputs, inputs_next  # (3,42)

    def get_q_values(self, batch, max_episode_len):
        episode_num = batch['o'].shape[0]
        q_evals, q_targets = [], []
        for transition_idx in range(max_episode_len):
            inputs, inputs_next = self._get_inputs(batch, transition_idx)  # 给obs加last_action、agent_id
            if self.args.cuda:
                inputs = inputs.cuda()
                inputs_next = inputs_next.cuda()
                self.eval_hidden = self.eval_hidden.cuda()
                self.target_hidden = self.target_hidden.cuda()
            q_eval, self.eval_hidden = self.eval_rnn(inputs,
                                                     self.eval_hidden)  # inputs维度为(40,96),得到的q_eval维度为(40,n_actions)
            q_target, self.target_hidden = self.target_rnn(inputs_next, self.target_hidden)

            # 把q_eval维度重新变回(8, 5,n_actions)
            q_eval = q_eval.view(episode_num, self.n_agents, -1)
            q_target = q_target.view(episode_num, self.n_agents, -1)
            q_evals.append(q_eval)
            q_targets.append(q_target)
        # 得的q_eval和q_target是一个列表,列表里装着max_episode_len个数组,数组的的维度是(episode个数, n_agents,n_actions)
        # 把该列表转化成(episode个数, max_episode_len, n_agents,n_actions)的数组
        q_evals = torch.stack(q_evals, dim=1)
        q_targets = torch.stack(q_targets, dim=1)
        return q_evals, q_targets

    def init_hidden(self, episode_num):
        # 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden
        self.eval_hidden = torch.zeros((episode_num, self.n_agents, self.args.rnn_hidden_dim))
        self.target_hidden = torch.zeros((episode_num, self.n_agents, self.args.rnn_hidden_dim))
        self.policy_hidden = torch.zeros((episode_num, self.n_agents, self.args.rnn_hidden_dim))
        self.target_policy_hidden = torch.zeros((episode_num, self.n_agents, self.args.rnn_hidden_dim))  # TODO
        if self.args.cuda:
            self.eval_hidden = self.eval_hidden.cuda()
            self.target_hidden = self.target_hidden.cuda()
            self.policy_hidden = self.policy_hidden.cuda()
            self.target_policy_hidden = self.target_policy_hidden.cuda()

    def save_model(self, train_step):
        num = str(train_step // self.args.save_cycle)
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        torch.save(self.eval_qmix_net.state_dict(), self.model_dir + '/' + num + '_qmix_net_params.pkl')
        torch.save(self.eval_rnn.state_dict(), self.model_dir + '/' + num + '_rnn_net_params.pkl')
        torch.save(self.agent.policy.state_dict(), self.model_dir + '/' + num + '_policy_net_params.pkl')

    def soft_update(self):
        self.tau = 0.005
        for param, target_param in zip(self.eval_rnn.parameters(), self.target_rnn.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        for param, target_param in zip(self.eval_qmix_net.parameters(), self.target_qmix_net.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        for param, target_param in zip(self.agent.policy.parameters(), self.target_policy.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)