示例#1
0
    def update_policies(self, sample, soft=True, logger=None, **kwargs):
        obs, acs, rews, next_obs, dones = sample
        samp_acs = []
        all_probs = []
        all_log_pis = []
        all_pol_regs = []

        each_input = self.attention_medium(obs)

        for a_i, pi, inp in zip(range(self.nagents), self.policies,
                                each_input):
            curr_ac, probs, log_pi, pol_regs, ent = pi(inp,
                                                       return_all_probs=True,
                                                       return_log_pi=True,
                                                       regularize=True,
                                                       return_entropy=True)
            logger.add_scalar('agent%i/policy_entropy' % a_i, ent, self.niter)
            samp_acs.append(curr_ac)
            all_probs.append(probs)
            all_log_pis.append(log_pi)
            all_pol_regs.append(pol_regs)

        critic_in = list(zip(obs, samp_acs))
        critic_rets = self.critic(critic_in, return_all_q=True)

        for a_i, probs, log_pi, pol_regs, (q, all_q) in zip(
                range(self.nagents), all_probs, all_log_pis, all_pol_regs,
                critic_rets):
            curr_agent = self.agents[a_i]
            v = (all_q * probs).sum(dim=1, keepdim=True)
            pol_target = q - v
            if soft:
                pol_loss = (
                    log_pi *
                    (log_pi / self.reward_scale - pol_target).detach()).mean()
            else:
                pol_loss = (log_pi * (-pol_target).detach()).mean()
            for reg in pol_regs:
                pol_loss += 1e-3 * reg  # policy regularization
            # don't want critic to accumulate gradients from policy loss
            disable_gradients(self.critic)
            pol_loss.backward(retain_graph=(a_i != self.nagents - 1))
            enable_gradients(self.critic)

            grad_norm = torch.nn.utils.clip_grad_norm(
                curr_agent.policy.parameters(), 0.5)
            curr_agent.policy_optimizer.step()
            curr_agent.policy_optimizer.zero_grad()

            if logger is not None:
                logger.add_scalar('agent%i/losses/pol_loss' % a_i, pol_loss,
                                  self.niter)
                # logger.add_scalar('agent%i/grad_norms/pi' % a_i,
                #                  grad_norm, self.niter)

        # self.attention_medium.scale_shared_grads()
        torch.nn.utils.clip_grad_norm(self.attention_medium.parameters(), 0.5)
        self.attention_medium_optimizer.step()
        self.attention_medium_optimizer.zero_grad()
示例#2
0
    def update_policies(self,
                        finalized_frames: Dict[AgentKey, BatchedAgentReplayFrame],
                        soft=True, logger=None, **kwargs):
        samp_acs = {}
        all_probs = {}
        all_log_pis = {}
        all_pol_regs = {}

        for k, v in finalized_frames.items():
            pi = self.policies[k.type]
            ob = v.obs
            curr_ac, probs, log_pi, pol_regs, ent = pi(
                ob, return_all_probs=True, return_log_pi=True,
                regularize=True, return_entropy=True)
            if logger is not None:
                logger.add_scalar('agent%s/policy_entropy' % k.id, ent,
                                  self.niter)
            samp_acs[k] = curr_ac
            all_probs[k] = probs
            all_log_pis[k] = log_pi
            all_pol_regs[k] = pol_regs

        critic_in = {k: BatchedAgentObservationAction(v.obs, samp_acs[k]) for k, v in finalized_frames.items()}
        critic_rets = self.critic(critic_in, return_all_q=True)

        for k, val in finalized_frames.items():
            probs = all_probs[k]
            log_pi = all_log_pis[k]
            pol_regs = all_pol_regs[k]
            (q, all_q) = critic_rets[k]

            curr_agent = self.agents[k.type]
            v = (all_q * probs).sum(dim=1, keepdim=True)
            pol_target = q - v
            if soft:
                pol_loss = (log_pi * (log_pi / self.reward_scale - pol_target).detach()).mean()
            else:
                pol_loss = (log_pi * (-pol_target).detach()).mean()
            for reg in pol_regs:
                pol_loss += 1e-3 * reg  # policy regularization
            # don't want critic to accumulate gradients from policy loss
            disable_gradients(self.critic)
            # https://stackoverflow.com/questions/53994625/how-can-i-process-multi-loss-in-pytorch
            pol_loss.backward()
            enable_gradients(self.critic)

        for curr_agent in self.agents:
            # grad_norm = torch.nn.utils.clip_grad_norm_(
            #     curr_agent.policy.parameters(), 0.5)
            curr_agent.policy_optimizer.step()
            curr_agent.policy_optimizer.zero_grad()
示例#3
0
    def update_policies(self, sample, soft=True, logger=None, **kwargs):
        obs, acs, rews, penalties_1, next_obs, dones = sample
        samp_acs = []
        all_probs = []
        all_log_pis = []
        all_pol_regs = []

        for a_i, pi, ob in zip(range(self.nagents), self.policies, obs):
            curr_ac, probs, log_pi, pol_regs, ent = pi(ob,
                                                       return_all_probs=True,
                                                       return_log_pi=True,
                                                       regularize=True,
                                                       return_entropy=True)
            logger.add_scalar('agent%i/policy_entropy' % a_i, ent, self.niter)
            samp_acs.append(curr_ac)
            all_probs.append(probs)
            all_log_pis.append(log_pi)
            all_pol_regs.append(pol_regs)

        critic_in = list(zip(obs, samp_acs))
        critic_rets = self.critic(critic_in, return_all_q=True)
        for a_i, probs, log_pi, pol_regs, (q, all_q) in zip(
                range(self.nagents), all_probs, all_log_pis, all_pol_regs,
                critic_rets):
            curr_agent = self.agents[a_i]
            v = (all_q * probs).sum(dim=1, keepdim=True)
            pol_target = q - v
            if soft:
                pol_loss = -(
                    log_pi *
                    (log_pi / self.reward_scale - pol_target).detach()).mean()
            else:
                pol_loss = -(log_pi * (-pol_target).detach()).mean()
            for reg in pol_regs:
                pol_loss += 1e-3 * reg  # policy regularization

            disable_gradients(self.critic)
            pol_loss.backward()
            enable_gradients(self.critic)
            grad_norm = torch.nn.utils.clip_grad_norm_(
                curr_agent.policy.parameters(), 0.5)
            curr_agent.policy_optimizer.step()
            curr_agent.policy_optimizer.zero_grad()

            if logger is not None:
                logger.add_scalar('agent%i/losses/pol_loss' % a_i, pol_loss,
                                  self.niter)
示例#4
0
    def update_policies(self, sample, soft=True, logger=None, **kwargs):
        state, obs, acs, rews, next_state, next_obs, dones = sample
        samp_acs = []
        all_probs = []
        all_log_probs = []
        all_log_pis = []
        all_pol_regs = []
        n_pol_heads = self.n_intr_rew_types + int(self.sep_extr_head)
        for a_i, pi, ob in zip(range(self.nagents), self.policies, obs):
            pi_outs = pi(ob,
                         return_all_probs=True,
                         return_all_log_probs=True,
                         return_log_pi=True,
                         regularize=True,
                         return_entropy=True,
                         head=None)
            curr_ac, probs, log_probs, log_pi, pol_regs, ent = list(
                zip(*pi_outs))
            for j in range(n_pol_heads):
                logger.add_scalar('agent%i/pol%i_entropy' % (a_i, j), ent[j],
                                  self.niter)
            samp_acs.append(curr_ac)
            all_probs.append(probs)
            all_log_probs.append(log_probs)
            all_log_pis.append(log_pi)
            all_pol_regs.append(pol_regs)

        critic_in = (state, obs, samp_acs)
        critic_rets = self.critic(critic_in, return_all_q=True)
        for a_i, probs, log_pi, pol_regs, ((eqs, iqs),
                                           (all_eqs, all_iqs)) in zip(
                                               range(self.nagents), all_probs,
                                               all_log_pis, all_pol_regs,
                                               critic_rets):
            curr_agent = self.agents[a_i]
            pol_loss = 0.0
            if self.sep_extr_head:
                all_q = all_eqs[0]
                q = eqs[0]
                v = (all_q * probs[0]).sum(dim=1, keepdim=True)
                pol_target = q - v
                if soft:
                    pol_loss += (log_pi[0] * (log_pi[0] / self.reward_scale -
                                              pol_target).detach()).mean()
                else:
                    pol_loss += (log_pi[0] * (-pol_target).detach()).mean()
                for reg in pol_regs[0]:
                    pol_loss += 1e-3 * reg  # policy regularization
                eqs = eqs[1:]
                all_eqs = all_eqs[1:]
                probs = probs[1:]
                log_pi = log_pi[1:]
                pol_regs = pol_regs[1:]
            for j in range(self.n_intr_rew_types):
                q = eqs[j] + self.beta * iqs[j]
                all_q = all_eqs[j] + self.beta * all_iqs[j]
                v = (all_q * probs[j]).sum(dim=1, keepdim=True)
                pol_target = q - v
                if soft:
                    pol_loss += (log_pi[j] * (log_pi[j] / self.reward_scale -
                                              pol_target).detach()).mean()
                else:
                    pol_loss += (log_pi[j] * (-pol_target).detach()).mean()
                for reg in pol_regs[j]:
                    pol_loss += 1e-3 * reg  # policy regularization
            # don't want critic to accumulate gradients from policy loss
            disable_gradients(self.critic)
            pol_loss.backward()
            curr_agent.policy.scale_shared_grads()
            enable_gradients(self.critic)

            grad_norm = torch.nn.utils.clip_grad_norm_(
                curr_agent.policy.parameters(), self.grad_norm_clip / 100.)
            curr_agent.policy_optimizer.step()
            curr_agent.policy_optimizer.zero_grad()

            if logger is not None:
                logger.add_scalar('grad_norms/agent%i_policy' % a_i, grad_norm,
                                  self.niter)
                logger.add_scalar('agent%i/losses/pol_loss' % a_i, pol_loss,
                                  self.niter)