Example #1
0
    def testAttentionCritic(self):
        critic = AttentionCritic([(5, 3), (5, 2)], attend_heads=4)
        sample_frames = \
            [{AgentKey(0, '0-1'): AgentReplayFrame([2, 1, 2, 2, 3], [0, 1, 0], 3, False, [3, 1, 1, 2, 3]),
              AgentKey(0, '0-2'): AgentReplayFrame([1, 1, 3, 2, 1], [0, 1, 0], 5, False, [2, 1, 1, 2, 2]),
              AgentKey(0, '0-3'): AgentReplayFrame([2, 0, 3, 0, 2], [1, 0, 0], 1, False, [3, 0, 1, 3, 4]),
              AgentKey(1, '0-1'): AgentReplayFrame([2, 0, 3, 1, 2], [0, 1], 3, False, [3, 0, 1, 3, 4])},
             {AgentKey(0, '0-1'): AgentReplayFrame([2, 1, 2, 2, 3], [0, 1, 0], 3, False, [3, 1, 1, 2, 3]),
              AgentKey(0, '0-2'): AgentReplayFrame([1, 1, 3, 2, 1], [0, 1, 0], 5, False, [2, 1, 1, 2, 2]),
              AgentKey(0, '0-3'): AgentReplayFrame([2, 0, 3, 0, 2], [1, 0, 0], 0, True, [3, 0, 1, 3, 4]),
              AgentKey(1, '0-1'): AgentReplayFrame([2, 0, 3, 1, 2], [0, 1], 3, False, [3, 0, 1, 3, 4])},
             {AgentKey(0, '0-1'): AgentReplayFrame([2, 1, 2, 2, 3], [0, 1, 0], 3, False, [3, 1, 1, 2, 3]),
              AgentKey(0, '0-2'): AgentReplayFrame([1, 1, 3, 2, 1], [0, 1, 0], 5, False, [2, 1, 1, 2, 2]),
              AgentKey(1, '0-1'): AgentReplayFrame([2, 0, 3, 1, 2], [0, 1], 3, False, [3, 0, 1, 3, 4])}]

        sample_frames: Dict[AgentKey,
                            BatchedAgentReplayFrame] = preprocess_to_batch(
                                sample_frames)

        results: Dict[AgentKey, List[float]] = critic.forward(sample_frames)

        print(results)

        for k in sample_frames.keys():
            self.assertTrue(k in results)
Example #2
0
    def __init__(self,
                 agent_init_params,
                 sa_size,
                 gamma=0.95,
                 tau=0.01,
                 attend_tau=0.002,
                 pi_lr=0.01,
                 q_lr=0.01,
                 reward_scale=10.,
                 pol_hidden_dim=128,
                 critic_hidden_dim=128,
                 attend_heads=4,
                 **kwargs):
        """
        Inputs:
            agent_init_params (list of dict): List of dicts with parameters to
                                              initialize each agent
                num_in_pol (int): Input dimensions to policy
                num_out_pol (int): Output dimensions to policy
            sa_size (list of (int, int)): Size of state and action space for
                                          each agent
            gamma (float): Discount factor
            tau (float): Target update rate
            pi_lr (float): Learning rate for policy
            q_lr (float): Learning rate for critic
            reward_scale (float): Scaling for reward (has effect of optimal
                                  policy entropy)
            hidden_dim (int): Number of hidden dimensions for networks
        """
        self.nagents = len(sa_size)

        self.agents = [
            AttentionAgent(lr=pi_lr, hidden_dim=pol_hidden_dim, **params)
            for params in agent_init_params
        ]
        self.critic = AttentionCritic(sa_size,
                                      hidden_dim=critic_hidden_dim,
                                      attend_heads=attend_heads)
        self.target_critic = AttentionCritic(sa_size,
                                             hidden_dim=critic_hidden_dim,
                                             attend_heads=attend_heads)
        hard_update(self.target_critic, self.critic)
        self.critic_optimizer = Adam(self.critic.q_parameters(),
                                     lr=q_lr,
                                     weight_decay=1e-3)
        self.agent_init_params = agent_init_params
        self.gamma = gamma
        self.tau = tau
        self.attend_tau = attend_tau
        self.pi_lr = pi_lr
        self.q_lr = q_lr
        self.reward_scale = reward_scale
        self.pol_dev = 'cpu'  # device for policies
        self.critic_dev = 'cpu'  # device for critics
        self.trgt_pol_dev = 'cpu'  # device for target policies
        self.trgt_critic_dev = 'cpu'  # device for target critics
        self.niter = 0
Example #3
0
    def __init__(self, algo_config: List[Tuple[int, int]],
                 gamma=0.95, tau=0.01, pi_lr=0.01, q_lr=0.01,
                 reward_scale=10.,
                 pol_hidden_dim=128,
                 critic_hidden_dim=128, attend_heads=4,
                 **kwargs):
        """
        Inputs:
            algo_config (List[Tuple[int, int]]): Agent types which will exist in this environment
                Ex. [(20, 8), (20, 2)]
            gamma (float): Discount factor
            tau (float): Target update rate
            pi_lr (float): Learning rate for policy
            q_lr (float): Learning rate for critic
            reward_scale (float): Scaling for reward (has effect of optimal
                                  policy entropy)
            hidden_dim (int): Number of hidden dimensions for networks
        """

        print(algo_config)
        # Dictionary which maps agent type to its topology
        self.agents = [AttentionAgent(sdim, adim, lr=pi_lr, hidden_dim=pol_hidden_dim) for sdim, adim in algo_config]
        self.critic = AttentionCritic(algo_config, hidden_dim=critic_hidden_dim, attend_heads=attend_heads)
        self.target_critic = AttentionCritic(algo_config, hidden_dim=critic_hidden_dim, attend_heads=attend_heads)
        hard_update(self.target_critic, self.critic)
        self.critic_optimizer = Adam(self.critic.parameters(), lr=q_lr, weight_decay=1e-3)
        self.gamma = gamma
        self.tau = tau
        self.pi_lr = pi_lr
        self.q_lr = q_lr
        self.reward_scale = reward_scale
        self.pol_dev = 'cpu'  # device for policies
        self.critic_dev = 'cpu'  # device for critics
        self.trgt_pol_dev = 'cpu'  # device for target policies
        self.trgt_critic_dev = 'cpu'  # device for target critics
        self.niter = 0

        self.init_dict = {'gamma': gamma, 'tau': tau,
                          'pi_lr': pi_lr, 'q_lr': q_lr,
                          'reward_scale': reward_scale,
                          'pol_hidden_dim': pol_hidden_dim,
                          'critic_hidden_dim': critic_hidden_dim,
                          'attend_heads': attend_heads,
                          'algo_config': algo_config}
Example #4
0
    def __init__(self,
                 agent_init_params,
                 sa_size,
                 gamma=0.95,
                 tau=0.01,
                 pi_lr=0.01,
                 q_lr=0.01,
                 reward_scale=10.,
                 pol_hidden_dim=128,
                 critic_hidden_dim=128,
                 attend_heads=4,
                 **kwargs):
        """
        Inputs:
            agent_init_params (list of dict): List of dicts with parameters to initialize each agent
                num_in_pol (int): Input dimensions to policy
                num_out_pol (int): Output dimensions to policy
            sa_size (list of (int, int)): Size of state and action space for each agent
        """
        self.nagents = len(sa_size)

        self.agents = [
            AttentionAgent(lr=pi_lr, hidden_dim=pol_hidden_dim, **params)
            for params in agent_init_params
        ]
        self.critic = AttentionCritic(sa_size,
                                      hidden_dim=critic_hidden_dim,
                                      attend_heads=attend_heads)
        self.target_critic = AttentionCritic(sa_size,
                                             hidden_dim=critic_hidden_dim,
                                             attend_heads=attend_heads)
        hard_update(self.target_critic, self.critic)
        self.critic_optimizer = Adam(self.critic.parameters(),
                                     lr=q_lr,
                                     weight_decay=1e-3)
        self.agent_init_params = agent_init_params
        self.gamma = gamma
        self.tau = tau
        self.pi_lr = pi_lr
        self.q_lr = q_lr
        self.reward_scale = reward_scale
        self.niter = 0
Example #5
0
class AttentionSAC(object):
    """
    Wrapper class for SAC agents with central attention critic in multi-agent
    task
    """
    def __init__(self,
                 agent_init_params,
                 sa_size,
                 gamma=0.95,
                 tau=0.01,
                 pi_lr=0.01,
                 q_lr=0.01,
                 reward_scale=10.,
                 pol_hidden_dim=128,
                 critic_hidden_dim=128,
                 attend_heads=4,
                 **kwargs):
        """
        Inputs:
            agent_init_params (list of dict): List of dicts with parameters to
                                              initialize each agent
                num_in_pol (int): Input dimensions to policy
                num_out_pol (int): Output dimensions to policy
            sa_size (list of (int, int)): Size of state and action space for
                                          each agent
            gamma (float): Discount factor
            tau (float): Target update rate
            pi_lr (float): Learning rate for policy
            q_lr (float): Learning rate for critic
            reward_scale (float): Scaling for reward (has effect of optimal
                                  policy entropy)
            hidden_dim (int): Number of hidden dimensions for networks
        """
        self.nagents = len(sa_size)

        self.agents = [
            AttentionAgent(lr=pi_lr, hidden_dim=pol_hidden_dim, **params)
            for params in agent_init_params
        ]
        self.critic = AttentionCritic(sa_size,
                                      hidden_dim=critic_hidden_dim,
                                      attend_heads=attend_heads)
        self.target_critic = AttentionCritic(sa_size,
                                             hidden_dim=critic_hidden_dim,
                                             attend_heads=attend_heads)
        hard_update(self.target_critic, self.critic)
        self.critic_optimizer = Adam(self.critic.parameters(),
                                     lr=q_lr,
                                     weight_decay=1e-3)
        self.agent_init_params = agent_init_params
        self.gamma = gamma
        self.tau = tau
        self.pi_lr = pi_lr
        self.q_lr = q_lr
        self.reward_scale = reward_scale
        self.pol_dev = 'cpu'  # device for policies
        self.critic_dev = 'cpu'  # device for critics
        self.trgt_pol_dev = 'cpu'  # device for target policies
        self.trgt_critic_dev = 'cpu'  # device for target critics
        self.niter = 0

    @property
    def policies(self):
        return [a.policy for a in self.agents]

    @property
    def target_policies(self):
        return [a.target_policy for a in self.agents]

    def step(self, observations, explore=False):
        """
        Take a step forward in environment with all agents
        Inputs:
            observations: List of observations for each agent
        Outputs:
            actions: List of actions for each agent
        """
        return [
            a.step(obs, explore=explore)
            for a, obs in zip(self.agents, observations)
        ]

    def update_critic(self, sample, soft=True, logger=None, **kwargs):
        """
        Update central critic for all agents
        """
        obs, acs, rews, next_obs, dones = sample
        # Q loss
        next_acs = []
        next_log_pis = []
        for pi, ob in zip(self.target_policies, next_obs):
            curr_next_ac, curr_next_log_pi = pi(ob, return_log_pi=True)
            next_acs.append(curr_next_ac)
            next_log_pis.append(curr_next_log_pi)
        trgt_critic_in = list(zip(next_obs, next_acs))
        critic_in = list(zip(obs, acs))
        next_qs = self.target_critic(trgt_critic_in)
        critic_rets = self.critic(critic_in,
                                  regularize=True,
                                  logger=logger,
                                  niter=self.niter)
        q_loss = 0
        for a_i, nq, log_pi, (pq, regs) in zip(range(self.nagents), next_qs,
                                               next_log_pis, critic_rets):
            target_q = (rews[a_i].view(-1, 1) + self.gamma * nq *
                        (1 - dones[a_i].view(-1, 1)))
            if soft:
                target_q -= log_pi / self.reward_scale
            q_loss += MSELoss(pq, target_q.detach())
            for reg in regs:
                q_loss += reg  # regularizing attention
        q_loss.backward()
        self.critic.scale_shared_grads()
        grad_norm = torch.nn.utils.clip_grad_norm(self.critic.parameters(),
                                                  10 * self.nagents)
        self.critic_optimizer.step()
        self.critic_optimizer.zero_grad()

        if logger is not None:
            logger.add_scalar('losses/q_loss', q_loss, self.niter)
            logger.add_scalar('grad_norms/q', grad_norm, self.niter)
        self.niter += 1

    def update_policies(self, sample, soft=True, logger=None, **kwargs):
        obs, acs, rews, next_obs, dones = sample
        samp_acs = []
        all_probs = []
        all_log_pis = []
        all_pol_regs = []

        for a_i, pi, ob in zip(range(self.nagents), self.policies, obs):
            curr_ac, probs, log_pi, pol_regs, ent = pi(ob,
                                                       return_all_probs=True,
                                                       return_log_pi=True,
                                                       regularize=True,
                                                       return_entropy=True)
            logger.add_scalar('agent%i/policy_entropy' % a_i, ent, self.niter)
            samp_acs.append(curr_ac)
            all_probs.append(probs)
            all_log_pis.append(log_pi)
            all_pol_regs.append(pol_regs)

        critic_in = list(zip(obs, samp_acs))
        critic_rets = self.critic(critic_in, return_all_q=True)
        for a_i, probs, log_pi, pol_regs, (q, all_q) in zip(
                range(self.nagents), all_probs, all_log_pis, all_pol_regs,
                critic_rets):
            curr_agent = self.agents[a_i]
            v = (all_q * probs).sum(dim=1, keepdim=True)
            pol_target = q - v
            if soft:
                pol_loss = (
                    log_pi *
                    (log_pi / self.reward_scale - pol_target).detach()).mean()
            else:
                pol_loss = (log_pi * (-pol_target).detach()).mean()
            for reg in pol_regs:
                pol_loss += 1e-3 * reg  # policy regularization
            # don't want critic to accumulate gradients from policy loss
            disable_gradients(self.critic)
            pol_loss.backward()
            enable_gradients(self.critic)

            grad_norm = torch.nn.utils.clip_grad_norm(
                curr_agent.policy.parameters(), 0.5)
            curr_agent.policy_optimizer.step()
            curr_agent.policy_optimizer.zero_grad()

            if logger is not None:
                logger.add_scalar('agent%i/losses/pol_loss' % a_i, pol_loss,
                                  self.niter)
                logger.add_scalar('agent%i/grad_norms/pi' % a_i, grad_norm,
                                  self.niter)

    def update_all_targets(self):
        """
        Update all target networks (called after normal updates have been
        performed for each agent)
        """
        soft_update(self.target_critic, self.critic, self.tau)
        for a in self.agents:
            soft_update(a.target_policy, a.policy, self.tau)

    def prep_training(self, device='gpu'):
        self.critic.train()
        self.target_critic.train()
        for a in self.agents:
            a.policy.train()
            a.target_policy.train()
        if device == 'gpu':
            fn = lambda x: x.cuda()
        else:
            fn = lambda x: x.cpu()
        if not self.pol_dev == device:
            for a in self.agents:
                a.policy = fn(a.policy)
            self.pol_dev = device
        if not self.critic_dev == device:
            self.critic = fn(self.critic)
            self.critic_dev = device
        if not self.trgt_pol_dev == device:
            for a in self.agents:
                a.target_policy = fn(a.target_policy)
            self.trgt_pol_dev = device
        if not self.trgt_critic_dev == device:
            self.target_critic = fn(self.target_critic)
            self.trgt_critic_dev = device

    def prep_rollouts(self, device='cpu'):
        for a in self.agents:
            a.policy.eval()
        if device == 'gpu':
            fn = lambda x: x.cuda()
        else:
            fn = lambda x: x.cpu()
        # only need main policy for rollouts
        if not self.pol_dev == device:
            for a in self.agents:
                a.policy = fn(a.policy)
            self.pol_dev = device

    def save(self, filename):
        """
        Save trained parameters of all agents into one file
        """
        self.prep_training(
            device='cpu')  # move parameters to CPU before saving
        save_dict = {
            'init_dict': self.init_dict,
            'agent_params': [a.get_params() for a in self.agents],
            'critic_params': {
                'critic': self.critic.state_dict(),
                'target_critic': self.target_critic.state_dict(),
                'critic_optimizer': self.critic_optimizer.state_dict()
            }
        }
        torch.save(save_dict, filename)

    @classmethod
    def init_from_env(cls,
                      env,
                      gamma=0.95,
                      tau=0.01,
                      pi_lr=0.01,
                      q_lr=0.01,
                      reward_scale=10.,
                      pol_hidden_dim=128,
                      critic_hidden_dim=128,
                      attend_heads=4,
                      **kwargs):
        """
        Instantiate instance of this class from multi-agent environment

        env: Multi-agent Gym environment
        gamma: discount factor
        tau: rate of update for target networks
        lr: learning rate for networks
        hidden_dim: number of hidden dimensions for networks
        """
        agent_init_params = []
        sa_size = []
        for acsp, obsp in zip(env.action_space, env.observation_space):
            agent_init_params.append({
                'num_in_pol': obsp[0],
                'num_out_pol': acsp
            })
            sa_size.append((obsp[0], acsp))

        init_dict = {
            'gamma': gamma,
            'tau': tau,
            'pi_lr': pi_lr,
            'q_lr': q_lr,
            'reward_scale': reward_scale,
            'pol_hidden_dim': pol_hidden_dim,
            'critic_hidden_dim': critic_hidden_dim,
            'attend_heads': attend_heads,
            'agent_init_params': agent_init_params,
            'sa_size': sa_size
        }
        instance = cls(**init_dict)
        instance.init_dict = init_dict
        return instance

    @classmethod
    def init_from_save(cls, filename, load_critic=False):
        """
        Instantiate instance of this class from file created by 'save' method
        """
        save_dict = torch.load(filename)
        instance = cls(**save_dict['init_dict'])
        instance.init_dict = save_dict['init_dict']
        for a, params in zip(instance.agents, save_dict['agent_params']):
            a.load_params(params)

        if load_critic:
            critic_params = save_dict['critic_params']
            instance.critic.load_state_dict(critic_params['critic'])
            instance.target_critic.load_state_dict(
                critic_params['target_critic'])
            instance.critic_optimizer.load_state_dict(
                critic_params['critic_optimizer'])
        return instance
Example #6
0
import numpy as np
from torch import Tensor
from typing import List
from utils.critics import AttentionCritic
from utils.core import *
from utils.buffer import AgentReplayFrame
from torchviz import *

critic = AttentionCritic([(5, 3), (5, 2)], attend_heads=4)
sample_frames = \
    [{AgentKey(0, '0-1'): AgentReplayFrame([2, 1, 2, 2, 3], [0, 1, 0], 3, False, [3, 1, 1, 2, 3]),
      AgentKey(0, '0-2'): AgentReplayFrame([1, 1, 3, 2, 1], [0, 1, 0], 5, False, [2, 1, 1, 2, 2]),
      AgentKey(0, '0-3'): AgentReplayFrame([2, 0, 3, 0, 2], [1, 0, 0], 1, False, [3, 0, 1, 3, 4]),
      AgentKey(1, '0-1'): AgentReplayFrame([2, 0, 3, 1, 2], [0, 1], 3, False, [3, 0, 1, 3, 4])},
     {AgentKey(0, '0-1'): AgentReplayFrame([2, 1, 2, 2, 3], [0, 1, 0], 3, False, [3, 1, 1, 2, 3]),
      AgentKey(0, '0-2'): AgentReplayFrame([1, 1, 3, 2, 1], [0, 1, 0], 5, False, [2, 1, 1, 2, 2]),
      AgentKey(0, '0-3'): AgentReplayFrame([2, 0, 3, 0, 2], [1, 0, 0], 0, True, [3, 0, 1, 3, 4]),
      AgentKey(1, '0-1'): AgentReplayFrame([2, 0, 3, 1, 2], [0, 1], 3, False, [3, 0, 1, 3, 4])},
     {AgentKey(0, '0-1'): AgentReplayFrame([2, 1, 2, 2, 3], [0, 1, 0], 3, False, [3, 1, 1, 2, 3]),
      AgentKey(0, '0-2'): AgentReplayFrame([1, 1, 3, 2, 1], [0, 1, 0], 5, False, [2, 1, 1, 2, 2]),
      AgentKey(1, '0-1'): AgentReplayFrame([2, 0, 3, 1, 2], [0, 1], 3, False, [3, 0, 1, 3, 4])}]

sample_frames: Dict[AgentKey, BatchedAgentReplayFrame] = preprocess_to_batch(
    sample_frames)

results: Dict[AgentKey, List[float]] = critic.forward(sample_frames)

print(results)

dot = make_dot(results[AgentKey(1, '0-1')][0].mean(),
               params=dict(critic.named_parameters()))
Example #7
0
class AttentionSAC(object):
    """
    Wrapper class for SAC agents with central attention critic in multi-agent
    task

    Had to change a couple things in this class to make agent networks work by
    type as opposed to a constant number of agents
    """

    def __init__(self, algo_config: List[Tuple[int, int]],
                 gamma=0.95, tau=0.01, pi_lr=0.01, q_lr=0.01,
                 reward_scale=10.,
                 pol_hidden_dim=128,
                 critic_hidden_dim=128, attend_heads=4,
                 **kwargs):
        """
        Inputs:
            algo_config (List[Tuple[int, int]]): Agent types which will exist in this environment
                Ex. [(20, 8), (20, 2)]
            gamma (float): Discount factor
            tau (float): Target update rate
            pi_lr (float): Learning rate for policy
            q_lr (float): Learning rate for critic
            reward_scale (float): Scaling for reward (has effect of optimal
                                  policy entropy)
            hidden_dim (int): Number of hidden dimensions for networks
        """

        print(algo_config)
        # Dictionary which maps agent type to its topology
        self.agents = [AttentionAgent(sdim, adim, lr=pi_lr, hidden_dim=pol_hidden_dim) for sdim, adim in algo_config]
        self.critic = AttentionCritic(algo_config, hidden_dim=critic_hidden_dim, attend_heads=attend_heads)
        self.target_critic = AttentionCritic(algo_config, hidden_dim=critic_hidden_dim, attend_heads=attend_heads)
        hard_update(self.target_critic, self.critic)
        self.critic_optimizer = Adam(self.critic.parameters(), lr=q_lr, weight_decay=1e-3)
        self.gamma = gamma
        self.tau = tau
        self.pi_lr = pi_lr
        self.q_lr = q_lr
        self.reward_scale = reward_scale
        self.pol_dev = 'cpu'  # device for policies
        self.critic_dev = 'cpu'  # device for critics
        self.trgt_pol_dev = 'cpu'  # device for target policies
        self.trgt_critic_dev = 'cpu'  # device for target critics
        self.niter = 0

        self.init_dict = {'gamma': gamma, 'tau': tau,
                          'pi_lr': pi_lr, 'q_lr': q_lr,
                          'reward_scale': reward_scale,
                          'pol_hidden_dim': pol_hidden_dim,
                          'critic_hidden_dim': critic_hidden_dim,
                          'attend_heads': attend_heads,
                          'algo_config': algo_config}

    @property
    def policies(self):
        return [a.policy for a in self.agents]

    @property
    def target_policies(self):
        return [a.target_policy for a in self.agents]

    def step(self, observations: Dict[AgentKey, AgentObservation], explore=False) -> Dict[AgentKey, AgentAction]:
        return {k: AgentAction(self.agents[k.type].step(Variable(torch.Tensor(np.array([v.obs])), requires_grad=False), explore=explore).tolist()[0])
                for k, v in observations.items()}

    def update_critic(self,
                      finalized_frames: Dict[AgentKey, BatchedAgentReplayFrame],
                      soft=True, logger=None, **kwargs):
        """
        Update central critic for all agents
        """

        # Q loss
        next_acs: Dict[AgentKey, Tensor] = {}
        next_log_pis: Dict[AgentKey, float] = {}
        for k, v in finalized_frames.items():
            pi = self.target_policies[k.type]
            ob = v.next_obs
            curr_next_ac, curr_next_log_pi = pi(ob, return_log_pi=True)
            next_acs[k] = curr_next_ac
            next_log_pis[k] = curr_next_log_pi
        trgt_critic_in = {k: BatchedAgentObservationAction(v.next_obs, next_acs[k]) for k, v in finalized_frames.items()}
        critic_in = {k: BatchedAgentObservationAction(v.obs, v.acs) for k, v in finalized_frames.items()}
        next_qs = self.target_critic(trgt_critic_in) # calls "forward", also TODO this doesn't need to be computed for frames in which done = True
        curr_qs = self.critic(critic_in, regularize=True, logger=logger, niter=self.niter)
        q_loss = 0
        for k, v in finalized_frames.items():
            (nq,) = next_qs[k]
            log_pi = next_log_pis[k]
            (pq, regs) = curr_qs[k]

            target_q = (v.rews +
                        self.gamma * nq *
                        (1 - v.dones))
            if soft:
                target_q -= log_pi / self.reward_scale
            q_loss += MSELoss(pq, target_q.detach())
            for reg in regs:
                q_loss += reg  # regularizing attention
        q_loss.backward()
        num_agents = len(finalized_frames.items())
        self.critic.scale_shared_grads(num_agents)
        grad_norm = torch.nn.utils.clip_grad_norm_(
            self.critic.parameters(), 10 * num_agents)
        self.critic_optimizer.step()
        self.critic_optimizer.zero_grad()

        if logger is not None:
            logger.add_scalar('losses/q_loss', q_loss, self.niter)
            logger.add_scalar('grad_norms/q', grad_norm, self.niter)
        self.niter += 1

    def update_policies(self,
                        finalized_frames: Dict[AgentKey, BatchedAgentReplayFrame],
                        soft=True, logger=None, **kwargs):
        samp_acs = {}
        all_probs = {}
        all_log_pis = {}
        all_pol_regs = {}

        for k, v in finalized_frames.items():
            pi = self.policies[k.type]
            ob = v.obs
            curr_ac, probs, log_pi, pol_regs, ent = pi(
                ob, return_all_probs=True, return_log_pi=True,
                regularize=True, return_entropy=True)
            if logger is not None:
                logger.add_scalar('agent%s/policy_entropy' % k.id, ent,
                                  self.niter)
            samp_acs[k] = curr_ac
            all_probs[k] = probs
            all_log_pis[k] = log_pi
            all_pol_regs[k] = pol_regs

        critic_in = {k: BatchedAgentObservationAction(v.obs, samp_acs[k]) for k, v in finalized_frames.items()}
        critic_rets = self.critic(critic_in, return_all_q=True)

        for k, val in finalized_frames.items():
            probs = all_probs[k]
            log_pi = all_log_pis[k]
            pol_regs = all_pol_regs[k]
            (q, all_q) = critic_rets[k]

            curr_agent = self.agents[k.type]
            v = (all_q * probs).sum(dim=1, keepdim=True)
            pol_target = q - v
            if soft:
                pol_loss = (log_pi * (log_pi / self.reward_scale - pol_target).detach()).mean()
            else:
                pol_loss = (log_pi * (-pol_target).detach()).mean()
            for reg in pol_regs:
                pol_loss += 1e-3 * reg  # policy regularization
            # don't want critic to accumulate gradients from policy loss
            disable_gradients(self.critic)
            # https://stackoverflow.com/questions/53994625/how-can-i-process-multi-loss-in-pytorch
            pol_loss.backward()
            enable_gradients(self.critic)

        for curr_agent in self.agents:
            # grad_norm = torch.nn.utils.clip_grad_norm_(
            #     curr_agent.policy.parameters(), 0.5)
            curr_agent.policy_optimizer.step()
            curr_agent.policy_optimizer.zero_grad()

            # if logger is not None:
            #     logger.add_scalar('agent%s/losses/pol_loss' % k.id,
            #                       pol_loss, self.niter)
            #     logger.add_scalar('agent%s/grad_norms/pi' % k.id,
            #                       grad_norm, self.niter)

    def update_all_targets(self):
        """
        Update all target networks (called after normal updates have been
        performed for each agent)
        """
        soft_update(self.target_critic, self.critic, self.tau)
        for a in self.agents:
            soft_update(a.target_policy, a.policy, self.tau)

    def prep_training(self, device='gpu'):
        self.critic.train()
        self.target_critic.train()
        for a in self.agents:
            a.policy.train()
            a.target_policy.train()
        if device == 'gpu':
            fn = lambda x: x.cuda()
        else:
            fn = lambda x: x.cpu()
        if not self.pol_dev == device:
            for a in self.agents:
                a.policy = fn(a.policy)
            self.pol_dev = device
        if not self.critic_dev == device:
            self.critic = fn(self.critic)
            self.critic_dev = device
        if not self.trgt_pol_dev == device:
            for a in self.agents:
                a.target_policy = fn(a.target_policy)
            self.trgt_pol_dev = device
        if not self.trgt_critic_dev == device:
            self.target_critic = fn(self.target_critic)
            self.trgt_critic_dev = device

    def prep_rollouts(self, device='cpu'):
        for a in self.agents:
            a.policy.eval()
        if device == 'gpu':
            fn = lambda x: x.cuda()
        else:
            fn = lambda x: x.cpu()
        # only need main policy for rollouts
        if not self.pol_dev == device:
            for a in self.agents:
                a.policy = fn(a.policy)
            self.pol_dev = device

    def save(self, filename):
        """
        Save trained parameters of all agents into one file
        """
        self.prep_training(device='cpu')  # move parameters to CPU before saving
        save_dict = {'init_dict': self.init_dict,
                     'agent_params': [a.get_params() for a in self.agents],
                     'critic_params': {'critic': self.critic.state_dict(),
                                       'target_critic': self.target_critic.state_dict(),
                                       'critic_optimizer': self.critic_optimizer.state_dict()}}
        torch.save(save_dict, filename)

    @classmethod
    def init_from_save(cls, filename, load_critic=False):
        """
        Instantiate instance of this class from file created by 'save' method
        """
        save_dict = torch.load(filename)
        instance = cls(**save_dict['init_dict'])
        instance.init_dict = save_dict['init_dict']
        for a, params in zip_equal(instance.agents, save_dict['agent_params']):
            a.load_params(params)

        if load_critic:
            critic_params = save_dict['critic_params']
            instance.critic.load_state_dict(critic_params['critic'])
            instance.target_critic.load_state_dict(critic_params['target_critic'])
            instance.critic_optimizer.load_state_dict(critic_params['critic_optimizer'])
        return instance