Beispiel #1
0
    def compute(self, rewards, next_obs):
        with torch.no_grad():
            next_obs = [Variable(torch.Tensor(np.vstack(next_obs[:, i])),
                      requires_grad=False) for i in range(rewards.shape[1])]

            acs_src = []
            prob_src = []
            for no, source in zip(next_obs, self.source):
                acs_src.append(gumbel_softmax(source(no), device=self.source_dev, hard=True))
                prob_src.append(gumbel_softmax(source(no), device=self.source_dev, hard=False))

            trans_in = torch.cat((*next_obs, *acs_src), dim=1)
            trans_out = self.transition(trans_in)
            prob_plan = []
            start = 0
            for i, (no, planning) in enumerate(zip(next_obs, self.planning)):
                length = no.shape[1]
                nno = trans_out[:, start:start + length]
                acs_ = []
                for j, ac in enumerate(acs_src):
                    if j != i: acs_.append(ac)
                acs_ = torch.cat(acs_, dim=1)
                plan_in = torch.cat((no, nno, acs_), dim=1)

                prob_plan.append(gumbel_softmax(planning(plan_in), device=self.plan_dev, hard=False))
            prob_plan = torch.cat(prob_plan, dim=1)
            prob_src = torch.cat(prob_src, dim=1)
            acs_src = torch.cat(acs_src, dim=1)

            E = acs_src * prob_plan - acs_src * prob_src
            i_rews = E.mean() * torch.ones((1, rewards.shape[1]))
            return i_rews.numpy()
    def update(self, sample, logger):
        obs, acs, rews, emps, next_obs, dones = sample

        self.transition_optimizer.zero_grad()
        trans_in = torch.cat((*obs, *acs), dim=1)
        next_obs_pred = self.transition(trans_in)
        trans_loss = MSELoss(next_obs_pred, torch.cat(next_obs, dim=1))
        trans_loss.backward()
        self.transition_optimizer.step()

        self.source_optimizer.zero_grad()
        acs_src = []
        prob_src = []
        for no, source in zip(next_obs, self.source):
            acs_src.append(
                gumbel_softmax(source(no), device=self.source_dev, hard=True))
            prob_src.append(
                gumbel_softmax(source(no), device=self.source_dev, hard=False))
        with torch.no_grad():
            trans_in = torch.cat((*next_obs, *acs_src), dim=1)
            trans_out = self.transition(trans_in)
        prob_plan = []
        start = 0
        for i, (no, planning) in enumerate(zip(next_obs, self.planning)):
            length = no.shape[1]
            nno = trans_out[:, start:start + length]
            acs_pi = gumbel_softmax(self.agents[i].policy(nno),
                                    device=self.source_dev,
                                    hard=True)
            plan_in = torch.cat((no, nno, acs_pi), dim=1)
            prob_plan.append(
                gumbel_softmax(planning(plan_in),
                               device=self.plan_dev,
                               hard=False))
            start += length
        prob_plan = torch.cat(prob_plan, dim=1)
        prob_src = torch.cat(prob_src, dim=1)
        acs_src = torch.cat(acs_src, dim=1)

        E = acs_src * prob_plan - acs_src * prob_src
        i_rews = -E.mean()
        i_rews.backward()
        self.source_optimizer.step()

        if logger is not None:
            logger.add_scalars('empowerment/losses', {
                'trans_loss': trans_loss.detach(),
                'i_rews': i_rews.detach()
            }, self.niter)
        self.niter += 1
    def step(self, obs, explore=False):
        """
        Take a step forward in environment for a minibatch of observations
        equivalent to `act` or `compute_actions`
        Arguments:
            obs: (B,O)
            explore: Whether or not to add exploration noise
        Returns:
            action: dict of actions for this agent, (B,A)
        """
        with torch.no_grad():
            action, hidden_states = self.policy(obs, self.policy_hidden_states)
            self.policy_hidden_states = hidden_states  # if mlp, still defafult None

            if self.discrete_action:
                for k in action:
                    if explore:
                        action[k] = gumbel_softmax(action[k], hard=True)
                    else:
                        action[k] = onehot_from_logits(action[k])
            else:  # continuous action
                idx = 0
                noise = Variable(Tensor(self.exploration.noise()),
                                 requires_grad=False)
                for k in action:
                    if explore:
                        dim = action[k].shape[-1]
                        action[k] += noise[idx:idx + dim]
                        idx += dim
                    action[k] = action[k].clamp(-1, 1)
        return action
Beispiel #4
0
 def step(self, obs, explore=False):
     """
     Take a step forward in environment for a minibatch of observations
     Inputs:
         obs (PyTorch Variable): Observations for this agent
         explore (boolean): Whether or not to add exploration noise
     Outputs:
         action (PyTorch Variable): Actions for this agent
     """
     # print('----', obs)
     action = self.policy(obs)
     # print('>>>>>>>>', action)
     if self.discrete_action:
         # print('agents.py discrete_action yes')
         if explore:
             action = gumbel_softmax(action, hard=True)
         else:
             action = onehot_from_logits(action)
     else:  # continuous action
         # print('agents.py continuous_action yes')
         if explore:
             # print('agents.py explore yes')
             # print('action before noise', action)
             action += Variable(Tensor(self.exploration.noise()),
                                requires_grad=False)
             # print('action after noise', action)
         action = action.clamp(-1, 1)
         # print('action after clamp', action)
     # print('>>>>>>>>>>>>>>>>', action)
     return action
 def _soft_act(x):  # x: (B,A)
     if not self.discrete_action:
         return x
     if requires_grad:
         return gumbel_softmax(x, hard=True)
     else:
         return onehot_from_logits(x)
 def _soft_act(self, x, requires_grad=True):
     """ soften action if discrete, x: (B,A) """
     if not self.discrete_action:
         return x
     if requires_grad:
         return gumbel_softmax(x, hard=True)
     else:
         return onehot_from_logits(x)
    def compute(self, next_obs):
        acs_src = []
        prob_src = []
        for no, source in zip(next_obs, self.source):
            acs_src.append(
                gumbel_softmax(source(no),
                               device=self.device.get_device(),
                               hard=True))
            prob_src.append(
                gumbel_softmax(source(no),
                               device=self.device.get_device(),
                               hard=False))

        trans_in = torch.cat((*next_obs, *acs_src), dim=1)
        trans_out = self.transition(trans_in)
        prob_plan = []
        end_idx = [0] + np.cumsum([ne_ob.shape[1]
                                   for ne_ob in next_obs]).tolist()
        start_end = [(start, end) for start, end in zip(end_idx, end_idx[1:])]
        for i, (no, planning) in enumerate(zip(next_obs, self.planning)):
            nno = trans_out[:, start_end[i][0]:start_end[i][1]]
            acs_ = []
            for j, ac in enumerate(acs_src):
                if j == i: continue
                nno_other = trans_out[:, start_end[j][0]:start_end[j][1]]
                acs_.append(
                    gumbel_softmax(self.agents[j].policy(nno_other),
                                   device=self.device.get_device(),
                                   hard=True))
            acs_ = torch.cat(acs_, dim=1)
            plan_in = torch.cat((no, nno, acs_), dim=1)

            prob_plan.append(
                gumbel_softmax(planning(plan_in),
                               device=self.device.get_device(),
                               hard=False))
        prob_plan = torch.cat(prob_plan, dim=1)
        prob_src = torch.cat(prob_src, dim=1)
        acs_src = torch.cat(acs_src, dim=1)

        return acs_src * prob_plan - acs_src * prob_src
    def compute(self, rewards, next_obs):
        with torch.no_grad():
            next_obs = [
                Variable(torch.Tensor(np.vstack(next_obs[:, i])),
                         requires_grad=False) for i in range(rewards.shape[1])
            ]

            acs_src = []
            prob_src = []
            for no in next_obs:
                acs_src.append(
                    gumbel_softmax(self.source(no),
                                   device=self.device.get_device(),
                                   hard=True))
                prob_src.append(
                    gumbel_softmax(self.source(no),
                                   device=self.device.get_device(),
                                   hard=False))

            trans_in = torch.cat((*next_obs, *acs_src), dim=1)
            trans_out = self.transition(trans_in)
            prob_plan = []
            n_obs = len(next_obs[0][0])
            for i, no in enumerate(next_obs):
                nno = trans_out[:, i * n_obs:(i + 1) * n_obs]
                plan_in = torch.cat((no, nno), dim=1)
                prob_plan.append(
                    gumbel_softmax(self.planning(plan_in),
                                   device=self.device.get_device(),
                                   hard=False))
            prob_plan = torch.cat(prob_plan, dim=1)
            prob_src = torch.cat(prob_src, dim=1)
            acs_src = torch.cat(acs_src, dim=1)

            E = acs_src * prob_plan - acs_src * prob_src
            i_rews = E.mean() * torch.ones((1, rewards.shape[1]))
            return i_rews.numpy()
Beispiel #9
0
 def get_actions(self, obs, noise=True, batch=False, hard=False):
     acts = self.ac_update.get_action(obs, batch)
     if self.args.discrete_action:
         if noise:
             assert batch is False
             acts = gumbel_softmax(acts,
                                   hard).cpu().detach().numpy().squeeze()
         else:
             acts = onehot(acts)
     else:
         acts = acts.cpu().detach().numpy().squeeze()
         if noise:
             assert batch is False
             acts = self.noise.get_action(acts)
     return acts
Beispiel #10
0
    def cast_embedding(self, emb, explore=True):

        # the terminology here is a bit misleading: explore==True is used for roll-outs (exploring the embedding
        # space with boltzmann exploration) and for selecting an un-biased backpropagable action (in contrast to a
        # back-propagable argmax that would be biased because always the mode of the distribution and not the mean (in
        # contrast to a gaussian that has mode=mean))
        # explore==False is only used at evaluation
        # we could imagine having three cases: 1-epsilon greedy exploration for roll-outs (or tunable temperature),
        # 2- gumbel_softmax for backprop
        # 3- argmax for evaluation


        if explore:
            emb = gumbel_softmax(emb, hard=True)
        else:
            emb = differentiable_onehot_from_logits(emb)
        return emb
Beispiel #11
0
    def update(self, sample, agent_i, parallel=False, logger=None):
        """
        Update parameters of agent model based on sample from replay buffer
        Inputs:
            sample: tuple of (observations, actions, rewards, next
                    observations, and episode end masks) sampled randomly from
                    the replay buffer. Each is a list with entries
                    corresponding to each agent
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
            logger (SummaryWriter from Tensorboard-Pytorch):
                If passed in, important quantities will be logged
        """
        obs, acs, rews, next_obs, dones = sample
        curr_agent = self.agents[agent_i]

        curr_agent.critic_optimizer.zero_grad()
        if self.alg_types[agent_i] == 'MADDPG':
            if self.discrete_action:  # one-hot encode action
                all_trgt_acs = [
                    onehot_from_logits(pi(nobs))
                    for pi, nobs in zip(self.target_policies, next_obs)
                ]  # a'=mu'(o') Have all agents'
            else:
                all_trgt_acs = [
                    pi(nobs)
                    for pi, nobs in zip(self.target_policies, next_obs)
                ]

            # ==========================Adding noise====================
            if self.noisy_sharing == True:
                #noisy_all_trgt_acs = self.noisy_sharing_discrete(all_trgt_acs,agent_i)
                #all_trgt_acs = noisy_all_trgt_acs
                noisy_acs = self.noisy_sharing_discrete(acs, agent_i)
                acs = noisy_acs
                # print(self.noisy_SNR)

            # ==================End of Adding noise====================
            trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)

            # =========================Differential Obs========================
            # ============== Dedicate for simple_speaker_listener =============
            # The est_action is used to replace acs[1]
            if self.game_id == 'simple_speaker_listener' and self.est_ac == True:
                diff_pos = (next_obs[0] - obs[0])[:, -2:]
                tmp_p = torch.transpose(diff_pos.ge(torch.max(diff_pos) * 0.8),
                                        0, 1)
                tmp_p[0] = tmp_p[0] * 1
                tmp_p[1] = tmp_p[1] * 3
                tmp_n = torch.transpose(diff_pos.le(torch.min(diff_pos) * 0.8),
                                        0, 1)
                tmp_n[0] = tmp_n[0] * 2
                tmp_n[1] = tmp_n[1] * 4
                mask = torch.transpose(tmp_p, 0, 1) + torch.transpose(
                    tmp_n, 0, 1)
                est_action = mask.sum(dim=1)
                est_action = torch.zeros(len(est_action),
                                         acs[1].shape[1]).scatter_(
                                             dim=1,
                                             index=est_action.view(-1, 1),
                                             value=1)
                acs[1] = est_action

            # =======================End of differential Obs ==================
        else:  # DDPG
            if self.discrete_action:
                trgt_vf_in = torch.cat(
                    (next_obs[agent_i],
                     onehot_from_logits(
                         curr_agent.target_policy(next_obs[agent_i]))),
                    dim=1)
            else:  # a'=mu(o') only have current agent's
                trgt_vf_in = torch.cat(
                    (next_obs[agent_i],
                     curr_agent.target_policy(next_obs[agent_i])),
                    dim=1)
        target_value = (rews[agent_i].view(-1, 1) +
                        self.gamma * curr_agent.target_critic(trgt_vf_in) *
                        (1 - dones[agent_i].view(-1, 1)))  #y^j

        if self.alg_types[agent_i] == 'MADDPG':
            vf_in = torch.cat((*obs, *acs), dim=1)
        else:  # DDPG
            vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1)

        actual_value = curr_agent.critic(vf_in)
        vf_loss = MSELoss(actual_value, target_value.detach())
        vf_loss.backward()
        if parallel:
            average_gradients(curr_agent.critic)
        torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)
        curr_agent.critic_optimizer.step()

        # ============== Here for policy network training =====================
        curr_agent.policy_optimizer.zero_grad()
        if self.discrete_action:
            # Forward pass as if onehot (hard=True) but backprop through a differentiable
            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
            # through discrete categorical samples, but I'm not sure if that is
            # correct since it removes the assumption of a deterministic policy for
            # DDPG. Regardless, discrete policies don't seem to learn properly without it.
            curr_pol_out = curr_agent.policy(obs[agent_i])
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
        else:
            curr_pol_out = curr_agent.policy(obs[agent_i])
            curr_pol_vf_in = curr_pol_out
        if self.alg_types[agent_i] == 'MADDPG':
            all_pol_acs = []
            for i, pi, ob in zip(range(self.nagents), self.policies, obs):
                # Is it correct to train mu using all others' policies???
                if i == agent_i:
                    all_pol_acs.append(curr_pol_vf_in)
                elif self.discrete_action:
                    all_pol_acs.append(onehot_from_logits(pi(ob)))
                else:
                    all_pol_acs.append(pi(ob))
            vf_in = torch.cat((*obs, *all_pol_acs), dim=1)
        else:  # DDPG
            vf_in = torch.cat((obs[agent_i], curr_pol_vf_in), dim=1)
        pol_loss = -curr_agent.critic(vf_in).mean()
        pol_loss += (curr_pol_out**2).mean() * 1e-3
        pol_loss.backward()
        if parallel:
            average_gradients(curr_agent.policy)
        torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(),
                                       0.5)  # Constraints on the grad.
        curr_agent.policy_optimizer.step()
        if logger is not None:
            logger.add_scalars('agent%i/losses' % agent_i, {
                'vf_loss': vf_loss,
                'pol_loss': pol_loss
            }, self.niter)
Beispiel #12
0
    def update(self, sample, agent_i):
        """
        TODO: to make the sample into valid inputs
        :param sample: the batch of experiences
        :param agent_i: the agent to be updated
        """
        obs, acs, rews, next_obs, dones = sample
        tensor_acs = torch.from_numpy(acs).float()
        tensor_obs = torch.from_numpy(obs).float()
        tensor_next_obs = torch.from_numpy(next_obs).float()
        tensor_rews = torch.from_numpy(rews).float()
        tensor_dones = torch.from_numpy(~dones).float()

        current_agent = self.agents[agent_i]

        current_agent.critic_optimizer.zero_grad()
        if self.alg_types[agent_i] == 'MADDPG':
            all_target_actors = self.target_actors()
            if self.discrete_action:
                all_target_actions = [
                    onehot_from_logits(
                        pi(tensor_obs[:,
                                      self.observation_index[agent_index][0]:
                                      self.observation_index[agent_index][1]]))
                    for pi, agent_index in zip(all_target_actors,
                                               range(self.num_agent))
                ]
            else:
                all_target_actions = [
                    pi(tensor_obs[:,
                                  self.observation_index[agent_index][0]:self.
                                  observation_index[agent_index][1]])
                    for pi, agent_index in zip(all_target_actors,
                                               range(self.num_agent))
                ]
            target_critic_input = torch.cat(
                (tensor_next_obs, torch.cat(all_target_actions, dim=1)), dim=1)
        else:
            if self.discrete_action:
                target_critic_input = torch.cat((
                    tensor_next_obs[:, self.observation_index[agent_i][0]:self.
                                    observation_index[agent_i][1]],
                    onehot_from_logits(
                        current_agent.target_actor(
                            tensor_next_obs[:, self.
                                            observation_index[agent_i][0]:self.
                                            observation_index[agent_i][1]]))),
                                                dim=1)
            else:
                target_critic_input = torch.cat((
                    tensor_next_obs[:, self.observation_index[agent_i][0]:self.
                                    observation_index[agent_i][1]],
                    current_agent.target_actor(
                        tensor_next_obs[:,
                                        self.observation_index[agent_i][0]:self
                                        .observation_index[agent_i][1]])),
                                                dim=1)
        target_critic_value = current_agent.target_critic(target_critic_input)
        target_value = \
            tensor_rews[:, agent_i].unsqueeze(1) + \
            self.gamma * target_critic_value * tensor_dones[:, agent_i].unsqueeze(1)

        if self.alg_types[agent_i] == 'MADDPG':
            critic_input = torch.cat((tensor_obs, tensor_acs), dim=1)
        else:  # DDPG
            critic_input = torch.cat(
                (tensor_obs[:, self.observation_index[agent_i][0]:self.
                            observation_index[agent_i][1]],
                 tensor_acs[:, self.action_index[agent_i][0]:self.
                            action_index[agent_i][1]]),
                dim=1)
        actual_value = current_agent.critic(critic_input)
        critic_loss = MSELoss(actual_value, target_value.detach())
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(current_agent.critic.parameters(), 0.5)
        current_agent.critic_optimizer.step()

        current_agent.actor_optimizer.zero_grad()
        if self.discrete_action:
            current_action_out = current_agent.actor(
                tensor_obs[:, self.observation_index[agent_i][0]:self.
                           observation_index[agent_i][1]])
            current_action_input_critic = gumbel_softmax(current_action_out,
                                                         hard=True)
        else:
            current_action_out = current_agent.actor(
                tensor_obs[:, self.observation_index[agent_i][0]:self.
                           observation_index[agent_i][1]])
            current_action_input_critic = current_action_out

        if self.alg_types[agent_i] == 'MADDPG':
            all_actor_action = []
            all_target_actors = self.target_actors()
            for i, pi in zip(range(self.num_agent), all_target_actors):
                if i == agent_i:
                    all_actor_action.append(current_action_input_critic)
                else:
                    if self.discrete_action:
                        all_actor_action.append(
                            onehot_from_logits(all_target_actors[i](
                                tensor_obs[:,
                                           self.observation_index[i][0]:self.
                                           observation_index[i][1]])))
                    else:
                        all_actor_action.append(all_target_actors[i](
                            tensor_obs[:, self.observation_index[i][0]:self.
                                       observation_index[i][1]]))
            critic_input = torch.cat(
                (tensor_obs, torch.cat(all_actor_action, dim=1)), dim=1)
        else:
            critic_input = torch.cat(
                (tensor_obs[:, self.observation_index[agent_i][0]:self.
                            observation_index[agent_i][1]],
                 current_action_input_critic),
                dim=1)

        actor_loss = -current_agent.critic(critic_input).mean()
        actor_loss += (current_action_out**2).mean() * 1e-3
        actor_loss.backward()
        torch.nn.utils.clip_grad_norm_(current_agent.actor.parameters(), 0.5)
        current_agent.actor_optimizer.step()
    def update(self, sample, logger):
        obs, acs, rews, emps, next_obs, dones = sample

        self.transition_optimizer.zero_grad()
        trans_in = torch.cat((*obs, *acs), dim=1)
        next_obs_pred = self.transition(trans_in)
        trans_loss = MSELoss(next_obs_pred, torch.cat(next_obs, dim=1))
        trans_loss.backward()
        self.transition_optimizer.step()

        self.planning_optimizer.zero_grad()
        acs_plan = []
        for o, no in zip(obs, next_obs):
            plan_in = torch.cat((o, no), dim=1)
            acs_plan.append(
                gumbel_softmax(self.planning(plan_in),
                               device=self.device.get_device(),
                               hard=True))
        acs_plan = torch.cat(acs_plan, dim=1)
        acs_torch = torch.cat(acs, dim=1)
        plan_loss = MSELoss(acs_plan, acs_torch)
        plan_loss.backward()
        self.planning_optimizer.step()

        self.source_optimizer.zero_grad()
        acs_src = []
        prob_src = []
        for no in next_obs:
            acs_src.append(
                gumbel_softmax(self.source(no),
                               device=self.device.get_device(),
                               hard=True))
            prob_src.append(
                gumbel_softmax(self.source(no),
                               device=self.device.get_device(),
                               hard=False))
        with torch.no_grad():
            trans_in = torch.cat((*next_obs, *acs_src), dim=1)
            trans_out = self.transition(trans_in)
        prob_plan = []
        n_obs = len(next_obs[0][0])
        for i, no in enumerate(next_obs):
            nno = trans_out[:, i * n_obs:(i + 1) * n_obs]
            plan_in = torch.cat((no, nno), dim=1)
            prob_plan.append(
                gumbel_softmax(self.planning(plan_in),
                               device=self.device.get_device(),
                               hard=False))
        prob_plan = torch.cat(prob_plan, dim=1)
        prob_src = torch.cat(prob_src, dim=1)
        acs_src = torch.cat(acs_src, dim=1)

        E = acs_src * prob_plan - acs_src * prob_src
        i_rews = -E.mean()
        i_rews.backward()
        self.source_optimizer.step()

        if logger is not None:
            logger.add_scalars(
                'empowerment/losses', {
                    'trans_loss': trans_loss.detach(),
                    'plan_loss': plan_loss.detach(),
                    'i_rews': i_rews.detach()
                }, self.niter)
        self.niter += 1
Beispiel #14
0
 def cast_embedding(self, emb, explore=True):
     if explore:
         emb = gumbel_softmax(emb, hard=True)
     else:
         emb = differentiable_onehot_from_logits(emb)
     return emb
    def update(self, sample, agent_i, parallel=False, grad_norm=0.5):
        """
        Update parameters of agent model based on sample from replay buffer
        Inputs:
            sample: EpisodeBatch, use sample[key_i] to get a specific 
                    array of obs, action, etc for agent i
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
        """
        obs, acs, rews, next_obs, dones = self.parse_sample(sample) # [(B,T,D)]*N  
        bs, ts, _ = obs[0].shape
        self.init_hidden(bs)  # use pre-defined init hiddens 
        curr_agent = self.agents[agent_i]

        # NOTE: critic update
        curr_agent.critic_optimizer.zero_grad()

        # compute target actions
        if self.alg_types[agent_i] == 'MADDPG':
            all_trgt_acs = []   # [(B,T,A)]*N
            for i, (pi, nobs) in enumerate(zip(self.target_policies, next_obs)):
                # nobs: (B,T,O)
                act_i = self.compute_action(i, pi, nobs, bs=bs, ts=ts)
                all_trgt_acs.append(act_i)  # [(B,T,A)]
            
            if self.discrete_action:    # one-hot encode action
                all_trgt_acs = [onehot_from_logits(
                    act_i.reshape(bs*ts,-1)
                ).reshape(bs,ts,-1) for act_i in all_trgt_acs] 

            # critic input, [(B,T,O)_i, ..., (B,T,A)_i, ...] -> (B,T,O*N+A*N)
            trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=-1)
        else:  # DDPG
            act_i = self.compute_action(agent_i, curr_agent.target_policy, 
                        next_obs[agent_i], bs=bs, ts=ts)

            if self.discrete_action:
                act_i = onehot_from_logits(
                    act_i.reshape(bs*ts, -1)
                ).reshape(bs, ts, -1) 

            # (B,T,O) + (B,T,A) -> (B,T,O+A)
            trgt_vf_in = torch.cat((next_obs[agent_i], act_i), dim=-1)

        # bellman targets
        target_q = self.compute_q_val(agent_i, curr_agent.target_critic, 
                        trgt_vf_in, bs=bs, ts=ts)   # (B*T,1)
        target_value = (rews[agent_i].view(-1, 1) + self.gamma * target_q *
                            (1 - dones[agent_i].view(-1, 1)))   # (B*T,1)

        # Q func
        if self.alg_types[agent_i] == 'MADDPG':
            vf_in = torch.cat((*obs, *acs), dim=1)
        else:  # DDPG
            vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1)
        actual_value = self.compute_q_val(agent_i, curr_agent.critic, 
                            vf_in, bs=bs, ts=ts)    # (B*T,1)

        # bellman errors
        vf_loss = MSELoss(actual_value, target_value.detach())
        vf_loss.backward()
        if parallel:
            average_gradients(curr_agent.critic)
        if grad_norm > 0:
            torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), grad_norm)
        curr_agent.critic_optimizer.step()

        # NOTE: policy update
        curr_agent.policy_optimizer.zero_grad()

        # current agent action (deterministic, softened)
        curr_pol_out = self.compute_action(i, curr_agent.policy, 
                                obs[agent_i], bs=bs, ts=ts) # (B,T,A)
        if self.discrete_action:
            # Forward pass as if onehot (hard=True) but backprop through a differentiable
            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
            # through discrete categorical samples, but I'm not sure if that is
            # correct since it removes the assumption of a deterministic policy for
            # DDPG. Regardless, discrete policies don't seem to learn properly without it.
            curr_pol_vf_in = gumbel_softmax(
                                curr_pol_out.reshape(bs*ts, -1), hard=True
                            ).reshape(bs, ts, -1) 
        else:
            curr_pol_vf_in = curr_pol_out

        if self.alg_types[agent_i] == 'MADDPG':
            all_pol_acs = []
            for i, pi, ob in zip(range(self.nagents), self.policies, obs):
                if i == agent_i:
                    # insert current agent act to q input 
                    all_pol_acs.append(curr_pol_vf_in)
                else: 
                    p_act_i = self.compute_action(i, pi, ob, bs=bs, ts=ts) # (B,T,A)
                    if self.discrete_action:
                        p_act_i = onehot_from_logits(
                                    p_act_i.reshape(bs*ts, -1)
                                ).reshape(bs, ts, -1) 
                    all_pol_acs.append(p_act_i)
            p_vf_in = torch.cat((*obs, *all_pol_acs), dim=-1) # (B,T,O*N+A*N)
        else:  # DDPG
            p_vf_in = torch.cat((obs[agent_i], curr_pol_vf_in), dim=-1) # (B,T,O+A)
        
        # value function to update current policy
        p_value = self.compute_q_val(agent_i, curr_agent.critic, 
                        p_vf_in, bs=bs, ts=ts)   # (B*T,1)
        pol_loss = -p_value.mean()
        # p regularization, scale down output (gaussian mean,std or logits)
        # reference: https://github.com/openai/maddpg/blob/master/maddpg/trainer/maddpg.py
        pol_loss += ((curr_pol_out.reshape(bs*ts, -1))**2).mean() * 1e-3
        pol_loss.backward()
        if parallel:
            average_gradients(curr_agent.policy)
        if grad_norm > 0:
            torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), grad_norm)
        curr_agent.policy_optimizer.step()

        # collect training statss 
        results = {
            "agent_{}_critic_loss".format(agent_i): vf_loss,
            "agent_{}_policy_loss".format(agent_i): pol_loss
        }
        return results
Beispiel #16
0
    def update_coach(self, sample):

        if any(['Coach' in alg for alg in self.alg_types]):

            observations, actions, rewards, next_obs, dones = sample

            # Computes coach embedding

            coach_embed_logits = self.coach.model(
                torch.cat((*observations, ), dim=1))

            coach_embed = gumbel_softmax(coach_embed_logits, hard=True)

            ## EMBEDDING MATCHING REGULARIZATION

            # Computes agents embeddings regularization

            J_E = 0
            for i, pi, ob in zip(range(self.nagents), self.policies,
                                 observations):
                if "Coach" in self.alg_types[i]:
                    _, agent_embed_logits = pi(ob, return_embed_logits=True)
                    J_E += self.coach.get_regularization_loss(
                        coach_embed_logits, agent_embed_logits)
            J_E = J_E / self.nagents

            ## POLICY GRADIENT WITH EMBEDDING REGULARIZATION

            # Gets actions of all agents when computed from the coach-embedding (coordinated actions)

            all_pol_acs = []
            for i, pi, ob in zip(range(self.nagents), self.policies,
                                 observations):
                if "Coach" in self.alg_types[i]:
                    if self.use_discrete_action:
                        # we need this trick to be able to differentiate
                        all_pol_acs.append(
                            differentiable_onehot_from_logits(
                                pi.partial_forward(ob, coach_embed)))
                    else:
                        all_pol_acs.append(pi.partial_forward(ob, coach_embed))

            # Gets evaluations from all critics

            vf_in = torch.cat((*observations, *all_pol_acs), dim=1)
            all_critics_eval = []
            for i, critic in enumerate(self.critics):
                if "Coach" in self.alg_types[i]:
                    all_critics_eval.append(critic(vf_in))

            J_PGE = -torch.mean(torch.stack(all_critics_eval).squeeze())

            ## BACKPROP, we backprop in two steps because the agents and the coach do not have the same weighting

            i = 0
            for loss, lam in zip([J_E, J_PGE],
                                 [self.lambdac_1, self.lambdac_2]):

                # Resets gradient buffers

                self.coach.optimizer.zero_grad()

                for agent in self.agents:
                    agent.policy.zero_grad()

                loss.backward(retain_graph=i == 0)

                # Apply coach update to coach
                if i == 0:
                    multiply_gradient(self.coach.model.parameters(),
                                      self.lambdac_3)
                torch.nn.utils.clip_grad_norm_(self.coach.model.parameters(),
                                               self.grad_clip_value)

                self.coach.optimizer.step()

                # Apply coach update to all agents

                for i, agent in enumerate(self.agents):
                    if "Coach" in self.alg_types[i]:
                        multiply_gradient(agent.policy.parameters(),
                                          lam * self.nagents)
                        torch.nn.utils.clip_grad_norm_(
                            agent.policy.parameters(), self.grad_clip_value)
                        agent.policy_optimizer.step()
                i += 1

        return J_E.data.cpu().numpy(), J_PGE.data.cpu().numpy()
Beispiel #17
0
    def update_agent(self, sample, agent_i):
        """
        Update parameters of agent model based on sample from replay buffer
        Inputs:
            sample: tuple of (observations, actions, rewards, next_observations, and episode_end masks)
                    sampled randomly from the replay buffer. Each is a list with entries corresponding
                    to each agent
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
        """
        # Extract info and agent
        observations, actions, rewards, next_obs, dones = sample
        curr_agent = self.agents[agent_i]

        # UPDATES THE CRITIC ---
        # Resets gradient buffer
        curr_agent.critic_optimizer.zero_grad()
        curr_agent.f_e_optimizer.zero_grad()

        # Gets target state-action pair (next_obs, next_actions)
        if self.alg_types[agent_i] in ['TeamMADDPG', 'MADDPG']:
            if self.use_discrete_action:  # one-hot encode action
                all_target_actions = [
                    onehot_from_logits(
                        pi(f_e(nobs)).view(-1, self.action_spaces[0],
                                           self.nagents)[:, :, i])
                    for i, (pi, f_e, nobs) in enumerate(
                        zip(self.target_policies, self.f_es, next_obs))
                ]
            else:
                all_target_actions = [
                    pi(f_e(nobs)).view(-1, self.action_spaces[0],
                                       self.nagents)[:, :, i]
                    for i, (pi, f_e, nobs) in enumerate(
                        zip(self.target_policies, self.f_es, next_obs))
                ]

            if self.critic_concat_all_obs:
                target_vf_in = torch.cat((*next_obs, *all_target_actions),
                                         dim=1)
            else:
                target_vf_in = torch.cat(
                    (curr_agent.f_e(next_obs[agent_i]), *all_target_actions),
                    dim=1)

        elif self.alg_types[agent_i] == 'DDPG':
            next_fe = curr_agent.f_e(next_obs[agent_i])
            if self.use_discrete_action:
                target_vf_in = torch.cat(
                    (next_fe,
                     onehot_from_logits(curr_agent.target_policy(next_fe))),
                    dim=1)
            else:
                target_vf_in = torch.cat(
                    (next_fe, curr_agent.target_policy(next_fe)), dim=1)

        else:
            raise NotImplemented

        # Computes target value
        target_value = (rewards[agent_i].view(-1, 1) +
                        self.gamma * curr_agent.target_critic(target_vf_in) *
                        (1 - dones.view(-1, 1)))

        # Computes current state-action value

        if self.alg_types[agent_i] in ['TeamMADDPG', 'MADDPG']:
            if self.critic_concat_all_obs:
                critic_obs = [
                    agent.f_e(obs)
                    for obs, agent in zip(observations, self.agents)
                ]
            else:
                critic_obs = [curr_agent.f_e(observations[agent_i])]
            vf_in = torch.cat((*critic_obs, *actions), dim=1)

        elif self.alg_types[agent_i] == 'DDPG':
            critic_obs = curr_agent.f_e(observations[agent_i])
            vf_in = torch.cat((critic_obs, actions[agent_i]), dim=1)
        else:
            raise NotImplemented
        actual_value = curr_agent.critic(vf_in)

        # Backpropagates
        vf_loss = MSELoss(actual_value, target_value.detach())
        # we have to retain the graph because we reuse the critic_obs for the following actor loss
        vf_loss.backward(
            retain_graph=True
        )  ##todo:make sure there is no leakage between the two losses

        # Clip gradients
        torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(),
                                       self.grad_clip_value)

        # Apply critic update
        curr_agent.critic_optimizer.step()

        # UPDATES THE ACTOR ---
        # Resets gradient buffer
        curr_agent.policy_optimizer.zero_grad()

        if self.use_discrete_action:
            # Forward pass as if onehot (hard=True) but backprop through a differentiable
            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
            # through discrete categorical samples, but I'm not sure if that is
            # correct since it removes the assumption of a deterministic policy for
            # DDPG. Regardless, discrete policies don't seem to learn properly without it.
            pol_out_all_heads = curr_agent.policy(
                curr_agent.f_e(observations[agent_i])).view(
                    -1, self.action_spaces[0], self.nagents)
            curr_pol_out = pol_out_all_heads[:, :, agent_i]
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
        else:
            pol_out_all_heads = curr_agent.policy(
                curr_agent.f_e(observations[agent_i])).view(
                    -1, self.action_spaces[0], self.nagents)
            curr_pol_out = pol_out_all_heads[:, :, agent_i]
            curr_pol_vf_in = curr_pol_out  # No Gumbel-softmax for continuous control

        # Gets state-action pair value given by the critic
        if self.alg_types[agent_i] in ['TeamMADDPG', 'MADDPG']:
            all_pol_logits = []
            all_pol_acs = []
            for i, pi, f_e, ob in zip(range(self.nagents), self.policies,
                                      self.f_es, observations):
                if i == agent_i:
                    all_pol_logits.append(curr_pol_out)
                    all_pol_acs.append(curr_pol_vf_in)
                elif self.use_discrete_action:
                    logits = pi(f_e(ob)).view(-1, self.action_spaces[0],
                                              self.nagents)[:, :, i]
                    all_pol_logits.append(logits)
                    all_pol_acs.append(onehot_from_logits(logits).detach())
                else:
                    all_pol_logits.append(None)
                    all_pol_acs.append(
                        pi(f_e(ob)).detach().view(-1, self.action_spaces[0],
                                                  self.nagents)[:, :, i])
            vf_in = torch.cat((*critic_obs, *all_pol_acs),
                              dim=1)  # Centralized critic for MADDPG agent

        elif self.alg_types[agent_i] == 'DDPG':
            vf_in = torch.cat((critic_obs, curr_pol_vf_in), dim=1)

        else:
            raise NotImplemented

        # Computes the loss
        pol_loss = -torch.mean(curr_agent.critic(vf_in))

        # Backpropagates
        pol_loss.backward(retain_graph=True if self.alg_types[agent_i] ==
                          "TeamMADDPG" else False)

        # Clip gradients
        torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(),
                                       self.grad_clip_value)

        # Apply actor update
        curr_agent.policy_optimizer.step()

        # Team Spirit (TS) regularization
        if self.alg_types[agent_i] == "TeamMADDPG":

            if self.use_discrete_action:
                real_action_logits = torch.stack(all_pol_logits, dim=2)

                real_action_probs = F.softmax(real_action_logits, dim=1)
                real_action_log_probs = F.log_softmax(real_action_logits,
                                                      dim=1)
                predicted_action_log_probs = F.log_softmax(pol_out_all_heads,
                                                           dim=1)

                # KL-divergence KL(predicted_action_dist||real_action_dist)

                ts_loss = torch.mean(
                    torch.sum(
                        real_action_probs *
                        (real_action_log_probs - predicted_action_log_probs),
                        dim=1))
            else:
                ts_loss = torch.mean(
                    torch.sum((pol_out_all_heads -
                               torch.stack(all_pol_acs, dim=2))**2,
                              dim=1))

            # Resets gradient buffer of all agents

            for agent in self.agents:
                agent.policy_optimizer.zero_grad()

            # Backpropagates through every agent of the team (including curr_agent)

            ts_loss.backward()

            for i, agent in enumerate(self.agents):
                if i == agent_i:
                    coeff = self.lambdat_1
                else:
                    coeff = self.lambdat_2

                for p in agent.policy.parameters():
                    p.grad *= coeff

                # Apply gradients

                torch.nn.utils.clip_grad_norm_(agent.policy.parameters(),
                                               self.grad_clip_value)
                agent.policy_optimizer.step()

            ts_loss = ts_loss.data.cpu().numpy()

        else:
            ts_loss = None

        ## Backpropagates for feature extractor from both backprops
        torch.nn.utils.clip_grad_norm_(curr_agent.f_e.parameters(),
                                       self.grad_clip_value)
        curr_agent.f_e_optimizer.step()

        return pol_loss.data.cpu().numpy(), vf_loss.data.cpu().numpy(), ts_loss
Beispiel #18
0
    def update_agent(self, sample, agent_i):
        """
        Update parameters of agent model based on sample from replay buffer
        Inputs:
            sample: tuple of (observations, actions, rewards, next_observations, and episode_end masks)
                    sampled randomly from the replay buffer. Each is a list with entries corresponding
                    to each agent
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
        """
        # Extract info and agent
        observations, actions, rewards, next_obs, dones = sample
        curr_agent = self.agents[agent_i]

        # UPDATES THE CRITIC ---
        # Resets gradient buffer
        curr_agent.critic_optimizer.zero_grad()
        curr_agent.f_e_optimizer.zero_grad()

        # Gets target state-action pair (next_obs, next_actions)
        if self.alg_types[agent_i] in ['SharedMADDPG', 'MADDPG']:
            if self.use_discrete_action:  # one-hot encode action
                all_target_actions = [
                    onehot_from_logits(pi(f_e(nobs))) for pi, f_e, nobs in zip(
                        self.target_policies, self.f_es, next_obs)
                ]
            else:
                all_target_actions = [
                    pi(f_e(nobs)) for pi, f_e, nobs in zip(
                        self.target_policies, self.f_es, next_obs)
                ]

            if self.alg_types[agent_i] == 'SharedMADDPG':
                next_obs.insert(0, next_obs.pop(agent_i))
                all_target_actions.insert(0, all_target_actions.pop(agent_i))

            if self.critic_concat_all_obs:
                target_vf_in = torch.cat(
                    (*next_obs, *all_target_actions), dim=1
                )  # TODO: WARNING: should probably feed that in feature_extractor
            else:
                target_vf_in = torch.cat(
                    (curr_agent.f_e(next_obs[agent_i]), *all_target_actions),
                    dim=1)

        elif self.alg_types[agent_i] in ['SharedDDPG', 'DDPG']:
            next_fe = curr_agent.f_e(next_obs[agent_i])
            if self.use_discrete_action:
                target_vf_in = torch.cat(
                    (next_fe,
                     onehot_from_logits(curr_agent.target_policy(next_fe))),
                    dim=1)
            else:
                target_vf_in = torch.cat(
                    (next_fe, curr_agent.target_policy(next_fe)), dim=1)

        else:
            raise NotImplemented

        # Computes target value
        target_value = (rewards[agent_i].view(-1, 1) +
                        self.gamma * curr_agent.target_critic(target_vf_in) *
                        (1 - dones.view(-1, 1)))

        # Computes current state-action value
        if self.alg_types[agent_i] in ['SharedMADDPG', 'MADDPG']:
            if self.alg_types[agent_i] == 'SharedMADDPG':
                observations.insert(0, observations.pop(agent_i))
                actions.insert(0, actions.pop(agent_i))
            vf_in = torch.cat((*observations, *actions), dim=1)
        elif self.alg_types[agent_i] in ['SharedDDPG', 'DDPG']:
            vf_in = torch.cat((observations[agent_i], actions[agent_i]), dim=1)
        else:
            raise NotImplemented
        actual_value = curr_agent.critic(vf_in)

        # Backpropagates
        vf_loss = MSELoss(actual_value, target_value.detach())
        # we have to retain the graph because we reuse the critic_obs for the following actor loss
        vf_loss.backward(
            retain_graph=True
        )  ##todo:make sure there is no leakage between the two losses

        # Clip gradients
        torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(),
                                       self.grad_clip_value)

        # Apply critic update
        curr_agent.critic_optimizer.step()

        # UPDATES THE ACTOR ---
        # Resets gradient buffer
        curr_agent.policy_optimizer.zero_grad()

        # We put experience data back in the general point of view
        if self.alg_types[agent_i] == 'SharedMADDPG':
            observations.insert(agent_i, observations.pop(0))

        if self.use_discrete_action:
            # Forward pass as if onehot (hard=True) but backprop through a differentiable
            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
            # through discrete categorical samples, but I'm not sure if that is
            # correct since it removes the assumption of a deterministic policy for
            # DDPG. Regardless, discrete policies don't seem to learn properly without it.
            curr_pol_out = curr_agent.policy(
                curr_agent.f_e(observations[agent_i]))
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
        else:
            curr_pol_out = curr_agent.policy(
                curr_agent.f_e(observations[agent_i]))
            curr_pol_vf_in = curr_pol_out  # No Gumbel-softmax for continuous control

        # Gets state-action pair value given by the critic
        if self.alg_types[agent_i] in ['SharedMADDPG', 'MADDPG']:
            all_pol_acs = []
            for i, pi, f_e, ob in zip(range(self.nagents), self.policies,
                                      self.f_es, observations):
                if i == agent_i:
                    all_pol_acs.append(curr_pol_vf_in)
                elif self.use_discrete_action:
                    all_pol_acs.append(
                        onehot_from_logits(pi(f_e(ob))).detach())
                else:
                    all_pol_acs.append(pi(ob).detach())

            if self.alg_types[agent_i] == 'SharedMADDPG':
                # the critic must take the point of vue of agent i
                observations.insert(0, observations.pop(agent_i))
                all_pol_acs.insert(0, all_pol_acs.pop(agent_i))

            vf_in = torch.cat((*observations, *all_pol_acs),
                              dim=1)  # Centralized critic for MADDPG agent
        elif self.alg_types[agent_i] in ['SharedDDPG', 'DDPG']:
            vf_in = torch.cat((observations[agent_i], curr_pol_vf_in), dim=1)
        else:
            raise NotImplemented

        # Computes the loss
        pol_loss = -torch.mean(curr_agent.critic(vf_in))

        # Backpropagates
        pol_loss.backward()

        # Clip gradients
        torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(),
                                       self.grad_clip_value)
        # Apply actor update
        curr_agent.policy_optimizer.step()

        ## Backpropagates for feature extractor from both backprops
        torch.nn.utils.clip_grad_norm_(curr_agent.f_e.parameters(),
                                       self.grad_clip_value)
        curr_agent.f_e_optimizer.step()

        return pol_loss.data.cpu().numpy(), vf_loss.data.cpu().numpy()
Beispiel #19
0
    def update_agent(self, sample, agent_i):
        """
        Update parameters of agent model based on sample from replay buffer
        Inputs:
            sample: tuple of (observations, actions, rewards, next_observations, and episode_end masks)
                    sampled randomly from the replay buffer. Each is a list with entries corresponding
                    to each agent
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
        """
        # Extract info and agent
        observations, actions, rewards, next_obs, dones = sample
        curr_agent = self.agents[agent_i]

        # UPDATES THE CRITIC ---
        # Resets gradient buffer

        curr_agent.critic_optimizer.zero_grad()

        # Gets target state-action pair (next_obs, next_actions)

        if self.alg_types[agent_i] in ['MADDPG', 'CoachMADDPG']:
            if self.use_discrete_action:  # one-hot encode action
                all_target_actions = [onehot_from_logits(pi(nobs)) for pi, nobs in zip(self.target_policies, next_obs)]
            else:
                all_target_actions = [pi(nobs) for pi, nobs in zip(self.target_policies, next_obs)]
            target_vf_in = torch.cat((*next_obs, *all_target_actions), dim=1)

        elif self.alg_types[agent_i] == 'DDPG':
            if self.use_discrete_action:
                target_vf_in = torch.cat((next_obs[agent_i],
                                          onehot_from_logits(curr_agent.target_policy(next_obs[agent_i]))),
                                         dim=1)
            else:
                target_vf_in = torch.cat((next_obs[agent_i],
                                          curr_agent.target_policy(next_obs[agent_i])),
                                         dim=1)

        else:
            raise NotImplemented

        # Computes target value

        target_value = (rewards[agent_i].view(-1, 1) + self.gamma *
                        curr_agent.target_critic(target_vf_in) *
                        (1 - dones.view(-1, 1)))

        # Computes current state-action value

        if self.alg_types[agent_i] in ['MADDPG', 'CoachMADDPG']:
            vf_in = torch.cat((*observations, *actions), dim=1)
        elif self.alg_types[agent_i] == 'DDPG':
            vf_in = torch.cat((observations[agent_i], actions[agent_i]), dim=1)
        else:
            raise NotImplemented
        actual_value = curr_agent.critic(vf_in)

        # Backpropagates

        vf_loss = MSELoss(actual_value, target_value.detach())
        vf_loss.backward()

        # Clip gradients

        torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), self.grad_clip_value)

        # Apply critic update

        curr_agent.critic_optimizer.step()

        # UPDATES THE ACTOR ---
        # Resets gradient buffer

        curr_agent.policy_optimizer.zero_grad()

        if self.use_discrete_action:
            # Forward pass as if onehot (hard=True) but backprop through a differentiable
            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
            # through discrete categorical samples, but I'm not sure if that is
            # correct since it removes the assumption of a deterministic policy for
            # DDPG. Regardless, discrete policies don't seem to learn properly without it.
            curr_pol_out = curr_agent.policy(observations[agent_i], return_embed_logits=False)
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
        else:
            curr_pol_out = curr_agent.policy(observations[agent_i], return_embed_logits=False)
            curr_pol_vf_in = curr_pol_out  # No Gumbel-softmax for continuous control

        # Gets state-action pair value given by the critic

        if self.alg_types[agent_i] in ['MADDPG', 'CoachMADDPG']:
            all_pol_acs = []
            for i, pi, ob in zip(range(self.nagents), self.policies, observations):
                if i == agent_i:
                    all_pol_acs.append(curr_pol_vf_in)
                elif self.use_discrete_action:
                    all_pol_acs.append(onehot_from_logits(pi(ob).detach()))
                else:
                    all_pol_acs.append(pi(ob).detach())
            vf_in = torch.cat((*observations, *all_pol_acs), dim=1)  # Centralized critic for MADDPG agent
        elif self.alg_types[agent_i] == 'DDPG':
            vf_in = torch.cat((observations[agent_i], curr_pol_vf_in), dim=1)
        else:
            raise NotImplemented

        # Computes the loss
        J_PG = -torch.mean(curr_agent.critic(vf_in))

        pol_loss = J_PG

        # Backpropagates

        pol_loss.backward()

        # Update actors

        # Clip gradients

        torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), self.grad_clip_value)

        # Apply actor update

        curr_agent.policy_optimizer.step()

        return pol_loss.data.cpu().numpy(), vf_loss.data.cpu().numpy()
    def compute(self, next_obs):
        # acs_pi_k = []
        # prob_pi_k = []
        # for no, pi in zip(next_obs, self.agents):
        #     acs_pi_k.append(gumbel_softmax(pi.policy(no), device=self.device.get_device(), hard=True))
        #     prob_pi_k.append(F.softmax(pi.policy(no), dim=1).unsqueeze(1))  # for stacking later, dim = [B, 1, A]
        #
        # final_obs = self.transition(torch.cat((*next_obs, *acs_pi_k), dim=1))
        #
        # end_idx = [0] + np.cumsum([ne_ob.shape[1] for ne_ob in next_obs]).tolist()
        # start_end = [(start, end) for start, end in zip(end_idx, end_idx[1:])]
        #
        # # P(action distribution of agent j | k takes action taken)
        # action_dist = []    # action_dist dim = [num agents, batch_size, num agents - 1, action_dim]
        # for k, no in enumerate(next_obs):
        #     prob_pi_j = []
        #     for j, pi_j in enumerate(self.agents):
        #         if j == k: continue     # computing effect on other agents
        #         final_obs_j = final_obs[:, start_end[j][0]:start_end[j][1]]
        #         prob_pi_j.append(F.softmax(pi_j.policy(final_obs_j), dim=1).unsqueeze(1))
        #     prob_pi_j = torch.cat(prob_pi_j, dim=1)
        #     action_dist.append(prob_pi_j)
        #
        # # [P(k takes action 0) * (action distribution of agent j | k takes action 0) + ...]
        # marginal_action_dists = []
        # for k, (no, pi) in enumerate(zip(next_obs, self.agents)):
        #     batch_size, action_dim = acs_pi_k[k].shape
        #     all_acs_pi_k = torch.nn.functional.one_hot(torch.arange(action_dim)).float()
        #     for one_hot_ac in all_acs_pi_k:
        #         # replace inside the original acs_pi_k, k's action
        #         acs_pi_k_modified = acs_pi_k
        #         acs_pi_k_modified[k] = one_hot_ac.unsqueeze(0).repeat(batch_size, 1)
        #         tilde_final_obs = self.transition(torch.cat((*next_obs, *acs_pi_k_modified), dim=1))
        #
        #         for j, pi_j in enumerate(self.agents):
        #             if j == k: continue  # computing effect on other agent
        #             tilde_final_obs_j = tilde_final_obs[:, start_end[j][0]:start_end[j][1]]
        #             mrgn_dist_acs_j = F.softmax(pi_j.policy(tilde_final_obs_j), dim=1).unsqueeze(1)
        #         marginal_action_dists.append()
        acs_src = []
        prob_src = []
        for no, source in zip(next_obs, self.agents):
            acs_src.append(gumbel_softmax(source.policy(no), device=self.device.get_device(), hard=True))
            prob_src.append(gumbel_softmax(source.policy(no), device=self.device.get_device(), hard=False))

        trans_in = torch.cat((*next_obs, *acs_src), dim=1)
        trans_out = self.transition(trans_in)
        prob_plan = []
        end_idx = [0] + np.cumsum([ne_ob.shape[1] for ne_ob in next_obs]).tolist()
        start_end = [(start, end) for start, end in zip(end_idx, end_idx[1:])]
        for i, (no, planning) in enumerate(zip(next_obs, self.planning)):
            nno = trans_out[:, start_end[i][0]:start_end[i][1]]
            acs_ = []
            for j, ac in enumerate(acs_src):
                if j == i: continue
                nno_other = trans_out[:, start_end[j][0]:start_end[j][1]]
                acs_.append(
                    gumbel_softmax(self.agents[j].policy(nno_other), device=self.device.get_device(), hard=True))
            acs_ = torch.cat(acs_, dim=1)
            plan_in = torch.cat((no, nno, acs_), dim=1)

            prob_plan.append(gumbel_softmax(planning(plan_in), device=self.device.get_device(), hard=False))
        prob_plan = torch.cat(prob_plan, dim=1)
        prob_src = torch.cat(prob_src, dim=1)

        # for returning the si for individual agents
        end_idx = [0] + np.cumsum([ne_ac.shape[1] for ne_ac in acs_src]).tolist()
        start_end = [(start, end) for start, end in zip(end_idx, end_idx[1:])]

        acs_src = torch.cat(acs_src, dim=1)
        si = acs_src * prob_plan - acs_src * prob_src
        result = torch.cat([si[:, start:end] for (start, end) in start_end], dim=0)
        return result
Beispiel #21
0
    def update(self, sample, agent_i, parallel=False, logger=None):
        """
        Update parameters of agent model based on sample from replay buffer
        Inputs:
            sample: tuple of (observations, actions, rewards, next
                    observations, and episode end masks) sampled randomly from
                    the replay buffer. Each is a list with entries
                    corresponding to each agent
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
            logger (SummaryWriter from Tensorboard-Pytorch):
                If passed in, important quantities will be logged
        """
        obs, acs, rews, next_obs, dones = sample
        curr_agent = self.agents[agent_i]

        curr_agent.critic_optimizer.zero_grad()
        if self.alg_types[agent_i] == 'MADDPG':
            if self.discrete_action:  # one-hot encode action
                all_trgt_acs = [
                    onehot_from_logits(pi(nobs))
                    for pi, nobs in zip(self.target_policies, next_obs)
                ]
            else:
                all_trgt_acs = [
                    pi(nobs)
                    for pi, nobs in zip(self.target_policies, next_obs)
                ]
            trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)
        else:  # DDPG
            if self.discrete_action:
                trgt_vf_in = torch.cat(
                    (next_obs[agent_i],
                     onehot_from_logits(
                         curr_agent.target_policy(next_obs[agent_i]))),
                    dim=1)
            else:
                trgt_vf_in = torch.cat(
                    (next_obs[agent_i],
                     curr_agent.target_policy(next_obs[agent_i])),
                    dim=1)
        target_value = (rews[agent_i].view(-1, 1) +
                        self.gamma * curr_agent.target_critic(trgt_vf_in) *
                        (1 - dones[agent_i].view(-1, 1)))

        if self.alg_types[agent_i] == 'MADDPG':
            vf_in = torch.cat((*obs, *acs), dim=1)
        else:  # DDPG
            vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1)
        actual_value = curr_agent.critic(vf_in)
        vf_loss = MSELoss(actual_value, target_value.detach())
        vf_loss.backward()
        if parallel:
            average_gradients(curr_agent.critic)
        torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), 0.5)
        curr_agent.critic_optimizer.step()

        curr_agent.policy_optimizer.zero_grad()

        if self.discrete_action:
            # Forward pass as if onehot (hard=True) but backprop through a differentiable
            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
            # through discrete categorical samples, but I'm not sure if that is
            # correct since it removes the assumption of a deterministic policy for
            # DDPG. Regardless, discrete policies don't seem to learn properly without it.
            curr_pol_out = curr_agent.policy(obs[agent_i])
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
        else:
            curr_pol_out = curr_agent.policy(obs[agent_i])
            curr_pol_vf_in = curr_pol_out
        if self.alg_types[agent_i] == 'MADDPG':
            all_pol_acs = []
            for i, pi, ob in zip(range(self.nagents), self.policies, obs):
                if i == agent_i:
                    all_pol_acs.append(curr_pol_vf_in)
                elif self.discrete_action:
                    all_pol_acs.append(onehot_from_logits(pi(ob)))
                else:
                    all_pol_acs.append(pi(ob))
            vf_in = torch.cat((*obs, *all_pol_acs), dim=1)
        else:  # DDPG
            vf_in = torch.cat((obs[agent_i], curr_pol_vf_in), dim=1)
        pol_loss = -curr_agent.critic(vf_in).mean()
        pol_loss += (curr_pol_out**2).mean() * 1e-3
        pol_loss.backward()
        if parallel:
            average_gradients(curr_agent.policy)
        torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), 0.5)
        curr_agent.policy_optimizer.step()
        if logger is not None:
            logger.add_scalars('agent%i/losses' % agent_i, {
                'vf_loss': vf_loss,
                'pol_loss': pol_loss
            }, self.niter)
Beispiel #22
0
    def update(self, sample, agent_i, parallel=False, logger=None):
        """
        Update parameters of agent model based on sample from replay buffer
        Inputs:
            sample: tuple of (observations, actions, rewards, next
                    observations, and episode end masks) sampled randomly from
                    the replay buffer. Each is a list with entries
                    corresponding to each agent
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
            logger (SummaryWriter from Tensorboard-Pytorch):
                If passed in, important quantities will be logged
        """
        # For RNN, the obs and next_obs both have histories
        obs, acs, rews, next_obs, dones = sample
        curr_agent = self.agents[agent_i]

        curr_agent.critic_optimizer.zero_grad()
        if self.alg_types[agent_i] == 'MADDPG':
            if self.discrete_action:  # one-hot encode action

                # This is original one, 'pi': policy, 'nobs' n_observations
                #all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in
                #                zip(self.target_policies, next_obs)]

                # Original till here

                #-------- Expanding out for debugging --------#
                all_trgt_acs = []
                for pi, nobs in zip(self.target_policies, next_obs):
                    temp = onehot_from_logits(pi(nobs))
                    #print(temp)
                    all_trgt_acs.append(temp)
                # -------- End debug -------------------------#

            else:
                all_trgt_acs = [
                    pi(nobs)
                    for pi, nobs in zip(self.target_policies, next_obs)
                ]

            # Get the most current observation from the history to calculate the target value
            t0_next_obs = [[], [], []]
            for a in range(self.nagents):
                t0_next_obs[a] = torch.tensor(np.zeros(
                    (next_obs[0].shape[0], 18)),
                                              dtype=torch.float)
            # the next_obs[0].shape[0] gives the batch size
            # TODO: change it to be a parameter
            # Only keep the current obs for critic VF
            for n in range(self.nagents):  # for each agents
                for b in range(
                        next_obs[0].shape[0]):  # for the number of batches
                    t0_next_obs[n][b][:] = next_obs[n][b][0:18]

            # ORIGINAL was \/
            #trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)
            trgt_vf_in = torch.cat((*t0_next_obs, *all_trgt_acs), dim=1)
            # It is working till here. Only kept the current obs for critic VF
        else:  # DDPG
            # DDPG only knows the particular agent's observation and policy
            # Whereas, MADDPG has access to all other agents' policies
            # TODO: grab only the current observation to send to the critic
            t0_next_obs = [[], [], []]
            for a in range(self.nagents):
                t0_next_obs[a] = torch.tensor(np.zeros(
                    (next_obs[0].shape[0], 18)),
                                              dtype=torch.float)
            for n in range(self.nagents):  # for each agents
                for b in range(
                        next_obs[0].shape[0]):  # for the number of batches
                    t0_next_obs[n][b][:] = next_obs[n][b][0:18]

            # Originally it would be next_obs[agent_i] instead of t0_next_obs[agent_i]
            if self.discrete_action:
                trgt_vf_in = torch.cat(
                    (t0_next_obs[agent_i],
                     onehot_from_logits(
                         curr_agent.target_policy(next_obs[agent_i]))),
                    dim=1)
            else:
                trgt_vf_in = torch.cat(
                    (next_obs[agent_i],
                     curr_agent.target_policy(next_obs[agent_i])),
                    dim=1)

        target_value = (rews[agent_i].view(-1, 1) +
                        self.gamma * curr_agent.target_critic(trgt_vf_in) *
                        (1 - dones[agent_i].view(-1, 1)))

        ##### Just get the current observation (i.e., without history) ##########
        # Reason: Critic VF does not need history
        # Copied the same as in t0_next_obs, since BOTH obs and next_obs have HISTORIES.
        t0_obs = [[], [], []]
        for a in range(self.nagents):
            t0_obs[a] = torch.tensor(np.zeros((obs[0].shape[0], 18)),
                                     dtype=torch.float)
        for n in range(self.nagents):  # for each agents
            for b in range(obs[0].shape[0]):  # for the number of batches
                t0_obs[n][b][:] = obs[n][b][0:18]
        ###################################################

        if self.alg_types[agent_i] == 'MADDPG':
            vf_in = torch.cat((*t0_obs, *acs), dim=1)
        else:  # DDPG  #TODO: below, might have to change obs to t0_obs, when using DDPG
            vf_in = torch.cat((t0_obs[agent_i], acs[agent_i]), dim=1)
        actual_value = curr_agent.critic(vf_in)
        vf_loss = MSELoss(actual_value, target_value.detach())
        vf_loss.backward()
        if parallel:
            average_gradients(curr_agent.critic)
        torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), 0.5)
        curr_agent.critic_optimizer.step()

        curr_agent.policy_optimizer.zero_grad()

        if self.discrete_action:
            # Forward pass as if onehot (hard=True) but backprop through a differentiable
            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
            # through discrete categorical samples, but I'm not sure if that is
            # correct since it removes the assumption of a deterministic policy for
            # DDPG. Regardless, discrete policies don't seem to learn properly without it.
            '''
            Now, we are back to forwarding policy, so we need to use obs with history
            '''
            curr_pol_out = curr_agent.policy(obs[agent_i])
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
            # Seems to be working fine till here
        else:
            curr_pol_out = curr_agent.policy(obs[agent_i])
            curr_pol_vf_in = curr_pol_out
        if self.alg_types[agent_i] == 'MADDPG':
            all_pol_acs = []
            for i, pi, ob in zip(range(self.nagents), self.policies, obs):
                if i == agent_i:
                    all_pol_acs.append(curr_pol_vf_in)
                elif self.discrete_action:
                    all_pol_acs.append(onehot_from_logits(pi(ob)))
                else:
                    all_pol_acs.append(pi(ob))
            # Originally:
            #vf_in = torch.cat((*obs, *all_pol_acs), dim=1)
            vf_in = torch.cat((*t0_obs, *all_pol_acs), dim=1)
        else:  # DDPG
            vf_in = torch.cat((t0_obs[agent_i], curr_pol_vf_in), dim=1)
        # TODO: FIX THIS
        pol_loss = -curr_agent.critic(vf_in).mean()
        pol_loss += (curr_pol_out**2).mean() * 1e-3
        pol_loss.backward()
        if parallel:
            average_gradients(curr_agent.policy)
        torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), 0.5)
        curr_agent.policy_optimizer.step()
        if logger is not None:
            logger.add_scalars('agent%i/losses' % agent_i, {
                'vf_loss': vf_loss,
                'pol_loss': pol_loss
            }, self.niter)