예제 #1
0
 def forward(self,
             obs,
             sample=True,
             return_all_probs=False,
             return_log_pi=False,
             regularize=False,
             return_entropy=False):
     out = super(DiscretePolicy, self).forward(obs)
     probs = F.softmax(out, dim=1)
     on_gpu = next(self.parameters()).is_cuda
     if sample:
         int_act, act = categorical_sample(probs, use_cuda=on_gpu)
     else:
         act = onehot_from_logits(probs)
     rets = [act]
     if return_log_pi or return_entropy:
         log_probs = F.log_softmax(out, dim=1)
     if return_all_probs:
         rets.append(probs)
     if return_log_pi:
         # return log probability of selected action
         rets.append(log_probs.gather(1, int_act))
     if regularize:
         rets.append([(out**2).mean()])
     if return_entropy:
         rets.append(-(log_probs * probs).sum(1).mean())
     if len(rets) == 1:
         return rets[0]
     return rets
예제 #2
0
    def step(self, obs, explore=False):
        """
        Take a step forward in environment for a minibatch of observations
        equivalent to `act` or `compute_actions`
        Arguments:
            obs: (B,O)
            explore: Whether or not to add exploration noise
        Returns:
            action: dict of actions for this agent, (B,A)
        """
        with torch.no_grad():
            action, hidden_states = self.policy(obs, self.policy_hidden_states)
            self.policy_hidden_states = hidden_states  # if mlp, still defafult None

            if self.discrete_action:
                for k in action:
                    if explore:
                        action[k] = gumbel_softmax(action[k], hard=True)
                    else:
                        action[k] = onehot_from_logits(action[k])
            else:  # continuous action
                idx = 0
                noise = Variable(Tensor(self.exploration.noise()),
                                 requires_grad=False)
                for k in action:
                    if explore:
                        dim = action[k].shape[-1]
                        action[k] += noise[idx:idx + dim]
                        idx += dim
                    action[k] = action[k].clamp(-1, 1)
        return action
예제 #3
0
 def step(self, obs, explore=False):
     """
     Take a step forward in environment for a minibatch of observations
     Inputs:
         obs (PyTorch Variable): Observations for this agent
         explore (boolean): Whether or not to add exploration noise
     Outputs:
         action (PyTorch Variable): Actions for this agent
     """
     # print('----', obs)
     action = self.policy(obs)
     # print('>>>>>>>>', action)
     if self.discrete_action:
         # print('agents.py discrete_action yes')
         if explore:
             action = gumbel_softmax(action, hard=True)
         else:
             action = onehot_from_logits(action)
     else:  # continuous action
         # print('agents.py continuous_action yes')
         if explore:
             # print('agents.py explore yes')
             # print('action before noise', action)
             action += Variable(Tensor(self.exploration.noise()),
                                requires_grad=False)
             # print('action after noise', action)
         action = action.clamp(-1, 1)
         # print('action after clamp', action)
     # print('>>>>>>>>>>>>>>>>', action)
     return action
예제 #4
0
 def _soft_act(x):  # x: (B,A)
     if not self.discrete_action:
         return x
     if requires_grad:
         return gumbel_softmax(x, hard=True)
     else:
         return onehot_from_logits(x)
예제 #5
0
 def _soft_act(self, x, requires_grad=True):
     """ soften action if discrete, x: (B,A) """
     if not self.discrete_action:
         return x
     if requires_grad:
         return gumbel_softmax(x, hard=True)
     else:
         return onehot_from_logits(x)
예제 #6
0
    def update(self, sample, agent_i, parallel=False, logger=None):
        """
        Update parameters of agent model based on sample from replay buffer
        Inputs:
            sample: tuple of (observations, actions, rewards, next
                    observations, and episode end masks) sampled randomly from
                    the replay buffer. Each is a list with entries
                    corresponding to each agent
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
            logger (SummaryWriter from Tensorboard-Pytorch):
                If passed in, important quantities will be logged
        """
        obs, acs, rews, next_obs, dones = sample
        curr_agent = self.agents[agent_i]

        curr_agent.critic_optimizer.zero_grad()
        if self.alg_types[agent_i] == 'MADDPG':
            if self.discrete_action:  # one-hot encode action
                all_trgt_acs = [
                    onehot_from_logits(pi(nobs))
                    for pi, nobs in zip(self.target_policies, next_obs)
                ]
            else:
                all_trgt_acs = [
                    pi(nobs)
                    for pi, nobs in zip(self.target_policies, next_obs)
                ]
            trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)
        else:  # DDPG
            if self.discrete_action:
                trgt_vf_in = torch.cat(
                    (next_obs[agent_i],
                     onehot_from_logits(
                         curr_agent.target_policy(next_obs[agent_i]))),
                    dim=1)
            else:
                trgt_vf_in = torch.cat(
                    (next_obs[agent_i],
                     curr_agent.target_policy(next_obs[agent_i])),
                    dim=1)
        target_value = (rews[agent_i].view(-1, 1) +
                        self.gamma * curr_agent.target_critic(trgt_vf_in) *
                        (1 - dones[agent_i].view(-1, 1)))

        if self.alg_types[agent_i] == 'MADDPG':
            vf_in = torch.cat((*obs, *acs), dim=1)
        else:  # DDPG
            vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1)
        actual_value = curr_agent.critic(vf_in)
        vf_loss = MSELoss(actual_value, target_value.detach())
        vf_loss.backward()
        if parallel:
            average_gradients(curr_agent.critic)
        torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), 0.5)
        curr_agent.critic_optimizer.step()

        curr_agent.policy_optimizer.zero_grad()

        if self.discrete_action:
            # Forward pass as if onehot (hard=True) but backprop through a differentiable
            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
            # through discrete categorical samples, but I'm not sure if that is
            # correct since it removes the assumption of a deterministic policy for
            # DDPG. Regardless, discrete policies don't seem to learn properly without it.
            curr_pol_out = curr_agent.policy(obs[agent_i])
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
        else:
            curr_pol_out = curr_agent.policy(obs[agent_i])
            curr_pol_vf_in = curr_pol_out
        if self.alg_types[agent_i] == 'MADDPG':
            all_pol_acs = []
            for i, pi, ob in zip(range(self.nagents), self.policies, obs):
                if i == agent_i:
                    all_pol_acs.append(curr_pol_vf_in)
                elif self.discrete_action:
                    all_pol_acs.append(onehot_from_logits(pi(ob)))
                else:
                    all_pol_acs.append(pi(ob))
            vf_in = torch.cat((*obs, *all_pol_acs), dim=1)
        else:  # DDPG
            vf_in = torch.cat((obs[agent_i], curr_pol_vf_in), dim=1)
        pol_loss = -curr_agent.critic(vf_in).mean()
        pol_loss += (curr_pol_out**2).mean() * 1e-3
        pol_loss.backward()
        if parallel:
            average_gradients(curr_agent.policy)
        torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), 0.5)
        curr_agent.policy_optimizer.step()
        if logger is not None:
            logger.add_scalars('agent%i/losses' % agent_i, {
                'vf_loss': vf_loss,
                'pol_loss': pol_loss
            }, self.niter)
예제 #7
0
    def update_agent(self, sample, agent_i):
        """
        Update parameters of agent model based on sample from replay buffer
        Inputs:
            sample: tuple of (observations, actions, rewards, next_observations, and episode_end masks)
                    sampled randomly from the replay buffer. Each is a list with entries corresponding
                    to each agent
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
        """
        # Extract info and agent
        observations, actions, rewards, next_obs, dones = sample
        curr_agent = self.agents[agent_i]

        # UPDATES THE CRITIC ---
        # Resets gradient buffer
        curr_agent.critic_optimizer.zero_grad()
        curr_agent.f_e_optimizer.zero_grad()

        # Gets target state-action pair (next_obs, next_actions)
        if self.alg_types[agent_i] in ['SharedMADDPG', 'MADDPG']:
            if self.use_discrete_action:  # one-hot encode action
                all_target_actions = [
                    onehot_from_logits(pi(f_e(nobs))) for pi, f_e, nobs in zip(
                        self.target_policies, self.f_es, next_obs)
                ]
            else:
                all_target_actions = [
                    pi(f_e(nobs)) for pi, f_e, nobs in zip(
                        self.target_policies, self.f_es, next_obs)
                ]

            if self.alg_types[agent_i] == 'SharedMADDPG':
                next_obs.insert(0, next_obs.pop(agent_i))
                all_target_actions.insert(0, all_target_actions.pop(agent_i))

            if self.critic_concat_all_obs:
                target_vf_in = torch.cat(
                    (*next_obs, *all_target_actions), dim=1
                )  # TODO: WARNING: should probably feed that in feature_extractor
            else:
                target_vf_in = torch.cat(
                    (curr_agent.f_e(next_obs[agent_i]), *all_target_actions),
                    dim=1)

        elif self.alg_types[agent_i] in ['SharedDDPG', 'DDPG']:
            next_fe = curr_agent.f_e(next_obs[agent_i])
            if self.use_discrete_action:
                target_vf_in = torch.cat(
                    (next_fe,
                     onehot_from_logits(curr_agent.target_policy(next_fe))),
                    dim=1)
            else:
                target_vf_in = torch.cat(
                    (next_fe, curr_agent.target_policy(next_fe)), dim=1)

        else:
            raise NotImplemented

        # Computes target value
        target_value = (rewards[agent_i].view(-1, 1) +
                        self.gamma * curr_agent.target_critic(target_vf_in) *
                        (1 - dones.view(-1, 1)))

        # Computes current state-action value
        if self.alg_types[agent_i] in ['SharedMADDPG', 'MADDPG']:
            if self.alg_types[agent_i] == 'SharedMADDPG':
                observations.insert(0, observations.pop(agent_i))
                actions.insert(0, actions.pop(agent_i))
            vf_in = torch.cat((*observations, *actions), dim=1)
        elif self.alg_types[agent_i] in ['SharedDDPG', 'DDPG']:
            vf_in = torch.cat((observations[agent_i], actions[agent_i]), dim=1)
        else:
            raise NotImplemented
        actual_value = curr_agent.critic(vf_in)

        # Backpropagates
        vf_loss = MSELoss(actual_value, target_value.detach())
        # we have to retain the graph because we reuse the critic_obs for the following actor loss
        vf_loss.backward(
            retain_graph=True
        )  ##todo:make sure there is no leakage between the two losses

        # Clip gradients
        torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(),
                                       self.grad_clip_value)

        # Apply critic update
        curr_agent.critic_optimizer.step()

        # UPDATES THE ACTOR ---
        # Resets gradient buffer
        curr_agent.policy_optimizer.zero_grad()

        # We put experience data back in the general point of view
        if self.alg_types[agent_i] == 'SharedMADDPG':
            observations.insert(agent_i, observations.pop(0))

        if self.use_discrete_action:
            # Forward pass as if onehot (hard=True) but backprop through a differentiable
            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
            # through discrete categorical samples, but I'm not sure if that is
            # correct since it removes the assumption of a deterministic policy for
            # DDPG. Regardless, discrete policies don't seem to learn properly without it.
            curr_pol_out = curr_agent.policy(
                curr_agent.f_e(observations[agent_i]))
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
        else:
            curr_pol_out = curr_agent.policy(
                curr_agent.f_e(observations[agent_i]))
            curr_pol_vf_in = curr_pol_out  # No Gumbel-softmax for continuous control

        # Gets state-action pair value given by the critic
        if self.alg_types[agent_i] in ['SharedMADDPG', 'MADDPG']:
            all_pol_acs = []
            for i, pi, f_e, ob in zip(range(self.nagents), self.policies,
                                      self.f_es, observations):
                if i == agent_i:
                    all_pol_acs.append(curr_pol_vf_in)
                elif self.use_discrete_action:
                    all_pol_acs.append(
                        onehot_from_logits(pi(f_e(ob))).detach())
                else:
                    all_pol_acs.append(pi(ob).detach())

            if self.alg_types[agent_i] == 'SharedMADDPG':
                # the critic must take the point of vue of agent i
                observations.insert(0, observations.pop(agent_i))
                all_pol_acs.insert(0, all_pol_acs.pop(agent_i))

            vf_in = torch.cat((*observations, *all_pol_acs),
                              dim=1)  # Centralized critic for MADDPG agent
        elif self.alg_types[agent_i] in ['SharedDDPG', 'DDPG']:
            vf_in = torch.cat((observations[agent_i], curr_pol_vf_in), dim=1)
        else:
            raise NotImplemented

        # Computes the loss
        pol_loss = -torch.mean(curr_agent.critic(vf_in))

        # Backpropagates
        pol_loss.backward()

        # Clip gradients
        torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(),
                                       self.grad_clip_value)
        # Apply actor update
        curr_agent.policy_optimizer.step()

        ## Backpropagates for feature extractor from both backprops
        torch.nn.utils.clip_grad_norm_(curr_agent.f_e.parameters(),
                                       self.grad_clip_value)
        curr_agent.f_e_optimizer.step()

        return pol_loss.data.cpu().numpy(), vf_loss.data.cpu().numpy()
예제 #8
0
    def update_agent(self, sample, agent_i):
        """
        Update parameters of agent model based on sample from replay buffer
        Inputs:
            sample: tuple of (observations, actions, rewards, next_observations, and episode_end masks)
                    sampled randomly from the replay buffer. Each is a list with entries corresponding
                    to each agent
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
        """
        # Extract info and agent
        observations, actions, rewards, next_obs, dones = sample
        curr_agent = self.agents[agent_i]

        # UPDATES THE CRITIC ---
        # Resets gradient buffer

        curr_agent.critic_optimizer.zero_grad()

        # Gets target state-action pair (next_obs, next_actions)

        if self.alg_types[agent_i] in ['MADDPG', 'CoachMADDPG']:
            if self.use_discrete_action:  # one-hot encode action
                all_target_actions = [onehot_from_logits(pi(nobs)) for pi, nobs in zip(self.target_policies, next_obs)]
            else:
                all_target_actions = [pi(nobs) for pi, nobs in zip(self.target_policies, next_obs)]
            target_vf_in = torch.cat((*next_obs, *all_target_actions), dim=1)

        elif self.alg_types[agent_i] == 'DDPG':
            if self.use_discrete_action:
                target_vf_in = torch.cat((next_obs[agent_i],
                                          onehot_from_logits(curr_agent.target_policy(next_obs[agent_i]))),
                                         dim=1)
            else:
                target_vf_in = torch.cat((next_obs[agent_i],
                                          curr_agent.target_policy(next_obs[agent_i])),
                                         dim=1)

        else:
            raise NotImplemented

        # Computes target value

        target_value = (rewards[agent_i].view(-1, 1) + self.gamma *
                        curr_agent.target_critic(target_vf_in) *
                        (1 - dones.view(-1, 1)))

        # Computes current state-action value

        if self.alg_types[agent_i] in ['MADDPG', 'CoachMADDPG']:
            vf_in = torch.cat((*observations, *actions), dim=1)
        elif self.alg_types[agent_i] == 'DDPG':
            vf_in = torch.cat((observations[agent_i], actions[agent_i]), dim=1)
        else:
            raise NotImplemented
        actual_value = curr_agent.critic(vf_in)

        # Backpropagates

        vf_loss = MSELoss(actual_value, target_value.detach())
        vf_loss.backward()

        # Clip gradients

        torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), self.grad_clip_value)

        # Apply critic update

        curr_agent.critic_optimizer.step()

        # UPDATES THE ACTOR ---
        # Resets gradient buffer

        curr_agent.policy_optimizer.zero_grad()

        if self.use_discrete_action:
            # Forward pass as if onehot (hard=True) but backprop through a differentiable
            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
            # through discrete categorical samples, but I'm not sure if that is
            # correct since it removes the assumption of a deterministic policy for
            # DDPG. Regardless, discrete policies don't seem to learn properly without it.
            curr_pol_out = curr_agent.policy(observations[agent_i], return_embed_logits=False)
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
        else:
            curr_pol_out = curr_agent.policy(observations[agent_i], return_embed_logits=False)
            curr_pol_vf_in = curr_pol_out  # No Gumbel-softmax for continuous control

        # Gets state-action pair value given by the critic

        if self.alg_types[agent_i] in ['MADDPG', 'CoachMADDPG']:
            all_pol_acs = []
            for i, pi, ob in zip(range(self.nagents), self.policies, observations):
                if i == agent_i:
                    all_pol_acs.append(curr_pol_vf_in)
                elif self.use_discrete_action:
                    all_pol_acs.append(onehot_from_logits(pi(ob).detach()))
                else:
                    all_pol_acs.append(pi(ob).detach())
            vf_in = torch.cat((*observations, *all_pol_acs), dim=1)  # Centralized critic for MADDPG agent
        elif self.alg_types[agent_i] == 'DDPG':
            vf_in = torch.cat((observations[agent_i], curr_pol_vf_in), dim=1)
        else:
            raise NotImplemented

        # Computes the loss
        J_PG = -torch.mean(curr_agent.critic(vf_in))

        pol_loss = J_PG

        # Backpropagates

        pol_loss.backward()

        # Update actors

        # Clip gradients

        torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), self.grad_clip_value)

        # Apply actor update

        curr_agent.policy_optimizer.step()

        return pol_loss.data.cpu().numpy(), vf_loss.data.cpu().numpy()
예제 #9
0
    def update_agent(self, sample, agent_i):
        """
        Update parameters of agent model based on sample from replay buffer
        Inputs:
            sample: tuple of (observations, actions, rewards, next_observations, and episode_end masks)
                    sampled randomly from the replay buffer. Each is a list with entries corresponding
                    to each agent
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
        """
        # Extract info and agent
        observations, actions, rewards, next_obs, dones = sample
        curr_agent = self.agents[agent_i]

        # UPDATES THE CRITIC ---
        # Resets gradient buffer
        curr_agent.critic_optimizer.zero_grad()
        curr_agent.f_e_optimizer.zero_grad()

        # Gets target state-action pair (next_obs, next_actions)
        if self.alg_types[agent_i] in ['TeamMADDPG', 'MADDPG']:
            if self.use_discrete_action:  # one-hot encode action
                all_target_actions = [
                    onehot_from_logits(
                        pi(f_e(nobs)).view(-1, self.action_spaces[0],
                                           self.nagents)[:, :, i])
                    for i, (pi, f_e, nobs) in enumerate(
                        zip(self.target_policies, self.f_es, next_obs))
                ]
            else:
                all_target_actions = [
                    pi(f_e(nobs)).view(-1, self.action_spaces[0],
                                       self.nagents)[:, :, i]
                    for i, (pi, f_e, nobs) in enumerate(
                        zip(self.target_policies, self.f_es, next_obs))
                ]

            if self.critic_concat_all_obs:
                target_vf_in = torch.cat((*next_obs, *all_target_actions),
                                         dim=1)
            else:
                target_vf_in = torch.cat(
                    (curr_agent.f_e(next_obs[agent_i]), *all_target_actions),
                    dim=1)

        elif self.alg_types[agent_i] == 'DDPG':
            next_fe = curr_agent.f_e(next_obs[agent_i])
            if self.use_discrete_action:
                target_vf_in = torch.cat(
                    (next_fe,
                     onehot_from_logits(curr_agent.target_policy(next_fe))),
                    dim=1)
            else:
                target_vf_in = torch.cat(
                    (next_fe, curr_agent.target_policy(next_fe)), dim=1)

        else:
            raise NotImplemented

        # Computes target value
        target_value = (rewards[agent_i].view(-1, 1) +
                        self.gamma * curr_agent.target_critic(target_vf_in) *
                        (1 - dones.view(-1, 1)))

        # Computes current state-action value

        if self.alg_types[agent_i] in ['TeamMADDPG', 'MADDPG']:
            if self.critic_concat_all_obs:
                critic_obs = [
                    agent.f_e(obs)
                    for obs, agent in zip(observations, self.agents)
                ]
            else:
                critic_obs = [curr_agent.f_e(observations[agent_i])]
            vf_in = torch.cat((*critic_obs, *actions), dim=1)

        elif self.alg_types[agent_i] == 'DDPG':
            critic_obs = curr_agent.f_e(observations[agent_i])
            vf_in = torch.cat((critic_obs, actions[agent_i]), dim=1)
        else:
            raise NotImplemented
        actual_value = curr_agent.critic(vf_in)

        # Backpropagates
        vf_loss = MSELoss(actual_value, target_value.detach())
        # we have to retain the graph because we reuse the critic_obs for the following actor loss
        vf_loss.backward(
            retain_graph=True
        )  ##todo:make sure there is no leakage between the two losses

        # Clip gradients
        torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(),
                                       self.grad_clip_value)

        # Apply critic update
        curr_agent.critic_optimizer.step()

        # UPDATES THE ACTOR ---
        # Resets gradient buffer
        curr_agent.policy_optimizer.zero_grad()

        if self.use_discrete_action:
            # Forward pass as if onehot (hard=True) but backprop through a differentiable
            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
            # through discrete categorical samples, but I'm not sure if that is
            # correct since it removes the assumption of a deterministic policy for
            # DDPG. Regardless, discrete policies don't seem to learn properly without it.
            pol_out_all_heads = curr_agent.policy(
                curr_agent.f_e(observations[agent_i])).view(
                    -1, self.action_spaces[0], self.nagents)
            curr_pol_out = pol_out_all_heads[:, :, agent_i]
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
        else:
            pol_out_all_heads = curr_agent.policy(
                curr_agent.f_e(observations[agent_i])).view(
                    -1, self.action_spaces[0], self.nagents)
            curr_pol_out = pol_out_all_heads[:, :, agent_i]
            curr_pol_vf_in = curr_pol_out  # No Gumbel-softmax for continuous control

        # Gets state-action pair value given by the critic
        if self.alg_types[agent_i] in ['TeamMADDPG', 'MADDPG']:
            all_pol_logits = []
            all_pol_acs = []
            for i, pi, f_e, ob in zip(range(self.nagents), self.policies,
                                      self.f_es, observations):
                if i == agent_i:
                    all_pol_logits.append(curr_pol_out)
                    all_pol_acs.append(curr_pol_vf_in)
                elif self.use_discrete_action:
                    logits = pi(f_e(ob)).view(-1, self.action_spaces[0],
                                              self.nagents)[:, :, i]
                    all_pol_logits.append(logits)
                    all_pol_acs.append(onehot_from_logits(logits).detach())
                else:
                    all_pol_logits.append(None)
                    all_pol_acs.append(
                        pi(f_e(ob)).detach().view(-1, self.action_spaces[0],
                                                  self.nagents)[:, :, i])
            vf_in = torch.cat((*critic_obs, *all_pol_acs),
                              dim=1)  # Centralized critic for MADDPG agent

        elif self.alg_types[agent_i] == 'DDPG':
            vf_in = torch.cat((critic_obs, curr_pol_vf_in), dim=1)

        else:
            raise NotImplemented

        # Computes the loss
        pol_loss = -torch.mean(curr_agent.critic(vf_in))

        # Backpropagates
        pol_loss.backward(retain_graph=True if self.alg_types[agent_i] ==
                          "TeamMADDPG" else False)

        # Clip gradients
        torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(),
                                       self.grad_clip_value)

        # Apply actor update
        curr_agent.policy_optimizer.step()

        # Team Spirit (TS) regularization
        if self.alg_types[agent_i] == "TeamMADDPG":

            if self.use_discrete_action:
                real_action_logits = torch.stack(all_pol_logits, dim=2)

                real_action_probs = F.softmax(real_action_logits, dim=1)
                real_action_log_probs = F.log_softmax(real_action_logits,
                                                      dim=1)
                predicted_action_log_probs = F.log_softmax(pol_out_all_heads,
                                                           dim=1)

                # KL-divergence KL(predicted_action_dist||real_action_dist)

                ts_loss = torch.mean(
                    torch.sum(
                        real_action_probs *
                        (real_action_log_probs - predicted_action_log_probs),
                        dim=1))
            else:
                ts_loss = torch.mean(
                    torch.sum((pol_out_all_heads -
                               torch.stack(all_pol_acs, dim=2))**2,
                              dim=1))

            # Resets gradient buffer of all agents

            for agent in self.agents:
                agent.policy_optimizer.zero_grad()

            # Backpropagates through every agent of the team (including curr_agent)

            ts_loss.backward()

            for i, agent in enumerate(self.agents):
                if i == agent_i:
                    coeff = self.lambdat_1
                else:
                    coeff = self.lambdat_2

                for p in agent.policy.parameters():
                    p.grad *= coeff

                # Apply gradients

                torch.nn.utils.clip_grad_norm_(agent.policy.parameters(),
                                               self.grad_clip_value)
                agent.policy_optimizer.step()

            ts_loss = ts_loss.data.cpu().numpy()

        else:
            ts_loss = None

        ## Backpropagates for feature extractor from both backprops
        torch.nn.utils.clip_grad_norm_(curr_agent.f_e.parameters(),
                                       self.grad_clip_value)
        curr_agent.f_e_optimizer.step()

        return pol_loss.data.cpu().numpy(), vf_loss.data.cpu().numpy(), ts_loss
예제 #10
0
    def update(self, sample, agent_i, parallel=False, logger=None):
        """
        Update parameters of agent model based on sample from replay buffer
        Inputs:
            sample: tuple of (observations, actions, rewards, next
                    observations, and episode end masks) sampled randomly from
                    the replay buffer. Each is a list with entries
                    corresponding to each agent
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
            logger (SummaryWriter from Tensorboard-Pytorch):
                If passed in, important quantities will be logged
        """
        obs, acs, rews, next_obs, dones = sample
        curr_agent = self.agents[agent_i]

        curr_agent.critic_optimizer.zero_grad()
        if self.alg_types[agent_i] == 'MADDPG':
            if self.discrete_action:  # one-hot encode action
                all_trgt_acs = [
                    onehot_from_logits(pi(nobs))
                    for pi, nobs in zip(self.target_policies, next_obs)
                ]  # a'=mu'(o') Have all agents'
            else:
                all_trgt_acs = [
                    pi(nobs)
                    for pi, nobs in zip(self.target_policies, next_obs)
                ]

            # ==========================Adding noise====================
            if self.noisy_sharing == True:
                #noisy_all_trgt_acs = self.noisy_sharing_discrete(all_trgt_acs,agent_i)
                #all_trgt_acs = noisy_all_trgt_acs
                noisy_acs = self.noisy_sharing_discrete(acs, agent_i)
                acs = noisy_acs
                # print(self.noisy_SNR)

            # ==================End of Adding noise====================
            trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)

            # =========================Differential Obs========================
            # ============== Dedicate for simple_speaker_listener =============
            # The est_action is used to replace acs[1]
            if self.game_id == 'simple_speaker_listener' and self.est_ac == True:
                diff_pos = (next_obs[0] - obs[0])[:, -2:]
                tmp_p = torch.transpose(diff_pos.ge(torch.max(diff_pos) * 0.8),
                                        0, 1)
                tmp_p[0] = tmp_p[0] * 1
                tmp_p[1] = tmp_p[1] * 3
                tmp_n = torch.transpose(diff_pos.le(torch.min(diff_pos) * 0.8),
                                        0, 1)
                tmp_n[0] = tmp_n[0] * 2
                tmp_n[1] = tmp_n[1] * 4
                mask = torch.transpose(tmp_p, 0, 1) + torch.transpose(
                    tmp_n, 0, 1)
                est_action = mask.sum(dim=1)
                est_action = torch.zeros(len(est_action),
                                         acs[1].shape[1]).scatter_(
                                             dim=1,
                                             index=est_action.view(-1, 1),
                                             value=1)
                acs[1] = est_action

            # =======================End of differential Obs ==================
        else:  # DDPG
            if self.discrete_action:
                trgt_vf_in = torch.cat(
                    (next_obs[agent_i],
                     onehot_from_logits(
                         curr_agent.target_policy(next_obs[agent_i]))),
                    dim=1)
            else:  # a'=mu(o') only have current agent's
                trgt_vf_in = torch.cat(
                    (next_obs[agent_i],
                     curr_agent.target_policy(next_obs[agent_i])),
                    dim=1)
        target_value = (rews[agent_i].view(-1, 1) +
                        self.gamma * curr_agent.target_critic(trgt_vf_in) *
                        (1 - dones[agent_i].view(-1, 1)))  #y^j

        if self.alg_types[agent_i] == 'MADDPG':
            vf_in = torch.cat((*obs, *acs), dim=1)
        else:  # DDPG
            vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1)

        actual_value = curr_agent.critic(vf_in)
        vf_loss = MSELoss(actual_value, target_value.detach())
        vf_loss.backward()
        if parallel:
            average_gradients(curr_agent.critic)
        torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5)
        curr_agent.critic_optimizer.step()

        # ============== Here for policy network training =====================
        curr_agent.policy_optimizer.zero_grad()
        if self.discrete_action:
            # Forward pass as if onehot (hard=True) but backprop through a differentiable
            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
            # through discrete categorical samples, but I'm not sure if that is
            # correct since it removes the assumption of a deterministic policy for
            # DDPG. Regardless, discrete policies don't seem to learn properly without it.
            curr_pol_out = curr_agent.policy(obs[agent_i])
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
        else:
            curr_pol_out = curr_agent.policy(obs[agent_i])
            curr_pol_vf_in = curr_pol_out
        if self.alg_types[agent_i] == 'MADDPG':
            all_pol_acs = []
            for i, pi, ob in zip(range(self.nagents), self.policies, obs):
                # Is it correct to train mu using all others' policies???
                if i == agent_i:
                    all_pol_acs.append(curr_pol_vf_in)
                elif self.discrete_action:
                    all_pol_acs.append(onehot_from_logits(pi(ob)))
                else:
                    all_pol_acs.append(pi(ob))
            vf_in = torch.cat((*obs, *all_pol_acs), dim=1)
        else:  # DDPG
            vf_in = torch.cat((obs[agent_i], curr_pol_vf_in), dim=1)
        pol_loss = -curr_agent.critic(vf_in).mean()
        pol_loss += (curr_pol_out**2).mean() * 1e-3
        pol_loss.backward()
        if parallel:
            average_gradients(curr_agent.policy)
        torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(),
                                       0.5)  # Constraints on the grad.
        curr_agent.policy_optimizer.step()
        if logger is not None:
            logger.add_scalars('agent%i/losses' % agent_i, {
                'vf_loss': vf_loss,
                'pol_loss': pol_loss
            }, self.niter)
예제 #11
0
def evaluate(config):
    if config.seed_num is None:
        all_seeds = list((DirectoryManager.root / config.storage_name /
                          f"experiment{config.experiment_num}").iterdir())
        config.seed_num = all_seeds[0].stem.strip('seed')

    # Creates paths and directories

    seed_path = DirectoryManager.root / config.storage_name / f"experiment{config.experiment_num}" / f"seed{config.seed_num}"
    dir_manager = DirectoryManager.init_from_seed_path(seed_path)
    if config.incremental is not None:
        model_path = dir_manager.incrementals_dir / (
            f'model_ep{config.incremental}.pt')
    elif config.last_model:
        last_models = [
            path for path in dir_manager.seed_dir.iterdir()
            if path.suffix == ".pt" and not path.stem.endswith('best')
        ]
        assert len(last_models) == 1
        model_path = last_models[0]
    else:
        best_models = [
            path for path in dir_manager.seed_dir.iterdir()
            if path.suffix == ".pt" and path.stem.endswith('best')
        ]
        assert len(best_models) == 1
        model_path = best_models[0]

    # Retrieves world_params if there were any (see make_world function in multiagent.scenarios)
    if (dir_manager.seed_dir / 'world_params.json').exists():
        world_params = load_dict_from_json(
            str(dir_manager.seed_dir / 'world_params.json'))
    else:
        world_params = {}

    # Overwrites world_params if specified
    if config.shuffle_landmarks is not None:
        world_params['shuffle_landmarks'] = config.shuffle_landmarks

    if config.color_objects is not None:
        world_params['color_objects'] = config.color_objects

    if config.small_agents is not None:
        world_params['small_agents'] = config.small_agents

    if config.individual_reward is not None:
        world_params['individual_reward'] = config.individual_reward

    if config.use_dense_rewards is not None:
        world_params['use_dense_rewards'] = config.use_dense_rewards

    # Retrieves env_params (see multiagent.environment.MultiAgentEnv)
    if (dir_manager.seed_dir / 'env_params.json').exists():
        env_params = load_dict_from_json(
            str(dir_manager.seed_dir / 'env_params.json'))
    else:
        env_params = {}
        env_params['use_max_speed'] = False

    # Initializes model and environment
    algorithm = init_from_save(model_path)
    env = make_env(scenario_name=env_params['env_name'],
                   use_discrete_action=algorithm.use_discrete_action,
                   use_max_speed=env_params['use_max_speed'],
                   world_params=world_params)

    if config.render:
        env.render()

    if config.runner_prey:
        # makes sure the environment involves a prey
        assert config.env_name.endswith('tag')
        runner_policy = RunnerPolicy()

        for agent in env.world.agents:
            if agent.adversary:
                agent.action_callback = runner_policy.action

    if config.rusher_predators:
        # makes sure the environment involves predators
        assert config.env_name.endswith('tag')
        rusher_policy = RusherPolicy()

        for agent in env.world.agents:
            if not agent.adversary:
                agent.action_callback = rusher_policy.action

    if config.pendulum_agent is not None:
        # makes sure the agent to be controlled has a valid id
        assert config.pendulum_agent in list(range(len(env.world.agents)))

        pendulum_policy = DoublePendulumPolicy()
        env.world.agents[
            config.pendulum_agent].action_callback = pendulum_policy.action

    if config.interactive_agent is not None:
        # makes sure the agent to be controlled has a valid id
        assert config.interactive_agent in list(range(len(env.world.agents)))

        interactive_policy = InteractivePolicy(env, viewer_id=0)
        env.world.agents[
            config.
            interactive_agent].action_callback = interactive_policy.action

    algorithm.prep_rollouts(device='cpu')
    ifi = 1 / config.fps  # inter-frame interval
    total_reward = []
    all_episodes_agent_embeddings = []
    all_episodes_coach_embeddings = []
    all_trajs = []

    overide_color = None

    color_agents = True

    if env_params['env_name'] == 'bounce':
        env.agents[0].size = 1. * env.agents[0].size
        env.world.overwrite = config.overwrite
    elif env_params['env_name'] == 'spread':
        color_agents = False
    elif env_params['env_name'] == 'compromise':
        env.agents[0].lightness = 0.9
        env.world.landmarks[0].lightness = 0.9
        env.agents[1].lightness = 0.5
        env.world.landmarks[1].lightness = 0.5
        # cmo = plt.cm.get_cmap('viridis')
        env.world.overwrite = config.overwrite
        # overide_color = [np.array(cmo(float(i) / float(2))[:3]) for i in range(2)]

    # set_seeds_env(2, env)
    # EPISODES LOOP
    for ep_i in range(config.n_episodes):
        # set_seeds(2)
        # set_seeds_env(2, env)
        agent_embeddings = []
        coach_embeddings = []
        traj = []
        ep_recorder = EpisodeRecorder(stuff_to_record=['reward'])

        # Resets the environment
        obs = env.reset()

        if config.save_gifs:
            frames = None
        if config.render:
            env.render('human')

        if not algorithm.soft:
            # Resets exploration noise
            algorithm.scale_noise(config.noise_scale)
            algorithm.reset_noise()

        # STEPS LOOP
        for t_i in range(config.episode_length):
            calc_start = time.time()
            # rearrange observations to be per agent, and convert to torch Variable
            torch_obs = [
                Variable(torch.Tensor(obs[i]).view(1, -1), requires_grad=False)
                for i in range(algorithm.nagents)
            ]
            # get actions as torch Variables
            torch_actions, torch_embed = algorithm.select_action(
                torch_obs,
                is_exploring=False if config.noise_scale is None else True,
                return_embed=True)
            torch_total_obs = torch.cat(torch_obs, dim=-1)
            coach_embed = onehot_from_logits(
                algorithm.coach.model(torch_total_obs))
            coach_embeddings.append(coach_embed.data.numpy().squeeze())
            # convert actions to numpy arrays
            actions = [ac.data.numpy().flatten() for ac in torch_actions]
            embeds = [emb.data.numpy().squeeze() for emb in torch_embed]
            agent_embeddings.append(embeds)
            # steps forward in the environment
            next_obs, rewards, dones, infos = env.step(actions)
            ep_recorder.add_step(None, None, rewards, None)
            traj.append((obs, actions, next_obs, rewards, dones))
            obs = next_obs
            colors = list(cm.get_cmap('Set1').colors[:len(embeds[0])])
            if overide_color is not None:
                colors[0] = overide_color[0]
                colors[2] = overide_color[1]
            if color_agents:
                for agent, emb in zip(env.agents, embeds):
                    agent.color = colors[np.argmax(emb)]

            # record frames
            if config.save_gifs:
                frames = [] if frames is None else frames
                frames.append(env.render('rgb_array')[0])

            if config.render or config.save_gifs:
                # Enforces the fps config
                calc_end = time.time()
                elapsed = calc_end - calc_start
                if elapsed < ifi:
                    time.sleep(ifi - elapsed)
                env.render('human')

            if all(dones) and config.interrupt_episode:
                if config.render:
                    time.sleep(2)
                break

        # print(ep_recorder.get_total_reward())
        total_reward.append(ep_recorder.get_total_reward())
        all_episodes_agent_embeddings.append(agent_embeddings)
        all_episodes_coach_embeddings.append(coach_embeddings)
        all_trajs.append(traj)

    # Saves gif of all the episodes
    if config.save_gifs:
        gif_path = dir_manager.storage_dir / 'gifs'
        gif_path.mkdir(exist_ok=True)

        gif_num = 0
        while (gif_path /
               f"{env_params['env_name']}__experiment{config.experiment_num}_seed{config.seed_num}_{gif_num}.gif"
               ).exists():
            gif_num += 1
        imageio.mimsave(str(
            gif_path /
            f"{env_params['env_name']}__experiment{config.experiment_num}_seed{config.seed_num}_{gif_num}.gif"
        ),
                        frames,
                        duration=ifi)
    env.close()

    embeddings = {
        'agents': all_episodes_agent_embeddings,
        'coach': all_episodes_coach_embeddings
    }

    save_folder = dir_manager.experiment_dir if config.save_to_exp_folder else dir_manager.seed_dir
    embeddings_path = U.directory_tree.uniquify(
        save_folder / f"{config.file_name_to_save}.pkl")
    trajs_path = osp.splitext(embeddings_path)[0] + "_trajs.pkl"

    with open(embeddings_path, 'wb') as fp:
        pickle.dump(embeddings, fp)
        fp.close()

    with open(trajs_path, 'wb') as fp:
        pickle.dump(all_trajs, fp)
        fp.close()

    return total_reward, str(embeddings_path)
예제 #12
0
    def forward(self,
                obs,
                sample=True,
                return_all_probs=False,
                return_log_pi=False,
                regularize=False,
                return_entropy=False):
        out = super(DiscretePolicy, self).forward(obs)
        # _, action_dim = out.size()
        # # dim(u_aaction)=5, dim(r_action) = 2, dim(audio_action = 3)
        # r_action_dim = 2
        # audio_action_dim = 3
        # u_action_dim = action_dim - (r_action_dim + audio_action_dim)
        # assert u_action_dim == 5, "policy dimensions"
        #
        #
        # probs_u = F.softmax(out[:,0:u_action_dim], dim=1)
        # on_gpu = next(self.parameters()).is_cuda
        # if sample:
        #     int_act, act_u = categorical_sample(probs_u, use_cuda=on_gpu)
        # else:
        #     act_u = onehot_from_logits(probs_u)
        #
        # # TODO: change rotation to discrete action, and output prob_r, also change the step in environment
        # # action_r = out[:, u_action_dim].view(-1, 1)
        # probs_r = F.softmax(out[:, u_action_dim:u_action_dim+r_action_dim], dim=1)
        # # on_gpu = next(self.parameters()).is_cuda
        # if sample:
        #     _, act_r = categorical_sample(probs_r, use_cuda=on_gpu)
        # else:
        #     act_r = onehot_from_logits(probs_r)
        #
        # probs_audio = F.softmax(out[:, u_action_dim+r_action_dim:], dim=1)
        # # on_gpu = next(self.parameters()).is_cuda
        # if sample:
        #     _, act_audio = categorical_sample(probs_audio, use_cuda=on_gpu)
        # else:
        #     act_audio = onehot_from_logits(probs_audio)
        #
        # return torch.cat([act_u, act_r, act_audio], dim=1)

        probs = F.softmax(out, dim=1)
        on_gpu = next(self.parameters()).is_cuda
        if sample:
            int_act, act = categorical_sample(probs, use_cuda=on_gpu)
        else:
            act = onehot_from_logits(probs)
        rets = [act]
        if return_log_pi or return_entropy:
            log_probs = F.log_softmax(out, dim=1)
        if return_all_probs:
            rets.append(probs)
        if return_log_pi:
            # return log probability of selected action
            rets.append(log_probs.gather(1, int_act))
        if regularize:
            rets.append([(out**2).mean()])
        if return_entropy:
            rets.append(-(log_probs * probs).sum(1).mean())
        if len(rets) == 1:
            return rets[0]
        return rets
예제 #13
0
    def update(self, sample, agent_i, parallel=False, grad_norm=0.5):
        """
        Update parameters of agent model based on sample from replay buffer
        Inputs:
            sample: EpisodeBatch, use sample[key_i] to get a specific 
                    array of obs, action, etc for agent i
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
        """
        obs, acs, rews, next_obs, dones = self.parse_sample(sample) # [(B,T,D)]*N  
        bs, ts, _ = obs[0].shape
        self.init_hidden(bs)  # use pre-defined init hiddens 
        curr_agent = self.agents[agent_i]

        # NOTE: critic update
        curr_agent.critic_optimizer.zero_grad()

        # compute target actions
        if self.alg_types[agent_i] == 'MADDPG':
            all_trgt_acs = []   # [(B,T,A)]*N
            for i, (pi, nobs) in enumerate(zip(self.target_policies, next_obs)):
                # nobs: (B,T,O)
                act_i = self.compute_action(i, pi, nobs, bs=bs, ts=ts)
                all_trgt_acs.append(act_i)  # [(B,T,A)]
            
            if self.discrete_action:    # one-hot encode action
                all_trgt_acs = [onehot_from_logits(
                    act_i.reshape(bs*ts,-1)
                ).reshape(bs,ts,-1) for act_i in all_trgt_acs] 

            # critic input, [(B,T,O)_i, ..., (B,T,A)_i, ...] -> (B,T,O*N+A*N)
            trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=-1)
        else:  # DDPG
            act_i = self.compute_action(agent_i, curr_agent.target_policy, 
                        next_obs[agent_i], bs=bs, ts=ts)

            if self.discrete_action:
                act_i = onehot_from_logits(
                    act_i.reshape(bs*ts, -1)
                ).reshape(bs, ts, -1) 

            # (B,T,O) + (B,T,A) -> (B,T,O+A)
            trgt_vf_in = torch.cat((next_obs[agent_i], act_i), dim=-1)

        # bellman targets
        target_q = self.compute_q_val(agent_i, curr_agent.target_critic, 
                        trgt_vf_in, bs=bs, ts=ts)   # (B*T,1)
        target_value = (rews[agent_i].view(-1, 1) + self.gamma * target_q *
                            (1 - dones[agent_i].view(-1, 1)))   # (B*T,1)

        # Q func
        if self.alg_types[agent_i] == 'MADDPG':
            vf_in = torch.cat((*obs, *acs), dim=1)
        else:  # DDPG
            vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1)
        actual_value = self.compute_q_val(agent_i, curr_agent.critic, 
                            vf_in, bs=bs, ts=ts)    # (B*T,1)

        # bellman errors
        vf_loss = MSELoss(actual_value, target_value.detach())
        vf_loss.backward()
        if parallel:
            average_gradients(curr_agent.critic)
        if grad_norm > 0:
            torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), grad_norm)
        curr_agent.critic_optimizer.step()

        # NOTE: policy update
        curr_agent.policy_optimizer.zero_grad()

        # current agent action (deterministic, softened)
        curr_pol_out = self.compute_action(i, curr_agent.policy, 
                                obs[agent_i], bs=bs, ts=ts) # (B,T,A)
        if self.discrete_action:
            # Forward pass as if onehot (hard=True) but backprop through a differentiable
            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
            # through discrete categorical samples, but I'm not sure if that is
            # correct since it removes the assumption of a deterministic policy for
            # DDPG. Regardless, discrete policies don't seem to learn properly without it.
            curr_pol_vf_in = gumbel_softmax(
                                curr_pol_out.reshape(bs*ts, -1), hard=True
                            ).reshape(bs, ts, -1) 
        else:
            curr_pol_vf_in = curr_pol_out

        if self.alg_types[agent_i] == 'MADDPG':
            all_pol_acs = []
            for i, pi, ob in zip(range(self.nagents), self.policies, obs):
                if i == agent_i:
                    # insert current agent act to q input 
                    all_pol_acs.append(curr_pol_vf_in)
                else: 
                    p_act_i = self.compute_action(i, pi, ob, bs=bs, ts=ts) # (B,T,A)
                    if self.discrete_action:
                        p_act_i = onehot_from_logits(
                                    p_act_i.reshape(bs*ts, -1)
                                ).reshape(bs, ts, -1) 
                    all_pol_acs.append(p_act_i)
            p_vf_in = torch.cat((*obs, *all_pol_acs), dim=-1) # (B,T,O*N+A*N)
        else:  # DDPG
            p_vf_in = torch.cat((obs[agent_i], curr_pol_vf_in), dim=-1) # (B,T,O+A)
        
        # value function to update current policy
        p_value = self.compute_q_val(agent_i, curr_agent.critic, 
                        p_vf_in, bs=bs, ts=ts)   # (B*T,1)
        pol_loss = -p_value.mean()
        # p regularization, scale down output (gaussian mean,std or logits)
        # reference: https://github.com/openai/maddpg/blob/master/maddpg/trainer/maddpg.py
        pol_loss += ((curr_pol_out.reshape(bs*ts, -1))**2).mean() * 1e-3
        pol_loss.backward()
        if parallel:
            average_gradients(curr_agent.policy)
        if grad_norm > 0:
            torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), grad_norm)
        curr_agent.policy_optimizer.step()

        # collect training statss 
        results = {
            "agent_{}_critic_loss".format(agent_i): vf_loss,
            "agent_{}_policy_loss".format(agent_i): pol_loss
        }
        return results
예제 #14
0
    def update(self, sample, agent_i):
        """
        TODO: to make the sample into valid inputs
        :param sample: the batch of experiences
        :param agent_i: the agent to be updated
        """
        obs, acs, rews, next_obs, dones = sample
        tensor_acs = torch.from_numpy(acs).float()
        tensor_obs = torch.from_numpy(obs).float()
        tensor_next_obs = torch.from_numpy(next_obs).float()
        tensor_rews = torch.from_numpy(rews).float()
        tensor_dones = torch.from_numpy(~dones).float()

        current_agent = self.agents[agent_i]

        current_agent.critic_optimizer.zero_grad()
        if self.alg_types[agent_i] == 'MADDPG':
            all_target_actors = self.target_actors()
            if self.discrete_action:
                all_target_actions = [
                    onehot_from_logits(
                        pi(tensor_obs[:,
                                      self.observation_index[agent_index][0]:
                                      self.observation_index[agent_index][1]]))
                    for pi, agent_index in zip(all_target_actors,
                                               range(self.num_agent))
                ]
            else:
                all_target_actions = [
                    pi(tensor_obs[:,
                                  self.observation_index[agent_index][0]:self.
                                  observation_index[agent_index][1]])
                    for pi, agent_index in zip(all_target_actors,
                                               range(self.num_agent))
                ]
            target_critic_input = torch.cat(
                (tensor_next_obs, torch.cat(all_target_actions, dim=1)), dim=1)
        else:
            if self.discrete_action:
                target_critic_input = torch.cat((
                    tensor_next_obs[:, self.observation_index[agent_i][0]:self.
                                    observation_index[agent_i][1]],
                    onehot_from_logits(
                        current_agent.target_actor(
                            tensor_next_obs[:, self.
                                            observation_index[agent_i][0]:self.
                                            observation_index[agent_i][1]]))),
                                                dim=1)
            else:
                target_critic_input = torch.cat((
                    tensor_next_obs[:, self.observation_index[agent_i][0]:self.
                                    observation_index[agent_i][1]],
                    current_agent.target_actor(
                        tensor_next_obs[:,
                                        self.observation_index[agent_i][0]:self
                                        .observation_index[agent_i][1]])),
                                                dim=1)
        target_critic_value = current_agent.target_critic(target_critic_input)
        target_value = \
            tensor_rews[:, agent_i].unsqueeze(1) + \
            self.gamma * target_critic_value * tensor_dones[:, agent_i].unsqueeze(1)

        if self.alg_types[agent_i] == 'MADDPG':
            critic_input = torch.cat((tensor_obs, tensor_acs), dim=1)
        else:  # DDPG
            critic_input = torch.cat(
                (tensor_obs[:, self.observation_index[agent_i][0]:self.
                            observation_index[agent_i][1]],
                 tensor_acs[:, self.action_index[agent_i][0]:self.
                            action_index[agent_i][1]]),
                dim=1)
        actual_value = current_agent.critic(critic_input)
        critic_loss = MSELoss(actual_value, target_value.detach())
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(current_agent.critic.parameters(), 0.5)
        current_agent.critic_optimizer.step()

        current_agent.actor_optimizer.zero_grad()
        if self.discrete_action:
            current_action_out = current_agent.actor(
                tensor_obs[:, self.observation_index[agent_i][0]:self.
                           observation_index[agent_i][1]])
            current_action_input_critic = gumbel_softmax(current_action_out,
                                                         hard=True)
        else:
            current_action_out = current_agent.actor(
                tensor_obs[:, self.observation_index[agent_i][0]:self.
                           observation_index[agent_i][1]])
            current_action_input_critic = current_action_out

        if self.alg_types[agent_i] == 'MADDPG':
            all_actor_action = []
            all_target_actors = self.target_actors()
            for i, pi in zip(range(self.num_agent), all_target_actors):
                if i == agent_i:
                    all_actor_action.append(current_action_input_critic)
                else:
                    if self.discrete_action:
                        all_actor_action.append(
                            onehot_from_logits(all_target_actors[i](
                                tensor_obs[:,
                                           self.observation_index[i][0]:self.
                                           observation_index[i][1]])))
                    else:
                        all_actor_action.append(all_target_actors[i](
                            tensor_obs[:, self.observation_index[i][0]:self.
                                       observation_index[i][1]]))
            critic_input = torch.cat(
                (tensor_obs, torch.cat(all_actor_action, dim=1)), dim=1)
        else:
            critic_input = torch.cat(
                (tensor_obs[:, self.observation_index[agent_i][0]:self.
                            observation_index[agent_i][1]],
                 current_action_input_critic),
                dim=1)

        actor_loss = -current_agent.critic(critic_input).mean()
        actor_loss += (current_action_out**2).mean() * 1e-3
        actor_loss.backward()
        torch.nn.utils.clip_grad_norm_(current_agent.actor.parameters(), 0.5)
        current_agent.actor_optimizer.step()
예제 #15
0
    def update(self, sample, agent_i, parallel=False, logger=None):
        """
        Update parameters of agent model based on sample from replay buffer
        Inputs:
            sample: tuple of (observations, actions, rewards, next
                    observations, and episode end masks) sampled randomly from
                    the replay buffer. Each is a list with entries
                    corresponding to each agent
            agent_i (int): index of agent to update
            parallel (bool): If true, will average gradients across threads
            logger (SummaryWriter from Tensorboard-Pytorch):
                If passed in, important quantities will be logged
        """
        # For RNN, the obs and next_obs both have histories
        obs, acs, rews, next_obs, dones = sample
        curr_agent = self.agents[agent_i]

        curr_agent.critic_optimizer.zero_grad()
        if self.alg_types[agent_i] == 'MADDPG':
            if self.discrete_action:  # one-hot encode action

                # This is original one, 'pi': policy, 'nobs' n_observations
                #all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in
                #                zip(self.target_policies, next_obs)]

                # Original till here

                #-------- Expanding out for debugging --------#
                all_trgt_acs = []
                for pi, nobs in zip(self.target_policies, next_obs):
                    temp = onehot_from_logits(pi(nobs))
                    #print(temp)
                    all_trgt_acs.append(temp)
                # -------- End debug -------------------------#

            else:
                all_trgt_acs = [
                    pi(nobs)
                    for pi, nobs in zip(self.target_policies, next_obs)
                ]

            # Get the most current observation from the history to calculate the target value
            t0_next_obs = [[], [], []]
            for a in range(self.nagents):
                t0_next_obs[a] = torch.tensor(np.zeros(
                    (next_obs[0].shape[0], 18)),
                                              dtype=torch.float)
            # the next_obs[0].shape[0] gives the batch size
            # TODO: change it to be a parameter
            # Only keep the current obs for critic VF
            for n in range(self.nagents):  # for each agents
                for b in range(
                        next_obs[0].shape[0]):  # for the number of batches
                    t0_next_obs[n][b][:] = next_obs[n][b][0:18]

            # ORIGINAL was \/
            #trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1)
            trgt_vf_in = torch.cat((*t0_next_obs, *all_trgt_acs), dim=1)
            # It is working till here. Only kept the current obs for critic VF
        else:  # DDPG
            # DDPG only knows the particular agent's observation and policy
            # Whereas, MADDPG has access to all other agents' policies
            # TODO: grab only the current observation to send to the critic
            t0_next_obs = [[], [], []]
            for a in range(self.nagents):
                t0_next_obs[a] = torch.tensor(np.zeros(
                    (next_obs[0].shape[0], 18)),
                                              dtype=torch.float)
            for n in range(self.nagents):  # for each agents
                for b in range(
                        next_obs[0].shape[0]):  # for the number of batches
                    t0_next_obs[n][b][:] = next_obs[n][b][0:18]

            # Originally it would be next_obs[agent_i] instead of t0_next_obs[agent_i]
            if self.discrete_action:
                trgt_vf_in = torch.cat(
                    (t0_next_obs[agent_i],
                     onehot_from_logits(
                         curr_agent.target_policy(next_obs[agent_i]))),
                    dim=1)
            else:
                trgt_vf_in = torch.cat(
                    (next_obs[agent_i],
                     curr_agent.target_policy(next_obs[agent_i])),
                    dim=1)

        target_value = (rews[agent_i].view(-1, 1) +
                        self.gamma * curr_agent.target_critic(trgt_vf_in) *
                        (1 - dones[agent_i].view(-1, 1)))

        ##### Just get the current observation (i.e., without history) ##########
        # Reason: Critic VF does not need history
        # Copied the same as in t0_next_obs, since BOTH obs and next_obs have HISTORIES.
        t0_obs = [[], [], []]
        for a in range(self.nagents):
            t0_obs[a] = torch.tensor(np.zeros((obs[0].shape[0], 18)),
                                     dtype=torch.float)
        for n in range(self.nagents):  # for each agents
            for b in range(obs[0].shape[0]):  # for the number of batches
                t0_obs[n][b][:] = obs[n][b][0:18]
        ###################################################

        if self.alg_types[agent_i] == 'MADDPG':
            vf_in = torch.cat((*t0_obs, *acs), dim=1)
        else:  # DDPG  #TODO: below, might have to change obs to t0_obs, when using DDPG
            vf_in = torch.cat((t0_obs[agent_i], acs[agent_i]), dim=1)
        actual_value = curr_agent.critic(vf_in)
        vf_loss = MSELoss(actual_value, target_value.detach())
        vf_loss.backward()
        if parallel:
            average_gradients(curr_agent.critic)
        torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), 0.5)
        curr_agent.critic_optimizer.step()

        curr_agent.policy_optimizer.zero_grad()

        if self.discrete_action:
            # Forward pass as if onehot (hard=True) but backprop through a differentiable
            # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop
            # through discrete categorical samples, but I'm not sure if that is
            # correct since it removes the assumption of a deterministic policy for
            # DDPG. Regardless, discrete policies don't seem to learn properly without it.
            '''
            Now, we are back to forwarding policy, so we need to use obs with history
            '''
            curr_pol_out = curr_agent.policy(obs[agent_i])
            curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True)
            # Seems to be working fine till here
        else:
            curr_pol_out = curr_agent.policy(obs[agent_i])
            curr_pol_vf_in = curr_pol_out
        if self.alg_types[agent_i] == 'MADDPG':
            all_pol_acs = []
            for i, pi, ob in zip(range(self.nagents), self.policies, obs):
                if i == agent_i:
                    all_pol_acs.append(curr_pol_vf_in)
                elif self.discrete_action:
                    all_pol_acs.append(onehot_from_logits(pi(ob)))
                else:
                    all_pol_acs.append(pi(ob))
            # Originally:
            #vf_in = torch.cat((*obs, *all_pol_acs), dim=1)
            vf_in = torch.cat((*t0_obs, *all_pol_acs), dim=1)
        else:  # DDPG
            vf_in = torch.cat((t0_obs[agent_i], curr_pol_vf_in), dim=1)
        # TODO: FIX THIS
        pol_loss = -curr_agent.critic(vf_in).mean()
        pol_loss += (curr_pol_out**2).mean() * 1e-3
        pol_loss.backward()
        if parallel:
            average_gradients(curr_agent.policy)
        torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), 0.5)
        curr_agent.policy_optimizer.step()
        if logger is not None:
            logger.add_scalars('agent%i/losses' % agent_i, {
                'vf_loss': vf_loss,
                'pol_loss': pol_loss
            }, self.niter)