Exemplo n.º 1
0
    def __update(self, obs_n, action_n, next_obs_n, reward_n, done):
        self.model.train()

        self.memory.push(obs_n, action_n, next_obs_n, reward_n, done)

        if self.batch_size > len(self.memory):
            self.model.eval()
            return None

        # Todo: move this beta in the Prioritized Replay memory
        beta_start = 0.4
        beta = min(
            1.0,
            beta_start + (self.__update_iter + 1) * (1.0 - beta_start) / 5000)

        transitions, indices, weights = self.memory.sample(
            self.batch_size, beta)
        batch = Transition(*zip(*transitions))

        obs_batch = torch.FloatTensor(list(batch.state)).to(self.device)
        action_batch = torch.FloatTensor(list(batch.action)).to(self.device)
        reward_batch = torch.FloatTensor(list(batch.reward)).to(self.device)
        next_obs_batch = torch.FloatTensor(list(batch.next_state)).to(
            self.device)
        weights = torch.FloatTensor(weights).to(self.device)
        non_final_mask = 1 - torch.ByteTensor(list(batch.done)).to(self.device)

        # calc loss
        prios = 0
        overall_loss = 0

        og_thoughts, global_thoughts = self._get_thoughts(obs_batch)
        next_obs_og_thoughts, next_obs_global_thoughts = self._get_thoughts(
            next_obs_batch)
        for i in range(self.model.n_agents):
            q_val_i = self.model.agent(i)(og_thoughts[i], global_thoughts[i])
            pred_q = q_val_i.gather(1, action_batch[:, i].long().unsqueeze(1))

            target_next_obs_q = torch.zeros(pred_q.shape).to(self.device)
            non_final_next_obs_og = next_obs_og_thoughts[i, :][
                non_final_mask[:, i]]
            non_final_next_obs_global = next_obs_global_thoughts[i, :][
                non_final_mask[:, i]]

            # Double DQN update
            target_q = 0
            if not (non_final_next_obs_og.shape[0] == 0):
                _max_actions = self.model.agent(i)(non_final_next_obs_og,
                                                   non_final_next_obs_global)
                _max_actions = _max_actions.max(1, keepdim=True)[1].detach()
                _max_q = self.target_model.agent(i)(
                    non_final_next_obs_og,
                    non_final_next_obs_global).gather(1, _max_actions)
                target_next_obs_q[non_final_mask[:, i]] = _max_q

                target_q = target_next_obs_q.detach()

            target_q = (self.discount * target_q) + reward_batch.sum(
                dim=1, keepdim=True)
            loss = (pred_q - target_q).pow(2) * weights.unsqueeze(1)
            prios += loss + 1e-5
            loss = loss.mean()
            overall_loss += loss
            self.writer.add_scalar('agent_{}/critic_loss'.format(i),
                                   loss.item(), self._step_iter)

        # Optimize the model
        self.optimizer.zero_grad()
        overall_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5)
        self.memory.update_priorities(indices, prios.data.cpu().numpy())
        self.optimizer.step()

        # update target network
        if (self._step_iter % 100) == 0:
            self._get_critic_consensus()
        soft_update(self.target_model, self.model, self.tau)

        # log
        self.writer.add_scalar('_overall/critic_loss', overall_loss,
                               self._step_iter)
        self.writer.add_scalar('_overall/beta', beta, self._step_iter)

        # just keep track of update counts
        self.__update_iter += 1

        # resuming the model in eval mode
        self.model.eval()

        return loss.item()
Exemplo n.º 2
0
    def __update(self, obs_n, action_n, next_obs_n, reward_n, done):
        self.model.train()
        self.memory.push(obs_n, action_n, next_obs_n, reward_n, done)

        if self.batch_size > len(self.memory):
            self.model.eval()
            return None

        # transitions = self.memory.sample(self.batch_size)
        # Todo: move this beta in the Prioritized Replay memory
        # beta_start = 0.4
        # beta = min(1.0, beta_start + (self.__update_iter + 1) * (1.0 - beta_start) / 5000)

        # transitions, indices, weights = self.memory.sample(self.batch_size, beta)
        transitions = self.memory.sample(self.batch_size)
        batch = Transition(*zip(*transitions))

        obs_batch = torch.FloatTensor(list(batch.state)).to(self.device)
        action_batch = torch.FloatTensor(list(batch.action)).to(self.device)
        reward_batch = torch.FloatTensor(list(batch.reward)).to(self.device)
        next_obs_batch = torch.FloatTensor(list(batch.next_state)).to(
            self.device)
        non_final_mask = 1 - torch.ByteTensor(list(batch.done)).to(self.device)
        # weights = torch.FloatTensor(weights).to(self.device)

        comb_obs_batch = obs_batch.flatten(1)
        comb_action_batch = action_batch.flatten(1)
        comb_next_obs_batch = next_obs_batch.flatten(1)

        # calculate loss
        q_loss_n, actor_loss_n = 0, 0
        # prios_n = 0
        for i in range(self.model.n_agents):
            # critic
            pred_q_value = self.model.agent(i).critic(comb_obs_batch,
                                                      comb_action_batch)

            target_next_obs_q = torch.zeros(pred_q_value.shape).to(self.device)
            target_action_batch = self.__select_action(self.target_model,
                                                       next_obs_batch)
            target_action_batch = target_action_batch.flatten(1).to(
                self.device)
            _next_q = self.target_model.agent(i).critic(
                comb_next_obs_batch, target_action_batch)
            target_next_obs_q[non_final_mask[:,
                                             i]] = _next_q[non_final_mask[:,
                                                                          i]]
            target_q_value = (self.discount *
                              target_next_obs_q).squeeze(1) + reward_batch[:,
                                                                           i]
            q_loss = MSELoss()(pred_q_value.squeeze(1), target_q_value).mean()
            q_loss_n += q_loss

            # q_loss = (pred_q_value.squeeze(1) - target_q_value).pow(2) * weights
            # prios_n += q_loss + 1e-5
            # q_loss = q_loss.mean()
            # q_loss_n += q_loss

            # actor
            actor_i = self.model.agent(i).actor(obs_batch[:, i])
            _action_batch = action_batch.clone()
            if self.discrete_action_space:
                _action_batch[:, i] = gumbel_softmax(actor_i, hard=True)
            else:
                _action_batch[:, i] = actor_i

            _action_batch = _action_batch.flatten(1)
            actor_loss = -self.model.agent(i).critic(comb_obs_batch,
                                                     _action_batch).mean()
            actor_loss += (actor_i**2).mean() * 1e-3
            actor_loss_n += actor_loss

            # log
            self.writer.add_scalar('agent_{}/critic_loss'.format(i), q_loss,
                                   self.__update_iter)
            self.writer.add_scalar('agent_{}/actor_loss'.format(i), actor_loss,
                                   self.__update_iter)

        # Overall loss
        loss = actor_loss_n + q_loss_n
        # prios_n /= self.model.n_agents

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5)
        # self.memory.update_priorities(indices, prios_n.data.cpu().numpy())
        self.optimizer.step()

        # update target network
        soft_update(self.target_model, self.model, self.tau)

        # log
        self.writer.add_scalar('_overall/critic_loss', q_loss_n,
                               self.__update_iter)
        self.writer.add_scalar('_overall/actor_loss', actor_loss_n,
                               self.__update_iter)

        # just keep track of update counts
        self.__update_iter += 1

        # resuming the model in eval mode
        self.model.eval()

        return loss.item()