Ejemplo n.º 1
0
    def update_moa(self, sample, agent_i, parallel=False, grad_norm=0.5):
        """ Update parameters of moa networks based on lastest sample from replay buffer
        Arguments:
            sample: [(B,D)]*N, obs, next_obs, action can be [dict (B,D)]*N
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
        """
        # [(B,1,D)]*N or [dict (B,1,D)]*N
        interm_sample = self.add_virtual_dim(sample)   
        # place current agent subsample to first in sample batch
        obs, acs, rews, next_obs, dones = [
            switch_list(s, agent_i) for s in interm_sample
        ]
        bs, ts, _ = obs[0].shape 
        curr_agent = self.agents[agent_i]
        curr_agent.init_moa_hidden(bs)  # use pre-defined init hiddens 
        results = {}

        # perform update on each moa agent 
        for agent_j in range(1, self.nagents):
            # current agent's j-th moa 
            pi_j = curr_agent.moa_policies[agent_j]
            curr_agent.moa_optimizers[agent_j].zero_grad()

            log_prob_j, entropy_j = curr_agent.evaluate_moa_action(
                agent_j, self.wrap_action(acs[agent_j]), obs[agent_j]
            )   # (B,T,1)
            log_prob_loss = -log_prob_j.reshape(bs*ts, -1).mean()
            entropy_loss = -entropy_j.reshape(bs*ts, -1).mean()
            
            moa_loss_j = log_prob_loss + self.moa_entropy_coeff * entropy_loss 
            moa_loss_j.backward()
            if parallel:
                average_gradients(pi_j)
            if grad_norm > 0:
                torch.nn.utils.clip_grad_norm(pi_j.parameters(), grad_norm)
            curr_agent.moa_optimizers[agent_j].step()

            # loggings (might be overwhelming)
            for k, v in zip(
                ["log_prob_loss", "entropy_loss"], 
                [log_prob_loss, entropy_loss]
            ):
                key = "agent_{}/moa_{}/{}".format(agent_i, agent_j, k)
                value = v.data.cpu().numpy()
                results[key] = value
                self.agent_losses[key].append(value)

        return results  
Ejemplo n.º 2
0
    def update(self, sample, agent_i, parallel=False, logger=None):
        """
        Update parameters of agent model based on sample from replay buffer
        Inputs:
            sample: tuple of (observations, actions, rewards, next
                    observations, and episode end masks) sampled randomly from
                    the replay buffer. Each is a list with entries
                    corresponding to each agent
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
            logger (SummaryWriter from Tensorboard-Pytorch):
                If passed in, important quantities will be logged
        """
        obs, acs, rews, next_obs, dones = sample
        curr_agent = self.agents[agent_i]

        curr_agent.critic_optimizer.zero_grad()
        if self.alg_types[agent_i] == 'MADDPG':
            if self.discrete_action:  # one-hot encode action
                all_trgt_acs = [
                    onehot_from_logits(pi(nobs))
                    for pi, nobs in zip(self.target_policies, next_obs)
                ]
            else:
                all_trgt_acs = [
                    pi(nobs)
                    for pi, nobs in zip(self.target_policies, next_obs)
                ]
            trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)
        else:  # DDPG
            if self.discrete_action:
                trgt_vf_in = torch.cat(
                    (next_obs[agent_i],
                     onehot_from_logits(
                         curr_agent.target_policy(next_obs[agent_i]))),
                    dim=1)
            else:
                trgt_vf_in = torch.cat(
                    (next_obs[agent_i],
                     curr_agent.target_policy(next_obs[agent_i])),
                    dim=1)
        target_value = (rews[agent_i].view(-1, 1) +
                        self.gamma * curr_agent.target_critic(trgt_vf_in) *
                        (1 - dones[agent_i].view(-1, 1)))

        if self.alg_types[agent_i] == 'MADDPG':
            vf_in = torch.cat((*obs, *acs), dim=1)
        else:  # DDPG
            vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1)
        actual_value = curr_agent.critic(vf_in)
        vf_loss = MSELoss(actual_value, target_value.detach())
        vf_loss.backward()
        if parallel:
            average_gradients(curr_agent.critic)
        torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), 0.5)
        curr_agent.critic_optimizer.step()

        curr_agent.policy_optimizer.zero_grad()

        if self.discrete_action:
            # Forward pass as if onehot (hard=True) but backprop through a differentiable
            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
            # through discrete categorical samples, but I'm not sure if that is
            # correct since it removes the assumption of a deterministic policy for
            # DDPG. Regardless, discrete policies don't seem to learn properly without it.
            curr_pol_out = curr_agent.policy(obs[agent_i])
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
        else:
            curr_pol_out = curr_agent.policy(obs[agent_i])
            curr_pol_vf_in = curr_pol_out
        if self.alg_types[agent_i] == 'MADDPG':
            all_pol_acs = []
            for i, pi, ob in zip(range(self.nagents), self.policies, obs):
                if i == agent_i:
                    all_pol_acs.append(curr_pol_vf_in)
                elif self.discrete_action:
                    all_pol_acs.append(onehot_from_logits(pi(ob)))
                else:
                    all_pol_acs.append(pi(ob))
            vf_in = torch.cat((*obs, *all_pol_acs), dim=1)
        else:  # DDPG
            vf_in = torch.cat((obs[agent_i], curr_pol_vf_in), dim=1)
        pol_loss = -curr_agent.critic(vf_in).mean()
        pol_loss += (curr_pol_out**2).mean() * 1e-3
        pol_loss.backward()
        if parallel:
            average_gradients(curr_agent.policy)
        torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), 0.5)
        curr_agent.policy_optimizer.step()
        if logger is not None:
            logger.add_scalars('agent%i/losses' % agent_i, {
                'vf_loss': vf_loss,
                'pol_loss': pol_loss
            }, self.niter)
Ejemplo n.º 3
0
    def update(self, sample, agent_i, parallel=False, logger=None):
        """
        Update parameters of agent model based on sample from replay buffer
        Inputs:
            sample: tuple of (observations, actions, rewards, next
                    observations, and episode end masks) sampled randomly from
                    the replay buffer. Each is a list with entries
                    corresponding to each agent
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
            logger (SummaryWriter from Tensorboard-Pytorch):
                If passed in, important quantities will be logged
        """
        obs, acs, rews, next_obs, dones = sample
        curr_agent = self.agents[agent_i]

        curr_agent.critic_optimizer.zero_grad()
        if self.alg_types[agent_i] == 'MADDPG':
            if self.discrete_action:  # one-hot encode action
                all_trgt_acs = [
                    onehot_from_logits(pi(nobs))
                    for pi, nobs in zip(self.target_policies, next_obs)
                ]  # a'=mu'(o') Have all agents'
            else:
                all_trgt_acs = [
                    pi(nobs)
                    for pi, nobs in zip(self.target_policies, next_obs)
                ]

            # ==========================Adding noise====================
            if self.noisy_sharing == True:
                #noisy_all_trgt_acs = self.noisy_sharing_discrete(all_trgt_acs,agent_i)
                #all_trgt_acs = noisy_all_trgt_acs
                noisy_acs = self.noisy_sharing_discrete(acs, agent_i)
                acs = noisy_acs
                # print(self.noisy_SNR)

            # ==================End of Adding noise====================
            trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)

            # =========================Differential Obs========================
            # ============== Dedicate for simple_speaker_listener =============
            # The est_action is used to replace acs[1]
            if self.game_id == 'simple_speaker_listener' and self.est_ac == True:
                diff_pos = (next_obs[0] - obs[0])[:, -2:]
                tmp_p = torch.transpose(diff_pos.ge(torch.max(diff_pos) * 0.8),
                                        0, 1)
                tmp_p[0] = tmp_p[0] * 1
                tmp_p[1] = tmp_p[1] * 3
                tmp_n = torch.transpose(diff_pos.le(torch.min(diff_pos) * 0.8),
                                        0, 1)
                tmp_n[0] = tmp_n[0] * 2
                tmp_n[1] = tmp_n[1] * 4
                mask = torch.transpose(tmp_p, 0, 1) + torch.transpose(
                    tmp_n, 0, 1)
                est_action = mask.sum(dim=1)
                est_action = torch.zeros(len(est_action),
                                         acs[1].shape[1]).scatter_(
                                             dim=1,
                                             index=est_action.view(-1, 1),
                                             value=1)
                acs[1] = est_action

            # =======================End of differential Obs ==================
        else:  # DDPG
            if self.discrete_action:
                trgt_vf_in = torch.cat(
                    (next_obs[agent_i],
                     onehot_from_logits(
                         curr_agent.target_policy(next_obs[agent_i]))),
                    dim=1)
            else:  # a'=mu(o') only have current agent's
                trgt_vf_in = torch.cat(
                    (next_obs[agent_i],
                     curr_agent.target_policy(next_obs[agent_i])),
                    dim=1)
        target_value = (rews[agent_i].view(-1, 1) +
                        self.gamma * curr_agent.target_critic(trgt_vf_in) *
                        (1 - dones[agent_i].view(-1, 1)))  #y^j

        if self.alg_types[agent_i] == 'MADDPG':
            vf_in = torch.cat((*obs, *acs), dim=1)
        else:  # DDPG
            vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1)

        actual_value = curr_agent.critic(vf_in)
        vf_loss = MSELoss(actual_value, target_value.detach())
        vf_loss.backward()
        if parallel:
            average_gradients(curr_agent.critic)
        torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)
        curr_agent.critic_optimizer.step()

        # ============== Here for policy network training =====================
        curr_agent.policy_optimizer.zero_grad()
        if self.discrete_action:
            # Forward pass as if onehot (hard=True) but backprop through a differentiable
            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
            # through discrete categorical samples, but I'm not sure if that is
            # correct since it removes the assumption of a deterministic policy for
            # DDPG. Regardless, discrete policies don't seem to learn properly without it.
            curr_pol_out = curr_agent.policy(obs[agent_i])
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
        else:
            curr_pol_out = curr_agent.policy(obs[agent_i])
            curr_pol_vf_in = curr_pol_out
        if self.alg_types[agent_i] == 'MADDPG':
            all_pol_acs = []
            for i, pi, ob in zip(range(self.nagents), self.policies, obs):
                # Is it correct to train mu using all others' policies???
                if i == agent_i:
                    all_pol_acs.append(curr_pol_vf_in)
                elif self.discrete_action:
                    all_pol_acs.append(onehot_from_logits(pi(ob)))
                else:
                    all_pol_acs.append(pi(ob))
            vf_in = torch.cat((*obs, *all_pol_acs), dim=1)
        else:  # DDPG
            vf_in = torch.cat((obs[agent_i], curr_pol_vf_in), dim=1)
        pol_loss = -curr_agent.critic(vf_in).mean()
        pol_loss += (curr_pol_out**2).mean() * 1e-3
        pol_loss.backward()
        if parallel:
            average_gradients(curr_agent.policy)
        torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(),
                                       0.5)  # Constraints on the grad.
        curr_agent.policy_optimizer.step()
        if logger is not None:
            logger.add_scalars('agent%i/losses' % agent_i, {
                'vf_loss': vf_loss,
                'pol_loss': pol_loss
            }, self.niter)
Ejemplo n.º 4
0
    def update(self, sample, agent_i, parallel=False, grad_norm=0.5, norm_rewards=False):
        """ Update parameters of agent model based on sample from replay buffer
        Arguments:
            sample: [(B,D)]*N, obs, next_obs, action can be [dict (B,D)]*N
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
        """
        def switch_idx(idx, curr_agent_idx):
            return idx if idx > curr_agent_idx else idx + 1

        # [(B,1,D)]*N or [dict (B,1,D)]*N
        obs, acs, rews, next_obs, dones = self.add_virtual_dim(sample)   
        # preprocess rewards to reduce variance 
        if norm_rewards:
            rews = selef.normalize_rewards(rews)

        bs, ts, _ = obs[agent_i].shape 
        curr_agent = self.agents[agent_i]

        # NOTE: Critic update
        curr_agent.critic_optimizer.zero_grad()

        # compute target actions
        if self.alg_types[agent_i] == 'MADDPG':
            all_trgt_acs = []   # [dict (B,1,A)]*N
            for i, nobs in enumerate(next_obs): # (B,1,O)

                if self.model_of_agents:
                    if i == agent_i:    # use current agent target
                        act_i = curr_agent.compute_action(
                            nobs, target=True, requires_grad=False)
                    else:   # use moa agent target 
                        agent_j = switch_idx(i, agent_i)
                        act_i = curr_agent.compute_moa_action(
                            agent_j, nobs, target=True, requires_grad=False, return_logits=True)
                else:   # use each agents' target directly
                    act_i = self.agents[i].compute_action(
                        nobs, target=True, requires_grad=False)

                all_trgt_acs.append(act_i)
            # [(B,1,O)_i, ..., (B,1,A)_i, ...] -> (B,1,O*N+A*N)
            trgt_vf_in = torch.cat([
                *self.flatten_obs(next_obs, ma=True), 
                *self.flatten_act(all_trgt_acs, ma=True)
            ], dim=-1)
        else:  # DDPG
            act_i = curr_agent.compute_action(next_obs[agent_i], target=True, requires_grad=False)
            # (B,1,O) + (B,1,A) -> (B,1,O+A)
            trgt_vf_in = torch.cat([
                self.flatten_obs(next_obs[agent_i]), 
                self.flatten_act(act_i)
            ], dim=-1)

        # bellman targets   # (B*T,1) -> (B*1,1) -> (B,1)
        target_q = curr_agent.compute_value(trgt_vf_in, target=True) 
        target_value = (rews[agent_i].view(-1, 1) + self.gamma * target_q *
                            (1 - dones[agent_i].view(-1, 1)))   

        # Q func
        if self.alg_types[agent_i] == 'MADDPG':
            vf_in = torch.cat([
                *self.flatten_obs(obs, ma=True), 
                *self.flatten_act(acs, ma=True)
            ], dim=-1)
        else:  # DDPG
            vf_in = torch.cat([
                self.flatten_obs(obs[agent_i]),
                self.flatten_act(acs[agent_i])
            ], dim=-1)
        actual_value = curr_agent.compute_value(vf_in, target=False) # (B*T,1)

        # bellman errors
        vf_loss = MSELoss(actual_value, target_value.detach())
        vf_loss.backward()
        if parallel:
            average_gradients(curr_agent.critic)
        if grad_norm > 0:
            torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), grad_norm)
        curr_agent.critic_optimizer.step()

        # NOTE: Policy update
        curr_agent.policy_optimizer.zero_grad()

        # current agent action (deterministic, softened), dcit (B,T,A)
        curr_pol_out = curr_agent.compute_action(obs[agent_i], target=False, requires_grad=True) 

        if self.alg_types[agent_i] == 'MADDPG':
            all_pol_acs = []
            for i, pi, ob in zip(range(self.nagents), self.policies, obs):
                if i == agent_i:    # insert current agent act to q input 
                    all_pol_acs.append(self.flatten_act(curr_pol_out))
                    # all_pol_acs.append(curr_pol_out)
                else: 
                    # p_act_i = self.agents[i].compute_action(ob, target=False, requires_grad=False) 
                    p_act_i = self.flatten_act(acs[i])
                    all_pol_acs.append(p_act_i)
            # (B,T,O*N+A*N)s
            p_vf_in = torch.cat([
                *self.flatten_obs(obs, ma=True),
                *self.flatten_act(all_pol_acs, ma=True)
            ], dim=-1) 
        else:  # DDPG
            # (B,T,O+A)
            p_vf_in = torch.cat([
                self.flatten_obs(obs[agent_i]),
                self.flatten_act(curr_pol_out)
            ], dim=-1) 
        
        # value function to update current policy
        p_value = curr_agent.compute_value(p_vf_in, target=False) # (B*T,1)
        pol_loss = -p_value.mean()

        # p regularization, scale down output (gaussian mean,std or logits)
        # reference: https://github.com/openai/maddpg/blob/master/maddpg/trainer/maddpg.py
        pol_reg_loss = torch.tensor(0.0)
        for k, v in curr_pol_out.items():
            pol_reg_loss += ((v.reshape(bs*ts, -1))**2).mean() * 1e-3

        pol_loss_total = pol_loss + pol_reg_loss
        pol_loss_total.backward()
        if parallel:
            average_gradients(curr_agent.policy)
        if grad_norm > 0:
            torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), grad_norm)
        curr_agent.policy_optimizer.step()

        # NOTE: collect training statss 
        results = {}
        for k, v in zip(
            ["critic_loss", "policy_loss", "policy_reg_loss"], 
            [vf_loss, pol_loss, pol_reg_loss]
        ):
            key = "agent_{}/{}".format(agent_i, k)
            value = v.data.cpu().numpy()
            results[key] = value
            self.agent_losses[key].append(value)        
        return results
Ejemplo n.º 5
0
    def update(self, sample, agent_i, parallel=False, grad_norm=0.5):
        """ Update parameters of agent model based on sample from replay buffer
        Arguments:
            sample: [(B,T,D)]*N, obs, next_obs, action, logprobs can be [dict (B,T,D)]*N
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
        """
        obs, acs, rews, next_obs, dones, logprobs = sample  # each [(B,T,D)]*N
        bs, ts, _ = obs[agent_i].shape
        self.init_hidden(bs)  # use pre-defined init hiddens
        curr_agent = self.agents[agent_i]

        # entropy temperature param
        alpha = curr_agent.log_alpha.get_alpha().detach()

        # NOTE: Critic update
        curr_agent.critic_optimizer.zero_grad()

        # compute target actions
        if self.alg_types[agent_i] == 'MASAC':
            all_trgt_acs = []  # [dict (B,T,A)]*N
            all_trgt_logprobs = []  # [dict (B,T,1)]*N

            for i, nobs in enumerate(next_obs):  # (B,T,O)
                with torch.no_grad():
                    act_i, log_prob_i, _ = self.agents[
                        i].compute_action_logprob(nobs)
                all_trgt_acs.append(act_i)
                all_trgt_logprobs.append(log_prob_i)

            # [(B,T,O)_i, ..., (B,T,A)_i, ...] -> (B,T,O*N+A*N)
            trgt_vf_in = torch.cat([
                *self.flatten_obs(next_obs, ma=True),
                *self.flatten_act(all_trgt_acs, ma=True)
            ],
                                   dim=-1)

            # log prob of target action, [(B,T,1)]*N
            target_a_logprob = self.contract_logprob(all_trgt_logprobs,
                                                     ma=True)
            # [(B,T,1)]*N -> (B,T,N) -> (B,T,1)
            target_a_logprob = torch.sum(torch.cat(target_a_logprob, -1), -1)

        else:  # SAC
            with torch.no_grad():
                act_i, log_prob_i, _ = curr_agent.compute_action_logprob(
                    next_obs[agent_i])
            # (B,T,O) + (B,T,A) -> (B,T,O+A)
            trgt_vf_in = torch.cat(
                [self.flatten_obs(next_obs[agent_i]),
                 self.flatten_act(act_i)],
                dim=-1)

            # log prob of target action, (B,T,1)
            target_a_logprob = self.contract_logprob(log_prob_i)

        # bellman targets
        target_q1, target_q2 = curr_agent.compute_value(trgt_vf_in,
                                                        target=True)  # (B*T,1)
        target_a_logprob = target_a_logprob.reshape(-1, 1).detach()  # (B*T,1)

        target_q = torch.min(target_q1, target_q2) - alpha * target_a_logprob
        target_value = (rews[agent_i].view(-1, 1) + self.gamma * target_q *
                        (1.0 - dones[agent_i].view(-1, 1)))  # (B*T,1)

        # Q func
        if self.alg_types[agent_i] == 'MASAC':
            vf_in = torch.cat([
                *self.flatten_obs(obs, ma=True),
                *self.flatten_act(acs, ma=True)
            ],
                              dim=-1)
        else:  # DDPG
            vf_in = torch.cat([
                self.flatten_obs(obs[agent_i]),
                self.flatten_act(acs[agent_i])
            ],
                              dim=-1)
        q1, q2 = curr_agent.compute_value(vf_in, target=False)  # (B*T,1)

        # bellman errors
        vf_loss1 = MSELoss(q1, target_value.detach())
        vf_loss1.backward()
        vf_loss2 = MSELoss(q2, target_value.detach())
        vf_loss2.backward()
        if parallel:
            average_gradients(curr_agent.critic1)
            average_gradients(curr_agent.critic2)
        if grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(curr_agent.critic1.parameters(),
                                           grad_norm)
            torch.nn.utils.clip_grad_norm_(curr_agent.critic2.parameters(),
                                           grad_norm)
        curr_agent.critic1_optimizer.step()
        curr_agent.critic2_optimizer.step()

        # NOTE: Policy update
        curr_agent.policy_optimizer.zero_grad()

        # current agent action (deterministic, softened), dcit (B,T,A)
        curr_pol_out, curr_log_prob, _ = curr_agent.compute_action_logprob(
            obs[agent_i])
        a_log_prob = self.contract_logprob(log_prob_d)

        if self.alg_types[agent_i] == 'MASAC':
            all_pol_acs = []
            all_pol_logprobs = []

            for i, pi, ob in zip(range(self.nagents), self.policies, obs):
                if i == agent_i:  # insert current agent act to q input
                    all_pol_acs.append(self.flatten_act(curr_pol_out))
                    # current agent log prob (backprop-able)
                    all_pol_logprobs.append(curr_log_prob)
                else:
                    # TODO: need other agents' log probs as well
                    # p_act_i = self.agents[i].compute_action(ob, target=False, requires_grad=False)
                    p_act_i = self.flatten_act(acs[i])
                    all_pol_acs.append(p_act_i)
                    # other agents' log probs (during sampling)
                    all_pol_logprobs.append(logprobs[i])

            # (B,T,O*N+A*N)s
            p_vf_in = torch.cat([
                *self.flatten_obs(obs, ma=True),
                *self.flatten_act(all_pol_acs, ma=True)
            ],
                                dim=-1)

            # [dict (B,T,1)]*N -> [(B,T,1)]*N -> (B,T,1)
            a_log_prob = self.contract_logprob(all_pol_logprobs, ma=True)
            a_log_prob = torch.sum(torch.cat(a_log_prob, -1), -1)

        else:  # DDPG
            # (B,T,O+A)
            p_vf_in = torch.cat([
                self.flatten_obs(obs[agent_i]),
                self.flatten_act(curr_pol_out)
            ],
                                dim=-1)

            # dict (B,T,1) -> (B,T,1)
            a_log_prob = self.contract_logprob(curr_log_prob)

        # KL loss between alpha log prob & target policy value function
        p_value1, p_value2 = curr_agent.compute_value(p_vf_in,
                                                      target=False)  # (B*T,1)
        p_value_target = torch.min(p_value1, p_value2)
        pol_loss = alpha * a_log_prob - p_value_target

        # NOTE: this is optional (not in SAC)
        # p regularization, scale down output (gaussian mean,std or logits)
        # reference: https://github.com/openai/maddpg/blob/master/maddpg/trainer/maddpg.py
        pol_reg_loss = torch.tensor(0.0)
        for k, v in curr_pol_out.items():
            pol_reg_loss += ((v.reshape(bs * ts, -1))**2).mean() * 1e-3

        pol_loss_total = pol_loss + pol_reg_loss
        pol_loss_total.backward()
        if parallel:
            average_gradients(curr_agent.policy)
        if grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(),
                                           grad_norm)
        curr_agent.policy_optimizer.step()

        # NOTE: Alpha (entropy) update
        alpha_loss = -curr_agent.log_alpha() * (a_log_prob.detach() +
                                                self.target_entropy)
        alpha_loss.backward()
        if parallel:
            average_gradients(curr_agent.policy)
        if grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(),
                                           grad_norm)
        curr_agent.alpha_optimizer.step()

        # NOTE: collect training statss
        results = {}
        for k, v in zip([
                "critic1_loss", "critic2_loss", "policy_loss",
                "policy_reg_loss", "alpha_loss"
        ], [vf_loss1, vf_loss2, pol_loss, pol_reg_loss, alpha_loss]):
            key = "agent_{}/{}".format(agent_i, k)
            value = v.data.cpu().numpy()
            results[key] = value
            self.agent_losses[key].append(value)
        return results
Ejemplo n.º 6
0
    def update(self, sample, agent_i, parallel=False, grad_norm=0.5, contract_keys=None):
        """ Update parameters of agent model based on sample from replay buffer
        Arguments:
            sample: [(B,T,D)]*N, obs, next_obs, action can be [dict (B,T,D)]*N
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
        """
        # each is [(B,T,D)]*N 
        obs, acs, rews, next_obs, dones, old_logits, advantages, vf_preds = sample 
        bs, ts, _ = obs[agent_i].shape 
        self.init_hidden(bs)  # use pre-defined init hiddens 
        curr_agent = self.agents[agent_i]

        # NOTE: Critic update
        curr_agent.critic_optimizer.zero_grad()

        # value func
        if self.alg_types[agent_i] == 'CCPPO':
            # [(B,T,O)_i, ...] -> (B,T,O*N)
            vf_in = torch.cat([
                *self.flatten_obs(obs, ma=True), 
            ], dim=-1)
        else:  # PPO
            vf_in = self.flatten_obs(obs[agent_i]) # (B,T,O)
        actual_value = curr_agent.compute_value(vf_in) # (B,T,1)
        
        # bellman errors (PPO clipped style)
        vf_loss1 = (actual_value - vf_preds) ** 2
        vf_clipped = vf_preds + (actual_value - vf_preds).clamp(
                                -self.vf_clip_param, self.vf_clip_param)
        vf_loss2 = (vf_clipped - vf_preds) ** 2
        vf_loss = torch.max(vf_loss1, vf_loss2).mean()

        critic_loss = self.vf_loss_coeff * vf_loss
        critic_loss.backward()
        if parallel:
            average_gradients(curr_agent.critic)
        if grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), grad_norm)
        curr_agent.critic_optimizer.step()


        # NOTE: Policy update
        curr_agent.policy_optimizer.zero_grad()
        
        # ppo policy update 
        # NOTE: need wrap coz `evaluation_action` takes in dict, since policy output dict)
        act_eval_out = curr_agent.evalaute_action(
            self.wrap_action(old_logits[agent_i]), 
            self.wrap_action(acs[agent_i]), 
            obs[agent_i], 
            contract_keys=contract_keys
        ) # all (B,T,1)
        curr_log_prob, old_log_prob, entropy, kl = act_eval_out

        logp_ratio = torch.exp(curr_log_probs - old_log_probs)
        policy_loss = -torch.min(
            advantages * logp_ratio, 
            advantages * logp_ratio.clamp(1-self.clip_param, 1+self.clip_param)
        )   # (B,T,1)
        policy_loss = policy_loss.mean()

        # kl loss on current & previous policy outputs
        kl_loss = kl.mean()
        # update kl coefficient per update (with mean/expected kl)
        curr_agent.kl_coeff.update_kl(kl_loss)

        # entropy loss on current policy outputs
        entropy_loss = entropy.mean()

        actor_loss = policy_loss
        actor_loss += curr_agent.kl_coeff() * kl_loss
        actor_loss += self.entropy_coeff * entropy_loss
        actor_loss.backward()
        if parallel:
            average_gradients(curr_agent.policy)
        if grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), grad_norm)
        curr_agent.policy_optimizer.step()

        # NOTE: collect training statss 
        results = {}
        key_list = [
            "critic_loss", 
            "policy_loss", 
            "kl_loss", 
            "entropy_loss",
            "explained_variance"
        ]
        val_list = [
            vf_loss,
            policy_loss,
            kl_loss,
            entropy_loss,
            explained_variance(vf_preds, actual_value)
        ]

        for k, v in zip(key_list, val_list):
            key = "agent_{}/{}".format(agent_i, k)
            value = v.data.cpu().numpy()
            results[key] = value
            self.agent_losses[key].append(value)        
        return results
Ejemplo n.º 7
0
    def update(self, sample, agent_i, parallel=False, grad_norm=0.5):
        """
        Update parameters of agent model based on sample from replay buffer
        Inputs:
            sample: EpisodeBatch, use sample[key_i] to get a specific 
                    array of obs, action, etc for agent i
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
        """
        obs, acs, rews, next_obs, dones = self.parse_sample(sample) # [(B,T,D)]*N  
        bs, ts, _ = obs[0].shape
        self.init_hidden(bs)  # use pre-defined init hiddens 
        curr_agent = self.agents[agent_i]

        # NOTE: critic update
        curr_agent.critic_optimizer.zero_grad()

        # compute target actions
        if self.alg_types[agent_i] == 'MADDPG':
            all_trgt_acs = []   # [(B,T,A)]*N
            for i, (pi, nobs) in enumerate(zip(self.target_policies, next_obs)):
                # nobs: (B,T,O)
                act_i = self.compute_action(i, pi, nobs, bs=bs, ts=ts)
                all_trgt_acs.append(act_i)  # [(B,T,A)]
            
            if self.discrete_action:    # one-hot encode action
                all_trgt_acs = [onehot_from_logits(
                    act_i.reshape(bs*ts,-1)
                ).reshape(bs,ts,-1) for act_i in all_trgt_acs] 

            # critic input, [(B,T,O)_i, ..., (B,T,A)_i, ...] -> (B,T,O*N+A*N)
            trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=-1)
        else:  # DDPG
            act_i = self.compute_action(agent_i, curr_agent.target_policy, 
                        next_obs[agent_i], bs=bs, ts=ts)

            if self.discrete_action:
                act_i = onehot_from_logits(
                    act_i.reshape(bs*ts, -1)
                ).reshape(bs, ts, -1) 

            # (B,T,O) + (B,T,A) -> (B,T,O+A)
            trgt_vf_in = torch.cat((next_obs[agent_i], act_i), dim=-1)

        # bellman targets
        target_q = self.compute_q_val(agent_i, curr_agent.target_critic, 
                        trgt_vf_in, bs=bs, ts=ts)   # (B*T,1)
        target_value = (rews[agent_i].view(-1, 1) + self.gamma * target_q *
                            (1 - dones[agent_i].view(-1, 1)))   # (B*T,1)

        # Q func
        if self.alg_types[agent_i] == 'MADDPG':
            vf_in = torch.cat((*obs, *acs), dim=1)
        else:  # DDPG
            vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1)
        actual_value = self.compute_q_val(agent_i, curr_agent.critic, 
                            vf_in, bs=bs, ts=ts)    # (B*T,1)

        # bellman errors
        vf_loss = MSELoss(actual_value, target_value.detach())
        vf_loss.backward()
        if parallel:
            average_gradients(curr_agent.critic)
        if grad_norm > 0:
            torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), grad_norm)
        curr_agent.critic_optimizer.step()

        # NOTE: policy update
        curr_agent.policy_optimizer.zero_grad()

        # current agent action (deterministic, softened)
        curr_pol_out = self.compute_action(i, curr_agent.policy, 
                                obs[agent_i], bs=bs, ts=ts) # (B,T,A)
        if self.discrete_action:
            # Forward pass as if onehot (hard=True) but backprop through a differentiable
            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
            # through discrete categorical samples, but I'm not sure if that is
            # correct since it removes the assumption of a deterministic policy for
            # DDPG. Regardless, discrete policies don't seem to learn properly without it.
            curr_pol_vf_in = gumbel_softmax(
                                curr_pol_out.reshape(bs*ts, -1), hard=True
                            ).reshape(bs, ts, -1) 
        else:
            curr_pol_vf_in = curr_pol_out

        if self.alg_types[agent_i] == 'MADDPG':
            all_pol_acs = []
            for i, pi, ob in zip(range(self.nagents), self.policies, obs):
                if i == agent_i:
                    # insert current agent act to q input 
                    all_pol_acs.append(curr_pol_vf_in)
                else: 
                    p_act_i = self.compute_action(i, pi, ob, bs=bs, ts=ts) # (B,T,A)
                    if self.discrete_action:
                        p_act_i = onehot_from_logits(
                                    p_act_i.reshape(bs*ts, -1)
                                ).reshape(bs, ts, -1) 
                    all_pol_acs.append(p_act_i)
            p_vf_in = torch.cat((*obs, *all_pol_acs), dim=-1) # (B,T,O*N+A*N)
        else:  # DDPG
            p_vf_in = torch.cat((obs[agent_i], curr_pol_vf_in), dim=-1) # (B,T,O+A)
        
        # value function to update current policy
        p_value = self.compute_q_val(agent_i, curr_agent.critic, 
                        p_vf_in, bs=bs, ts=ts)   # (B*T,1)
        pol_loss = -p_value.mean()
        # p regularization, scale down output (gaussian mean,std or logits)
        # reference: https://github.com/openai/maddpg/blob/master/maddpg/trainer/maddpg.py
        pol_loss += ((curr_pol_out.reshape(bs*ts, -1))**2).mean() * 1e-3
        pol_loss.backward()
        if parallel:
            average_gradients(curr_agent.policy)
        if grad_norm > 0:
            torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), grad_norm)
        curr_agent.policy_optimizer.step()

        # collect training statss 
        results = {
            "agent_{}_critic_loss".format(agent_i): vf_loss,
            "agent_{}_policy_loss".format(agent_i): pol_loss
        }
        return results
Ejemplo n.º 8
0
    def update(self, sample, agent_i, parallel=False, logger=None):
        """
        Update parameters of agent model based on sample from replay buffer
        Inputs:
            sample: tuple of (observations, actions, rewards, next
                    observations, and episode end masks) sampled randomly from
                    the replay buffer. Each is a list with entries
                    corresponding to each agent
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
            logger (SummaryWriter from Tensorboard-Pytorch):
                If passed in, important quantities will be logged
        """
        # For RNN, the obs and next_obs both have histories
        obs, acs, rews, next_obs, dones = sample
        curr_agent = self.agents[agent_i]

        curr_agent.critic_optimizer.zero_grad()
        if self.alg_types[agent_i] == 'MADDPG':
            if self.discrete_action:  # one-hot encode action

                # This is original one, 'pi': policy, 'nobs' n_observations
                #all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in
                #                zip(self.target_policies, next_obs)]

                # Original till here

                #-------- Expanding out for debugging --------#
                all_trgt_acs = []
                for pi, nobs in zip(self.target_policies, next_obs):
                    temp = onehot_from_logits(pi(nobs))
                    #print(temp)
                    all_trgt_acs.append(temp)
                # -------- End debug -------------------------#

            else:
                all_trgt_acs = [
                    pi(nobs)
                    for pi, nobs in zip(self.target_policies, next_obs)
                ]

            # Get the most current observation from the history to calculate the target value
            t0_next_obs = [[], [], []]
            for a in range(self.nagents):
                t0_next_obs[a] = torch.tensor(np.zeros(
                    (next_obs[0].shape[0], 18)),
                                              dtype=torch.float)
            # the next_obs[0].shape[0] gives the batch size
            # TODO: change it to be a parameter
            # Only keep the current obs for critic VF
            for n in range(self.nagents):  # for each agents
                for b in range(
                        next_obs[0].shape[0]):  # for the number of batches
                    t0_next_obs[n][b][:] = next_obs[n][b][0:18]

            # ORIGINAL was \/
            #trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)
            trgt_vf_in = torch.cat((*t0_next_obs, *all_trgt_acs), dim=1)
            # It is working till here. Only kept the current obs for critic VF
        else:  # DDPG
            # DDPG only knows the particular agent's observation and policy
            # Whereas, MADDPG has access to all other agents' policies
            # TODO: grab only the current observation to send to the critic
            t0_next_obs = [[], [], []]
            for a in range(self.nagents):
                t0_next_obs[a] = torch.tensor(np.zeros(
                    (next_obs[0].shape[0], 18)),
                                              dtype=torch.float)
            for n in range(self.nagents):  # for each agents
                for b in range(
                        next_obs[0].shape[0]):  # for the number of batches
                    t0_next_obs[n][b][:] = next_obs[n][b][0:18]

            # Originally it would be next_obs[agent_i] instead of t0_next_obs[agent_i]
            if self.discrete_action:
                trgt_vf_in = torch.cat(
                    (t0_next_obs[agent_i],
                     onehot_from_logits(
                         curr_agent.target_policy(next_obs[agent_i]))),
                    dim=1)
            else:
                trgt_vf_in = torch.cat(
                    (next_obs[agent_i],
                     curr_agent.target_policy(next_obs[agent_i])),
                    dim=1)

        target_value = (rews[agent_i].view(-1, 1) +
                        self.gamma * curr_agent.target_critic(trgt_vf_in) *
                        (1 - dones[agent_i].view(-1, 1)))

        ##### Just get the current observation (i.e., without history) ##########
        # Reason: Critic VF does not need history
        # Copied the same as in t0_next_obs, since BOTH obs and next_obs have HISTORIES.
        t0_obs = [[], [], []]
        for a in range(self.nagents):
            t0_obs[a] = torch.tensor(np.zeros((obs[0].shape[0], 18)),
                                     dtype=torch.float)
        for n in range(self.nagents):  # for each agents
            for b in range(obs[0].shape[0]):  # for the number of batches
                t0_obs[n][b][:] = obs[n][b][0:18]
        ###################################################

        if self.alg_types[agent_i] == 'MADDPG':
            vf_in = torch.cat((*t0_obs, *acs), dim=1)
        else:  # DDPG  #TODO: below, might have to change obs to t0_obs, when using DDPG
            vf_in = torch.cat((t0_obs[agent_i], acs[agent_i]), dim=1)
        actual_value = curr_agent.critic(vf_in)
        vf_loss = MSELoss(actual_value, target_value.detach())
        vf_loss.backward()
        if parallel:
            average_gradients(curr_agent.critic)
        torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), 0.5)
        curr_agent.critic_optimizer.step()

        curr_agent.policy_optimizer.zero_grad()

        if self.discrete_action:
            # Forward pass as if onehot (hard=True) but backprop through a differentiable
            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
            # through discrete categorical samples, but I'm not sure if that is
            # correct since it removes the assumption of a deterministic policy for
            # DDPG. Regardless, discrete policies don't seem to learn properly without it.
            '''
            Now, we are back to forwarding policy, so we need to use obs with history
            '''
            curr_pol_out = curr_agent.policy(obs[agent_i])
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
            # Seems to be working fine till here
        else:
            curr_pol_out = curr_agent.policy(obs[agent_i])
            curr_pol_vf_in = curr_pol_out
        if self.alg_types[agent_i] == 'MADDPG':
            all_pol_acs = []
            for i, pi, ob in zip(range(self.nagents), self.policies, obs):
                if i == agent_i:
                    all_pol_acs.append(curr_pol_vf_in)
                elif self.discrete_action:
                    all_pol_acs.append(onehot_from_logits(pi(ob)))
                else:
                    all_pol_acs.append(pi(ob))
            # Originally:
            #vf_in = torch.cat((*obs, *all_pol_acs), dim=1)
            vf_in = torch.cat((*t0_obs, *all_pol_acs), dim=1)
        else:  # DDPG
            vf_in = torch.cat((t0_obs[agent_i], curr_pol_vf_in), dim=1)
        # TODO: FIX THIS
        pol_loss = -curr_agent.critic(vf_in).mean()
        pol_loss += (curr_pol_out**2).mean() * 1e-3
        pol_loss.backward()
        if parallel:
            average_gradients(curr_agent.policy)
        torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), 0.5)
        curr_agent.policy_optimizer.step()
        if logger is not None:
            logger.add_scalars('agent%i/losses' % agent_i, {
                'vf_loss': vf_loss,
                'pol_loss': pol_loss
            }, self.niter)