예제 #1
0
    def _train(self, BATCH):
        q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q_next = self.q_net.t(BATCH.obs_,
                              begin_mask=BATCH.begin_mask)  # [T, B, A]
        q_eval = (q * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        q_target = n_step_return(
            BATCH.reward,
            self.gamma,
            BATCH.done,
            q_next.max(-1, keepdim=True)[0],
            BATCH.begin_mask,
            nstep=self._n_step_value).detach()  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1]
        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1

        cql1_loss = (th.logsumexp(q, dim=-1, keepdim=True) - q).mean()  # 1
        loss = q_loss + self._cql_weight * cql1_loss
        self.oplr.optimize(loss)
        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/q_loss': q_loss,
            'LOSS/cql1_loss': cql1_loss,
            'LOSS/loss': loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }
예제 #2
0
    def _train(self, BATCH):
        q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask).mean(
            0)  # [H, T, B, A] => [T, B, A]
        q_next = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask).mean(
            0)  # [H, T, B, A] => [T, B, A]
        # [T, B, A] * [T, B, A] => [T, B, 1]
        q_eval = (q * BATCH.action).sum(-1, keepdim=True)
        q_target = n_step_return(
            BATCH.reward,
            self.gamma,
            BATCH.done,
            # [T, B, A] => [T, B, 1]
            q_next.max(-1, keepdim=True)[0],
            BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1]
        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1

        # mask_dist = td.Bernoulli(probs=self._probs)  # TODO:
        # mask = mask_dist.sample([batch_size]).T   # [H, B]
        self.oplr.optimize(q_loss)
        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': q_loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }
예제 #3
0
    def _train(self, BATCH):
        q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
        next_q = self.q_net(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q_target = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]

        q_eval = (q * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        next_max_action = next_q.argmax(-1)  # [T, B]
        next_max_action_one_hot = F.one_hot(next_max_action.squeeze(), self.a_dim).float()  # [T, B, A]

        q_target_next_max = (q_target * next_max_action_one_hot).sum(-1, keepdim=True)  # [T, B, 1]
        q_target = n_step_return(BATCH.reward,
                                 self.gamma,
                                 BATCH.done,
                                 q_target_next_max,
                                 BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1]
        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1
        self.oplr.optimize(q_loss)

        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': q_loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }
예제 #4
0
파일: ddpg.py 프로젝트: StepNeverStop/RLs
    def _train(self, BATCH):
        if self.is_continuous:
            action_target = self.actor.t(
                BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            if self.use_target_action_noise:
                action_target = self.target_noised_action(
                    action_target)  # [T, B, A]
        else:
            target_logits = self.actor.t(
                BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            target_cate_dist = td.Categorical(logits=target_logits)
            target_pi = target_cate_dist.sample()  # [T, B]
            action_target = F.one_hot(target_pi,
                                      self.a_dim).float()  # [T, B, A]
        q = self.critic(BATCH.obs, BATCH.action,
                        begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q_target = self.critic.t(BATCH.obs_,
                                 action_target,
                                 begin_mask=BATCH.begin_mask)  # [T, B, 1]
        dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target,
                             BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = dc_r - q  # [T, B, 1]
        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1
        self.critic_oplr.optimize(q_loss)

        if self.is_continuous:
            mu = self.actor(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                -1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1),
                                         self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            mu = _pi_diff + _pi  # [T, B, A]
        q_actor = self.critic(BATCH.obs, mu,
                              begin_mask=BATCH.begin_mask)  # [T, B, 1]
        actor_loss = -q_actor.mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        return td_error, {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': q_loss,
            'Statistics/q_min': q.min(),
            'Statistics/q_mean': q.mean(),
            'Statistics/q_max': q.max()
        }
예제 #5
0
    def _train(self, BATCH):
        q = self.critic(BATCH.obs, BATCH.action,
                        begin_mask=BATCH.begin_mask)  # [T, B, 1]
        if self.is_continuous:
            next_mu, _ = self.actor(BATCH.obs_,
                                    begin_mask=BATCH.begin_mask)  # [T, B, *]
            max_q_next = self.critic(
                BATCH.obs_, next_mu,
                begin_mask=BATCH.begin_mask).detach()  # [T, B, 1]
        else:
            logits = self.actor(BATCH.obs_,
                                begin_mask=BATCH.begin_mask)  # [T, B, *]
            max_a = logits.argmax(-1)  # [T, B]
            max_a_one_hot = F.one_hot(max_a, self.a_dim).float()  # [T, B, N]
            max_q_next = self.critic(BATCH.obs_,
                                     max_a_one_hot).detach()  # [T, B, 1]
        td_error = q - n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                     max_q_next,
                                     BATCH.begin_mask).detach()  # [T, B, 1]
        critic_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1
        self.critic_oplr.optimize(critic_loss)

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs,
                                     begin_mask=BATCH.begin_mask)  # [T, B, *]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            log_prob = dist.log_prob(BATCH.action)  # [T, B]
            entropy = dist.entropy().mean()  # 1
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, *]
            logp_all = logits.log_softmax(-1)  # [T, B, *]
            log_prob = (logp_all * BATCH.action).sum(-1)  # [T, B]
            entropy = -(logp_all.exp() * logp_all).sum(-1).mean()  # 1
        ratio = (log_prob - BATCH.log_prob).exp().detach()  # [T, B]
        actor_loss = -(ratio * log_prob *
                       q.squeeze(-1).detach()).mean()  # [T, B] => 1
        self.actor_oplr.optimize(actor_loss)

        return td_error, {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': critic_loss,
            'Statistics/q_max': q.max(),
            'Statistics/q_min': q.min(),
            'Statistics/q_mean': q.mean(),
            'Statistics/ratio': ratio.mean(),
            'Statistics/entropy': entropy
        }
예제 #6
0
    def _train(self, BATCH):
        q1 = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q2 = self.critic2(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q1_eval = (q1 * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        q2_eval = (q2 * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]

        q1_log_probs = (q1 / (self.alpha + th.finfo().eps)).log_softmax(-1)  # [T, B, A]
        q1_entropy = -(q1_log_probs.exp() * q1_log_probs).sum(-1, keepdim=True).mean()  # 1

        q1_target = self.critic.t(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q2_target = self.critic2.t(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q1_target_max = q1_target.max(-1, keepdim=True)[0]  # [T, B, 1]
        q1_target_log_probs = (q1_target / (self.alpha + th.finfo().eps)).log_softmax(-1)  # [T, B, A]
        q1_target_entropy = -(q1_target_log_probs.exp() * q1_target_log_probs).sum(-1, keepdim=True)  # [T, B, 1]

        q2_target_max = q2_target.max(-1, keepdim=True)[0]  # [T, B, 1]
        # q2_target_log_probs = q2_target.log_softmax(-1)
        # q2_target_log_max = q2_target_log_probs.max(1, keepdim=True)[0]

        q_target = th.minimum(q1_target_max, q2_target_max) + self.alpha * q1_target_entropy  # [T, B, 1]
        dc_r = n_step_return(BATCH.reward,
                             self.gamma,
                             BATCH.done,
                             q_target,
                             BATCH.begin_mask).detach()  # [T, B, 1]
        td_error1 = q1_eval - dc_r  # [T, B, 1]
        td_error2 = q2_eval - dc_r  # [T, B, 1]
        q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
        q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
        loss = 0.5 * (q1_loss + q2_loss)
        self.critic_oplr.optimize(loss)
        summaries = {
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/loss': loss,
            'Statistics/log_alpha': self.log_alpha,
            'Statistics/alpha': self.alpha,
            'Statistics/q1_entropy': q1_entropy,
            'Statistics/q_min': th.minimum(q1, q2).mean(),
            'Statistics/q_mean': q1.mean(),
            'Statistics/q_max': th.maximum(q1, q2).mean()
        }
        if self.auto_adaption:
            alpha_loss = -(self.alpha * (self.target_entropy - q1_entropy).detach()).mean()
            self.alpha_oplr.optimize(alpha_loss)
            summaries.update({
                'LOSS/alpha_loss': alpha_loss,
                'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr
            })
        return (td_error1 + td_error2) / 2, summaries
예제 #7
0
파일: c51.py 프로젝트: StepNeverStop/RLs
    def _train(self, BATCH):
        q_dist = self.q_net(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A, N]
        # [T, B, A, N] * [T, B, A, 1] => [T, B, A, N] => [T, B, N]
        q_dist = (q_dist * BATCH.action.unsqueeze(-1)).sum(-2)

        q_eval = (q_dist * self._z).sum(-1)  # [T, B, N] * [N,] => [T, B]

        target_q_dist = self.q_net.t(
            BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A, N]
        # [T, B, A, N] * [1, N] => [T, B, A]
        target_q = (target_q_dist * self._z).sum(-1)
        a_ = target_q.argmax(-1)  # [T, B]
        a_onehot = F.one_hot(a_, self.a_dim).float()  # [T, B, A]
        # [T, B, A, N] * [T, B, A, 1] => [T, B, A, N] => [T, B, N]
        target_q_dist = (target_q_dist * a_onehot.unsqueeze(-1)).sum(-2)

        target = n_step_return(
            BATCH.reward.repeat(1, 1, self._atoms), self.gamma,
            BATCH.done.repeat(1, 1, self._atoms), target_q_dist,
            BATCH.begin_mask.repeat(1, 1, self._atoms)).detach()  # [T, B, N]
        target = target.clamp(self._v_min, self._v_max)  # [T, B, N]
        # An amazing trick for calculating the projection gracefully.
        # ref: https://github.com/ShangtongZhang/DeepRL
        target_dist = (
            1 - (target.unsqueeze(-1) - self._z.view(1, 1, -1, 1)).abs() /
            self._delta_z).clamp(0, 1) * target_q_dist.unsqueeze(
                -1)  # [T, B, N, 1]
        target_dist = target_dist.sum(-1)  # [T, B, N]

        _cross_entropy = -(target_dist * th.log(q_dist + th.finfo().eps)).sum(
            -1, keepdim=True)  # [T, B, 1]
        loss = (_cross_entropy * BATCH.get('isw', 1.0)).mean()  # 1

        self.oplr.optimize(loss)
        return _cross_entropy, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }
예제 #8
0
파일: qrdqn.py 프로젝트: StepNeverStop/RLs
    def _train(self, BATCH):
        q_dist = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A, N]
        q_dist = (q_dist * BATCH.action.unsqueeze(-1)).sum(-2)  # [T, B, A, N] => [T, B, N]

        target_q_dist = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A, N]
        target_q = target_q_dist.mean(-1)  # [T, B, A, N] => [T, B, A]
        _a = target_q.argmax(-1)  # [T, B]
        next_max_action = F.one_hot(_a, self.a_dim).float().unsqueeze(-1)  # [T, B, A, 1]
        # [T, B, A, N] => [T, B, N]
        target_q_dist = (target_q_dist * next_max_action).sum(-2)

        target = n_step_return(BATCH.reward.repeat(1, 1, self.nums),
                               self.gamma,
                               BATCH.done.repeat(1, 1, self.nums),
                               target_q_dist,
                               BATCH.begin_mask.repeat(1, 1, self.nums)).detach()  # [T, B, N]

        q_eval = q_dist.mean(-1, keepdim=True)  # [T, B, 1]
        q_target = target.mean(-1, keepdim=True)  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1], used for PER

        target = target.unsqueeze(-2)  # [T, B, 1, N]
        q_dist = q_dist.unsqueeze(-1)  # [T, B, N, 1]

        # [T, B, 1, N] - [T, B, N, 1] => [T, B, N, N]
        quantile_error = target - q_dist
        huber = F.huber_loss(target, q_dist, reduction="none", delta=self.huber_delta)  # [T, B, N, N]
        # [N,] - [T, B, N, N] => [T, B, N, N]
        huber_abs = (self.quantiles - quantile_error.detach().le(0.).float()).abs()
        loss = (huber_abs * huber).mean(-1)  # [T, B, N, N] => [T, B, N]
        loss = loss.sum(-1, keepdim=True)  # [T, B, N] => [T, B, 1]
        loss = (loss * BATCH.get('isw', 1.0)).mean()  # 1

        self.oplr.optimize(loss)
        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }
예제 #9
0
    def _train(self, BATCH):
        q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q_next = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
        v_next = self._get_v(q_next)  # [T, B, 1]
        q_eval = (q * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        q_target = n_step_return(BATCH.reward,
                                 self.gamma,
                                 BATCH.done,
                                 v_next,
                                 BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1]

        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1
        self.oplr.optimize(q_loss)
        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': q_loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }
예제 #10
0
    def _train(self, BATCH):
        q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, P]
        q_next = self.q_net.t(BATCH.obs_,
                              begin_mask=BATCH.begin_mask)  # [T, B, P]
        beta_next = self.termination_net(
            BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, P]

        qu_eval = (q * BATCH.options).sum(-1, keepdim=True)  # [T, B, 1]
        beta_s_ = (beta_next * BATCH.options).sum(-1,
                                                  keepdim=True)  # [T, B, 1]
        q_s_ = (q_next * BATCH.options).sum(-1, keepdim=True)  # [T, B, 1]
        if self.double_q:
            q_ = self.q_net(BATCH.obs_,
                            begin_mask=BATCH.begin_mask)  # [T, B, P]
            max_a_idx = F.one_hot(q_.argmax(-1),
                                  self.options_num).float()  # [T, B, P]
            q_s_max = (q_next * max_a_idx).sum(-1, keepdim=True)  # [T, B, 1]
        else:
            q_s_max = q_next.max(-1, keepdim=True)[0]  # [T, B, 1]
        u_target = (1 - beta_s_) * q_s_ + beta_s_ * q_s_max  # [T, B, 1]
        qu_target = n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                  u_target,
                                  BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = qu_target - qu_eval  # [T, B, 1] gradient : q
        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1
        self.q_oplr.optimize(q_loss)

        q_s = qu_eval.detach()  # [T, B, 1]
        pi = self.intra_option_net(BATCH.obs,
                                   begin_mask=BATCH.begin_mask)  # [T, B, P, A]

        if self.use_baseline:
            adv = (qu_target - q_s).detach()  # [T, B, 1]
        else:
            adv = qu_target.detach()  # [T, B, 1]
        # [T, B, P] => [T, B, P, 1]
        options_onehot_expanded = BATCH.options.unsqueeze(-1)
        # [T, B, P, A] => [T, B, A]
        pi = (pi * options_onehot_expanded).sum(-2)
        if self.is_continuous:
            mu = pi.tanh()  # [T, B, A]
            log_std = self.log_std[BATCH.options.argmax(-1)]  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            log_p = dist.log_prob(BATCH.action).unsqueeze(-1)  # [T, B, 1]
            entropy = dist.entropy().unsqueeze(-1)  # [T, B, 1]
        else:
            pi = pi / self.boltzmann_temperature  # [T, B, A]
            log_pi = pi.log_softmax(-1)  # [T, B, A]
            entropy = -(log_pi.exp() * log_pi).sum(-1,
                                                   keepdim=True)  # [T, B, 1]
            log_p = (log_pi * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        pi_loss = -(log_p * adv + self.ent_coff * entropy).mean()  # 1
        self.intra_option_oplr.optimize(pi_loss)

        beta = self.termination_net(BATCH.obs,
                                    begin_mask=BATCH.begin_mask)  # [T, B, P]
        beta_s = (beta * BATCH.last_options).sum(-1, keepdim=True)  # [T, B, 1]

        interests = self.interest_net(BATCH.obs,
                                      begin_mask=BATCH.begin_mask)  # [T, B, P]
        # [T, B, P] or q.softmax(-1)
        pi_op = (interests * q.detach()).softmax(-1)
        interest_loss = -(beta_s.detach() *
                          (pi_op * BATCH.options).sum(-1, keepdim=True) *
                          q_s).mean()  # 1
        self.interest_oplr.optimize(interest_loss)

        v_s = (q * pi_op).sum(-1, keepdim=True)  # [T, B, 1]
        beta_loss = beta_s * (q_s - v_s).detach()  # [T, B, 1]
        if self.terminal_mask:
            beta_loss *= (1 - BATCH.done)  # [T, B, 1]
        beta_loss = beta_loss.mean()  # 1
        self.termination_oplr.optimize(beta_loss)

        return td_error, {
            'LEARNING_RATE/q_lr': self.q_oplr.lr,
            'LEARNING_RATE/intra_option_lr': self.intra_option_oplr.lr,
            'LEARNING_RATE/termination_lr': self.termination_oplr.lr,
            # 'Statistics/option': self.options[0],
            'LOSS/q_loss': q_loss,
            'LOSS/pi_loss': pi_loss,
            'LOSS/beta_loss': beta_loss,
            'LOSS/interest_loss': interest_loss,
            'Statistics/q_option_max': q_s.max(),
            'Statistics/q_option_min': q_s.min(),
            'Statistics/q_option_mean': q_s.mean()
        }
예제 #11
0
    def _train(self, BATCH_DICT):
        """
        TODO: Annotation
        """
        summaries = defaultdict(dict)
        target_actions = {}
        for aid, mid in zip(self.agent_ids, self.model_ids):
            if self.is_continuouss[aid]:
                target_actions[aid] = self.actors[mid].t(
                    BATCH_DICT[aid].obs_,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
            else:
                target_logits = self.actors[mid].t(
                    BATCH_DICT[aid].obs_,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
                target_cate_dist = td.Categorical(logits=target_logits)
                target_pi = target_cate_dist.sample()  # [T, B]
                action_target = F.one_hot(
                    target_pi, self.a_dims[aid]).float()  # [T, B, A]
                target_actions[aid] = action_target  # [T, B, A]
        target_actions = th.cat(list(target_actions.values()),
                                -1)  # [T, B, N*A]

        qs, q_targets = {}, {}
        for mid in self.model_ids:
            qs[mid] = self.critics[mid](
                [BATCH_DICT[id].obs for id in self.agent_ids],
                th.cat([BATCH_DICT[id].action for id in self.agent_ids],
                       -1))  # [T, B, 1]
            q_targets[mid] = self.critics[mid].t(
                [BATCH_DICT[id].obs_ for id in self.agent_ids],
                target_actions)  # [T, B, 1]

        q_loss = {}
        td_errors = 0.
        for aid, mid in zip(self.agent_ids, self.model_ids):
            dc_r = n_step_return(
                BATCH_DICT[aid].reward, self.gamma, BATCH_DICT[aid].done,
                q_targets[mid],
                BATCH_DICT['global'].begin_mask).detach()  # [T, B, 1]
            td_error = dc_r - qs[mid]  # [T, B, 1]
            td_errors += td_error
            q_loss[aid] = 0.5 * td_error.square().mean()  # 1
            summaries[aid].update({
                'Statistics/q_min': qs[mid].min(),
                'Statistics/q_mean': qs[mid].mean(),
                'Statistics/q_max': qs[mid].max()
            })
        self.critic_oplr.optimize(sum(q_loss.values()))

        actor_loss = {}
        for aid, mid in zip(self.agent_ids, self.model_ids):
            if self.is_continuouss[aid]:
                mu = self.actors[mid](
                    BATCH_DICT[aid].obs,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
            else:
                logits = self.actors[mid](
                    BATCH_DICT[aid].obs,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
                logp_all = logits.log_softmax(-1)  # [T, B, A]
                gumbel_noise = td.Gumbel(0,
                                         1).sample(logp_all.shape)  # [T, B, A]
                _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                    -1)  # [T, B, A]
                _pi_true_one_hot = F.one_hot(
                    _pi.argmax(-1), self.a_dims[aid]).float()  # [T, B, A]
                _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
                mu = _pi_diff + _pi  # [T, B, A]

            all_actions = {id: BATCH_DICT[id].action for id in self.agent_ids}
            all_actions[aid] = mu
            q_actor = self.critics[mid](
                [BATCH_DICT[id].obs for id in self.agent_ids],
                th.cat(list(all_actions.values()), -1),
                begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, 1]
            actor_loss[aid] = -q_actor.mean()  # 1

        self.actor_oplr.optimize(sum(actor_loss.values()))

        for aid in self.agent_ids:
            summaries[aid].update({
                'LOSS/actor_loss': actor_loss[aid],
                'LOSS/critic_loss': q_loss[aid]
            })
        summaries['model'].update({
            'LOSS/actor_loss',
            sum(actor_loss.values()), 'LOSS/critic_loss',
            sum(q_loss.values())
        })
        return td_errors / self.n_agents_percopy, summaries
예제 #12
0
    def _train(self, BATCH_DICT):
        summaries = {}
        reward = BATCH_DICT[self.agent_ids[0]].reward  # [T, B, 1]
        done = 0.

        q_evals = []
        q_rnncs_s = []
        q_actions = []
        q_maxs = []
        q_max_actions = []

        q_target_next_choose_maxs = []
        q_target_rnncs_s = []
        q_target_actions = []

        for aid, mid in zip(self.agent_ids, self.model_ids):
            done += BATCH_DICT[aid].done  # [T, B, 1]

            q = self.q_nets[mid](
                BATCH_DICT[aid].obs,
                begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
            q_rnncs = self.q_nets[mid].get_rnncs()  # [T, B, *]
            q_eval = (q * BATCH_DICT[aid].action).sum(
                -1, keepdim=True)  # [T, B, 1]
            q_evals.append(q_eval)  # N * [T, B, 1]
            q_rnncs_s.append(q_rnncs)  # N * [T, B, *]
            q_actions.append(BATCH_DICT[aid].action)  # N * [T, B, A]
            q_maxs.append(q.max(-1, keepdim=True)[0])  # [T, B, 1]
            q_max_actions.append(
                F.one_hot(q.argmax(-1), self.a_dims[aid]).float())  # [T, B, A]

            q_target = self.q_nets[mid].t(
                BATCH_DICT[aid].obs_,
                begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
            # [T, B, *]
            q_target_rnncs = self.q_nets[mid].target.get_rnncs()
            if self._use_double:
                next_q = self.q_nets[mid](
                    BATCH_DICT[aid].obs_,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]

                next_max_action = next_q.argmax(-1)  # [T, B]
                next_max_action_one_hot = F.one_hot(
                    next_max_action, self.a_dims[aid]).float()  # [T, B, A]

                q_target_next_max = (q_target * next_max_action_one_hot).sum(
                    -1, keepdim=True)  # [T, B, 1]
            else:
                next_max_action = q_target.argmax(-1)  # [T, B]
                next_max_action_one_hot = F.one_hot(
                    next_max_action, self.a_dims[aid]).float()  # [T, B, A]
                # [T, B, 1]
                q_target_next_max = q_target.max(-1, keepdim=True)[0]

            q_target_next_choose_maxs.append(
                q_target_next_max)  # N * [T, B, 1]
            q_target_rnncs_s.append(q_target_rnncs)  # N * [T, B, *]
            q_target_actions.append(next_max_action_one_hot)  # N * [T, B, A]

        joint_qs, vs = self.mixer(
            BATCH_DICT['global'].obs,
            q_rnncs_s,
            q_actions,
            begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, 1]
        target_joint_qs, target_vs = self.mixer.t(
            BATCH_DICT['global'].obs_,
            q_target_rnncs_s,
            q_target_actions,
            begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, 1]

        q_target_tot = n_step_return(
            reward, self.gamma, (done > 0.).float(), target_joint_qs,
            BATCH_DICT['global'].begin_mask).detach()  # [T, B, 1]
        td_error = q_target_tot - joint_qs  # [T, B, 1]
        td_loss = td_error.square().mean()  # 1

        # opt loss
        max_joint_qs, _ = self.mixer(
            BATCH_DICT['global'].obs,
            q_rnncs_s,
            q_max_actions,
            begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, 1]
        max_actions_qvals = sum(q_maxs)  # [T, B, 1]
        opt_loss = (max_actions_qvals - max_joint_qs.detach() +
                    vs).square().mean()  # 1

        # nopt loss
        nopt_error = sum(q_evals) - joint_qs.detach() + vs  # [T, B, 1]
        nopt_error = nopt_error.clamp(max=0)  # [T, B, 1]
        nopt_loss = nopt_error.square().mean()  # 1

        loss = td_loss + self.opt_loss * opt_loss + self.nopt_min_loss * nopt_loss

        self.oplr.optimize(loss)

        summaries['model'] = {
            'LOSS/q_loss': td_loss,
            'LOSS/loss': loss,
            'Statistics/q_max': joint_qs.max(),
            'Statistics/q_min': joint_qs.min(),
            'Statistics/q_mean': joint_qs.mean()
        }
        return td_error, summaries
예제 #13
0
파일: oc.py 프로젝트: StepNeverStop/RLs
    def _train(self, BATCH):
        q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, P]
        q_next = self.q_net.t(BATCH.obs_,
                              begin_mask=BATCH.begin_mask)  # [T, B, P]
        beta_next = self.termination_net(
            BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, P]

        qu_eval = (q * BATCH.options).sum(-1, keepdim=True)  # [T, B, 1]
        beta_s_ = (beta_next * BATCH.options).sum(-1,
                                                  keepdim=True)  # [T, B, 1]
        q_s_ = (q_next * BATCH.options).sum(-1, keepdim=True)  # [T, B, 1]
        # https://github.com/jeanharb/option_critic/blob/5d6c81a650a8f452bc8ad3250f1f211d317fde8c/neural_net.py#L94
        if self.double_q:
            q_ = self.q_net(BATCH.obs_,
                            begin_mask=BATCH.begin_mask)  # [T, B, P]
            # [T, B, P] => [T, B] => [T, B, P]
            max_a_idx = F.one_hot(q_.argmax(-1), self.options_num).float()
            q_s_max = (q_next * max_a_idx).sum(-1, keepdim=True)  # [T, B, 1]
        else:
            q_s_max = q_next.max(-1, keepdim=True)[0]  # [T, B, 1]
        u_target = (1 - beta_s_) * q_s_ + beta_s_ * q_s_max  # [T, B, 1]
        qu_target = n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                  u_target,
                                  BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = qu_target - qu_eval  # gradient : q   [T, B, 1]
        q_loss = (td_error.square() *
                  BATCH.get('isw', 1.0)).mean()  # [T, B, 1] => 1
        self.q_oplr.optimize(q_loss)

        q_s = qu_eval.detach()  # [T, B, 1]
        # https://github.com/jeanharb/option_critic/blob/5d6c81a650a8f452bc8ad3250f1f211d317fde8c/neural_net.py#L130
        if self.use_baseline:
            adv = (qu_target - q_s).detach()  # [T, B, 1]
        else:
            adv = qu_target.detach()  # [T, B, 1]
        # [T, B, P] => [T, B, P, 1]
        options_onehot_expanded = BATCH.options.unsqueeze(-1)
        pi = self.intra_option_net(BATCH.obs,
                                   begin_mask=BATCH.begin_mask)  # [T, B, P, A]
        # [T, B, P, A] => [T, B, A]
        pi = (pi * options_onehot_expanded).sum(-2)
        if self.is_continuous:
            mu = pi.tanh()  # [T, B, A]
            log_std = self.log_std[BATCH.options.argmax(-1)]  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            log_p = dist.log_prob(BATCH.action).unsqueeze(-1)  # [T, B, 1]
            entropy = dist.entropy().unsqueeze(-1)  # [T, B, 1]
        else:
            pi = pi / self.boltzmann_temperature  # [T, B, A]
            log_pi = pi.log_softmax(-1)  # [T, B, A]
            entropy = -(log_pi.exp() * log_pi).sum(-1,
                                                   keepdim=True)  # [T, B, 1]
            log_p = (BATCH.action * log_pi).sum(-1, keepdim=True)  # [T, B, 1]
        pi_loss = -(log_p * adv + self.ent_coff * entropy).mean()  # 1

        beta = self.termination_net(BATCH.obs,
                                    begin_mask=BATCH.begin_mask)  # [T, B, P]
        beta_s = (beta * BATCH.last_options).sum(-1, keepdim=True)  # [T, B, 1]
        if self.use_eps_greedy:
            v_s = q.max(
                -1,
                keepdim=True)[0] - self.termination_regularizer  # [T, B, 1]
        else:
            v_s = (1 - beta_s) * q_s + beta_s * q.max(
                -1, keepdim=True)[0]  # [T, B, 1]
            # v_s = q.mean(-1, keepdim=True)  # [T, B, 1]
        beta_loss = beta_s * (q_s - v_s).detach()  # [T, B, 1]
        # https://github.com/lweitkamp/option-critic-pytorch/blob/0c57da7686f8903ed2d8dded3fae832ee9defd1a/option_critic.py#L238
        if self.terminal_mask:
            beta_loss *= (1 - BATCH.done)  # [T, B, 1]
        beta_loss = beta_loss.mean()  # 1

        self.intra_option_oplr.optimize(pi_loss)
        self.termination_oplr.optimize(beta_loss)

        return td_error, {
            'LEARNING_RATE/q_lr': self.q_oplr.lr,
            'LEARNING_RATE/intra_option_lr': self.intra_option_oplr.lr,
            'LEARNING_RATE/termination_lr': self.termination_oplr.lr,
            # 'Statistics/option': self.options[0],
            'LOSS/q_loss': q_loss,
            'LOSS/pi_loss': pi_loss,
            'LOSS/beta_loss': beta_loss,
            'Statistics/q_option_max': q_s.max(),
            'Statistics/q_option_min': q_s.min(),
            'Statistics/q_option_mean': q_s.mean()
        }
예제 #14
0
    def _train(self, BATCH):
        if self.is_continuous:
            target_mu, target_log_std = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(target_mu, target_log_std.exp()), 1)
            target_pi = dist.sample()  # [T, B, A]
            target_pi, target_log_pi = squash_action(target_pi, dist.log_prob(
                target_pi).unsqueeze(-1), is_independent=False)  # [T, B, A]
            target_log_pi = tsallis_entropy_log_q(target_log_pi, self.entropic_index)  # [T, B, 1]
        else:
            target_logits = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            target_cate_dist = td.Categorical(logits=target_logits)
            target_pi = target_cate_dist.sample()  # [T, B]
            target_log_pi = target_cate_dist.log_prob(target_pi).unsqueeze(-1)  # [T, B, 1]
            target_pi = F.one_hot(target_pi, self.a_dim).float()  # [T, B, A]
        q1 = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q2 = self.critic2(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask)  # [T, B, 1]

        q1_target = self.critic.t(BATCH.obs_, target_pi, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q2_target = self.critic2.t(BATCH.obs_, target_pi, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q_target = th.minimum(q1_target, q2_target)  # [T, B, 1]
        dc_r = n_step_return(BATCH.reward,
                             self.gamma,
                             BATCH.done,
                             (q_target - self.alpha * target_log_pi),
                             BATCH.begin_mask).detach()  # [T, B, 1]
        td_error1 = q1 - dc_r  # [T, B, 1]
        td_error2 = q2 - dc_r  # [T, B, 1]

        q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
        q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
        critic_loss = 0.5 * q1_loss + 0.5 * q2_loss
        self.critic_oplr.optimize(critic_loss)

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            pi = dist.rsample()  # [T, B, A]
            pi, log_pi = squash_action(pi, dist.log_prob(pi).unsqueeze(-1), is_independent=False)  # [T, B, A]
            log_pi = tsallis_entropy_log_q(log_pi, self.entropic_index)  # [T, B, 1]
            entropy = dist.entropy().mean()  # 1
        else:
            logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(-1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            pi = _pi_diff + _pi  # [T, B, A]
            log_pi = (logp_all * pi).sum(-1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(-1).mean()  # 1
        q_s_pi = th.minimum(self.critic(BATCH.obs, pi, begin_mask=BATCH.begin_mask),
                            self.critic2(BATCH.obs, pi, begin_mask=BATCH.begin_mask))  # [T, B, 1]
        actor_loss = -(q_s_pi - self.alpha * log_pi).mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        summaries = {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/q1_loss': q1_loss,
            'LOSS/q2_loss': q2_loss,
            'LOSS/critic_loss': critic_loss,
            'Statistics/log_alpha': self.log_alpha,
            'Statistics/alpha': self.alpha,
            'Statistics/entropy': entropy,
            'Statistics/q_min': th.minimum(q1, q2).min(),
            'Statistics/q_mean': th.minimum(q1, q2).mean(),
            'Statistics/q_max': th.maximum(q1, q2).max()
        }
        if self.auto_adaption:
            alpha_loss = -(self.alpha * (log_pi + self.target_entropy).detach()).mean()  # 1
            self.alpha_oplr.optimize(alpha_loss)
            summaries.update({
                'LOSS/alpha_loss': alpha_loss,
                'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr
            })
        return (td_error1 + td_error2) / 2, summaries
예제 #15
0
    def _train(self, BATCH):
        time_step = BATCH.reward.shape[0]
        batch_size = BATCH.reward.shape[1]

        quantiles, quantiles_tiled = self._generate_quantiles(  # [T*B, N, 1], [N*T*B, X]
            batch_size=time_step * batch_size,
            quantiles_num=self.online_quantiles)
        # [T*B, N, 1] => [T, B, N, 1]
        quantiles = quantiles.view(time_step, batch_size, -1, 1)
        quantiles_tiled = quantiles_tiled.view(time_step, -1, self.quantiles_idx)  # [N*T*B, X] => [T, N*B, X]

        quantiles_value = self.q_net(BATCH.obs, quantiles_tiled, begin_mask=BATCH.begin_mask)  # [T, N, B, A]
        # [T, N, B, A] => [N, T, B, A] * [T, B, A] => [N, T, B, 1]
        quantiles_value = (quantiles_value.swapaxes(0, 1) * BATCH.action).sum(-1, keepdim=True)
        q_eval = quantiles_value.mean(0)  # [N, T, B, 1] => [T, B, 1]

        _, select_quantiles_tiled = self._generate_quantiles(  # [N*T*B, X]
            batch_size=time_step * batch_size,
            quantiles_num=self.select_quantiles)
        select_quantiles_tiled = select_quantiles_tiled.view(
            time_step, -1, self.quantiles_idx)  # [N*T*B, X] => [T, N*B, X]

        q_values = self.q_net(
            BATCH.obs_, select_quantiles_tiled, begin_mask=BATCH.begin_mask)  # [T, N, B, A]
        q_values = q_values.mean(1)  # [T, N, B, A] => [T, B, A]
        next_max_action = q_values.argmax(-1)  # [T, B]
        next_max_action = F.one_hot(
            next_max_action, self.a_dim).float()  # [T, B, A]

        _, target_quantiles_tiled = self._generate_quantiles(  # [N'*T*B, X]
            batch_size=time_step * batch_size,
            quantiles_num=self.target_quantiles)
        target_quantiles_tiled = target_quantiles_tiled.view(
            time_step, -1, self.quantiles_idx)  # [N'*T*B, X] => [T, N'*B, X]
        target_quantiles_value = self.q_net.t(BATCH.obs_, target_quantiles_tiled,
                                              begin_mask=BATCH.begin_mask)  # [T, N', B, A]
        target_quantiles_value = target_quantiles_value.swapaxes(0, 1)  # [T, N', B, A] => [N', T, B, A]
        target_quantiles_value = (target_quantiles_value * next_max_action).sum(-1, keepdim=True)  # [N', T, B, 1]

        target_q = target_quantiles_value.mean(0)  # [T, B, 1]
        q_target = n_step_return(BATCH.reward,  # [T, B, 1]
                                 self.gamma,
                                 BATCH.done,  # [T, B, 1]
                                 target_q,  # [T, B, 1]
                                 BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1]

        # [N', T, B, 1] => [N', T, B]
        target_quantiles_value = target_quantiles_value.squeeze(-1)
        target_quantiles_value = target_quantiles_value.permute(
            1, 2, 0)  # [N', T, B] => [T, B, N']
        quantiles_value_target = n_step_return(BATCH.reward.repeat(1, 1, self.target_quantiles),
                                               self.gamma,
                                               BATCH.done.repeat(1, 1, self.target_quantiles),
                                               target_quantiles_value,
                                               BATCH.begin_mask.repeat(1, 1,
                                                                       self.target_quantiles)).detach()  # [T, B, N']
        # [T, B, N'] => [T, B, 1, N']
        quantiles_value_target = quantiles_value_target.unsqueeze(-2)
        quantiles_value_online = quantiles_value.permute(1, 2, 0, 3)  # [N, T, B, 1] => [T, B, N, 1]
        # [T, B, N, 1] - [T, B, 1, N'] => [T, B, N, N']
        quantile_error = quantiles_value_online - quantiles_value_target
        huber = F.huber_loss(quantiles_value_online, quantiles_value_target,
                             reduction="none", delta=self.huber_delta)  # [T, B, N, N]
        # [T, B, N, 1] - [T, B, N, N'] => [T, B, N, N']
        huber_abs = (quantiles - quantile_error.detach().le(0.).float()).abs()
        loss = (huber_abs * huber).mean(-1)  # [T, B, N, N'] => [T, B, N]
        loss = loss.sum(-1, keepdim=True)  # [T, B, N] => [T, B, 1]

        loss = (loss * BATCH.get('isw', 1.0)).mean()  # 1
        self.oplr.optimize(loss)
        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }
예제 #16
0
    def _train(self, BATCH):
        for _ in range(self.delay_num):
            if self.is_continuous:
                action_target = self.target_noised_action(
                    self.actor.t(BATCH.obs_,
                                 begin_mask=BATCH.begin_mask))  # [T, B, A]
            else:
                target_logits = self.actor.t(
                    BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
                target_cate_dist = td.Categorical(logits=target_logits)
                target_pi = target_cate_dist.sample()  # [T, B]
                action_target = F.one_hot(target_pi,
                                          self.a_dim).float()  # [T, B, A]
            q1 = self.critic(BATCH.obs,
                             BATCH.action,
                             begin_mask=BATCH.begin_mask)  # [T, B, 1]
            q2 = self.critic2(BATCH.obs,
                              BATCH.action,
                              begin_mask=BATCH.begin_mask)  # [T, B, 1]
            q_target = th.minimum(
                self.critic.t(BATCH.obs_,
                              action_target,
                              begin_mask=BATCH.begin_mask),
                self.critic2.t(BATCH.obs_,
                               action_target,
                               begin_mask=BATCH.begin_mask))  # [T, B, 1]
            dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                 q_target,
                                 BATCH.begin_mask).detach()  # [T, B, 1]
            td_error1 = q1 - dc_r  # [T, B, 1]
            td_error2 = q2 - dc_r  # [T, B, 1]

            q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
            q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
            critic_loss = 0.5 * (q1_loss + q2_loss)
            self.critic_oplr.optimize(critic_loss)

        if self.is_continuous:
            mu = self.actor(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                -1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1),
                                         self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            mu = _pi_diff + _pi  # [T, B, A]
        q1_actor = self.critic(BATCH.obs, mu,
                               begin_mask=BATCH.begin_mask)  # [T, B, 1]

        actor_loss = -q1_actor.mean()  # 1
        self.actor_oplr.optimize(actor_loss)
        return (td_error1 + td_error2) / 2, {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': critic_loss,
            'Statistics/q_min': th.minimum(q1, q2).min(),
            'Statistics/q_mean': th.minimum(q1, q2).mean(),
            'Statistics/q_max': th.maximum(q1, q2).max()
        }
예제 #17
0
파일: qplex.py 프로젝트: StepNeverStop/RLs
    def _train(self, BATCH_DICT):
        summaries = {}
        reward = BATCH_DICT[self.agent_ids[0]].reward  # [T, B, 1]
        done = 0.

        q_evals = []
        q_actions = []
        q_maxs = []

        q_target_next_choose_maxs = []
        q_target_actions = []
        q_target_next_maxs = []

        for aid, mid in zip(self.agent_ids, self.model_ids):
            done += BATCH_DICT[aid].done  # [T, B, 1]

            q = self.q_nets[mid](
                BATCH_DICT[aid].obs,
                begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
            q_eval = (q * BATCH_DICT[aid].action).sum(
                -1, keepdim=True)  # [T, B, 1]
            q_evals.append(q_eval)  # N * [T, B, 1]
            q_actions.append(BATCH_DICT[aid].action)  # N * [T, B, A]
            q_maxs.append(q.max(-1, keepdim=True)[0])  # [T, B, 1]

            q_target = self.q_nets[mid].t(
                BATCH_DICT[aid].obs_,
                begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]

            # use double
            next_q = self.q_nets[mid](
                BATCH_DICT[aid].obs_,
                begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]

            next_max_action = next_q.argmax(-1)  # [T, B]
            next_max_action_one_hot = F.one_hot(
                next_max_action, self.a_dims[aid]).float()  # [T, B, A]

            q_target_next_max = (q_target * next_max_action_one_hot).sum(
                -1, keepdim=True)  # [T, B, 1]

            q_target_next_choose_maxs.append(
                q_target_next_max)  # N * [T, B, 1]
            q_target_actions.append(next_max_action_one_hot)  # N * [T, B, A]
            q_target_next_maxs.append(q_target.max(
                -1, keepdim=True)[0])  # N * [T, B, 1]

        q_evals = th.stack(q_evals, -1)  # [T, B, 1, N]
        q_maxs = th.stack(q_maxs, -1)  # [T, B, 1, N]
        q_target_next_choose_maxs = th.stack(q_target_next_choose_maxs,
                                             -1)  # [T, B, 1, N]
        q_target_next_maxs = th.stack(q_target_next_maxs, -1)  # [T, B, 1, N]

        q_eval_tot = self.mixer(
            BATCH_DICT['global'].obs,
            q_evals,
            q_actions,
            q_maxs,
            begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, 1]
        q_target_next_max_tot = self.mixer.t(
            BATCH_DICT['global'].obs_,
            q_target_next_choose_maxs,
            q_target_actions,
            q_target_next_maxs,
            begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, 1]

        q_target_tot = n_step_return(
            reward, self.gamma, (done > 0.).float(), q_target_next_max_tot,
            BATCH_DICT['global'].begin_mask).detach()  # [T, B, 1]
        td_error = q_target_tot - q_eval_tot  # [T, B, 1]
        q_loss = td_error.square().mean()  # 1
        self.oplr.optimize(q_loss)

        summaries['model'] = {
            'LOSS/q_loss': q_loss,
            'Statistics/q_max': q_eval_tot.max(),
            'Statistics/q_min': q_eval_tot.min(),
            'Statistics/q_mean': q_eval_tot.mean()
        }
        return td_error, summaries
예제 #18
0
파일: masac.py 프로젝트: StepNeverStop/RLs
    def _train(self, BATCH_DICT):
        """
        TODO: Annotation
        """
        summaries = defaultdict(dict)
        target_actions = {}
        target_log_pis = 1.
        for aid, mid in zip(self.agent_ids, self.model_ids):
            if self.is_continuouss[aid]:
                target_mu, target_log_std = self.actors[mid](
                    BATCH_DICT[aid].obs_,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
                dist = td.Independent(
                    td.Normal(target_mu, target_log_std.exp()), 1)
                target_pi = dist.sample()  # [T, B, A]
                target_pi, target_log_pi = squash_action(
                    target_pi,
                    dist.log_prob(target_pi).unsqueeze(
                        -1))  # [T, B, A], [T, B, 1]
            else:
                target_logits = self.actors[mid](
                    BATCH_DICT[aid].obs_,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
                target_cate_dist = td.Categorical(logits=target_logits)
                target_pi = target_cate_dist.sample()  # [T, B]
                target_log_pi = target_cate_dist.log_prob(target_pi).unsqueeze(
                    -1)  # [T, B, 1]
                target_pi = F.one_hot(target_pi,
                                      self.a_dims[aid]).float()  # [T, B, A]
            target_actions[aid] = target_pi
            target_log_pis *= target_log_pi

        target_log_pis += th.finfo().eps
        target_actions = th.cat(list(target_actions.values()),
                                -1)  # [T, B, N*A]

        qs1, qs2, q_targets1, q_targets2 = {}, {}, {}, {}
        for mid in self.model_ids:
            qs1[mid] = self.critics[mid](
                [BATCH_DICT[id].obs for id in self.agent_ids],
                th.cat([BATCH_DICT[id].action for id in self.agent_ids],
                       -1))  # [T, B, 1]
            qs2[mid] = self.critics2[mid](
                [BATCH_DICT[id].obs for id in self.agent_ids],
                th.cat([BATCH_DICT[id].action for id in self.agent_ids],
                       -1))  # [T, B, 1]
            q_targets1[mid] = self.critics[mid].t(
                [BATCH_DICT[id].obs_ for id in self.agent_ids],
                target_actions)  # [T, B, 1]
            q_targets2[mid] = self.critics2[mid].t(
                [BATCH_DICT[id].obs_ for id in self.agent_ids],
                target_actions)  # [T, B, 1]

        q_loss = {}
        td_errors = 0.
        for aid, mid in zip(self.agent_ids, self.model_ids):
            q_target = th.minimum(q_targets1[mid],
                                  q_targets2[mid])  # [T, B, 1]
            dc_r = n_step_return(
                BATCH_DICT[aid].reward, self.gamma, BATCH_DICT[aid].done,
                q_target - self.alpha * target_log_pis,
                BATCH_DICT['global'].begin_mask).detach()  # [T, B, 1]
            td_error1 = qs1[mid] - dc_r  # [T, B, 1]
            td_error2 = qs2[mid] - dc_r  # [T, B, 1]
            td_errors += (td_error1 + td_error2) / 2
            q1_loss = td_error1.square().mean()  # 1
            q2_loss = td_error2.square().mean()  # 1
            q_loss[aid] = 0.5 * q1_loss + 0.5 * q2_loss
            summaries[aid].update({
                'Statistics/q_min': qs1[mid].min(),
                'Statistics/q_mean': qs1[mid].mean(),
                'Statistics/q_max': qs1[mid].max()
            })
        self.critic_oplr.optimize(sum(q_loss.values()))

        log_pi_actions = {}
        log_pis = {}
        sample_pis = {}
        for aid, mid in zip(self.agent_ids, self.model_ids):
            if self.is_continuouss[aid]:
                mu, log_std = self.actors[mid](
                    BATCH_DICT[aid].obs,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
                dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
                pi = dist.rsample()  # [T, B, A]
                pi, log_pi = squash_action(
                    pi,
                    dist.log_prob(pi).unsqueeze(-1))  # [T, B, A], [T, B, 1]
                pi_action = BATCH_DICT[aid].action.arctanh()
                _, log_pi_action = squash_action(
                    pi_action,
                    dist.log_prob(pi_action).unsqueeze(
                        -1))  # [T, B, A], [T, B, 1]
            else:
                logits = self.actors[mid](
                    BATCH_DICT[aid].obs,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
                logp_all = logits.log_softmax(-1)  # [T, B, A]
                gumbel_noise = td.Gumbel(0,
                                         1).sample(logp_all.shape)  # [T, B, A]
                _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                    -1)  # [T, B, A]
                _pi_true_one_hot = F.one_hot(
                    _pi.argmax(-1), self.a_dims[aid]).float()  # [T, B, A]
                _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
                pi = _pi_diff + _pi  # [T, B, A]
                log_pi = (logp_all * pi).sum(-1, keepdim=True)  # [T, B, 1]
                log_pi_action = (logp_all * BATCH_DICT[aid].action).sum(
                    -1, keepdim=True)  # [T, B, 1]
            log_pi_actions[aid] = log_pi_action
            log_pis[aid] = log_pi
            sample_pis[aid] = pi

        actor_loss = {}
        for aid, mid in zip(self.agent_ids, self.model_ids):
            all_actions = {id: BATCH_DICT[id].action for id in self.agent_ids}
            all_actions[aid] = sample_pis[aid]
            all_log_pis = {id: log_pi_actions[id] for id in self.agent_ids}
            all_log_pis[aid] = log_pis[aid]

            q_s_pi = th.minimum(
                self.critics[mid](
                    [BATCH_DICT[id].obs for id in self.agent_ids],
                    th.cat(list(all_actions.values()), -1),
                    begin_mask=BATCH_DICT['global'].begin_mask),
                self.critics2[mid](
                    [BATCH_DICT[id].obs for id in self.agent_ids],
                    th.cat(list(all_actions.values()), -1),
                    begin_mask=BATCH_DICT['global'].begin_mask))  # [T, B, 1]

            _log_pis = 1.
            for _log_pi in all_log_pis.values():
                _log_pis *= _log_pi
            _log_pis += th.finfo().eps
            actor_loss[aid] = -(q_s_pi - self.alpha * _log_pis).mean()  # 1

        self.actor_oplr.optimize(sum(actor_loss.values()))

        for aid in self.agent_ids:
            summaries[aid].update({
                'LOSS/actor_loss': actor_loss[aid],
                'LOSS/critic_loss': q_loss[aid]
            })
        summaries['model'].update({
            'LOSS/actor_loss': sum(actor_loss.values()),
            'LOSS/critic_loss': sum(q_loss.values())
        })

        if self.auto_adaption:
            _log_pis = 1.
            _log_pis = 1.
            for _log_pi in log_pis.values():
                _log_pis *= _log_pi
            _log_pis += th.finfo().eps

            alpha_loss = -(
                self.alpha *
                (_log_pis + self.target_entropy).detach()).mean()  # 1

            self.alpha_oplr.optimize(alpha_loss)
            summaries['model'].update({
                'LOSS/alpha_loss':
                alpha_loss,
                'LEARNING_RATE/alpha_lr':
                self.alpha_oplr.lr
            })
        return td_errors / self.n_agents_percopy, summaries
예제 #19
0
    def _train_discrete(self, BATCH):
        v = self.v_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        v_target = self.v_net.t(BATCH.obs_,
                                begin_mask=BATCH.begin_mask)  # [T, B, 1]

        q1_all = self.q_net(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        q2_all = self.q_net2(BATCH.obs,
                             begin_mask=BATCH.begin_mask)  # [T, B, A]
        q1 = (q1_all * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        q2 = (q2_all * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        logits = self.actor(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        logp_all = logits.log_softmax(-1)  # [T, B, A]

        dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, v_target,
                             BATCH.begin_mask).detach()  # [T, B, 1]
        td_v = v - (th.minimum((logp_all.exp() * q1_all).sum(-1, keepdim=True),
                               (logp_all.exp() * q2_all).sum(
                                   -1, keepdim=True))).detach()  # [T, B, 1]
        td_error1 = q1 - dc_r  # [T, B, 1]
        td_error2 = q2 - dc_r  # [T, B, 1]

        q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
        q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
        v_loss_stop = (td_v.square() * BATCH.get('isw', 1.0)).mean()  # 1
        critic_loss = 0.5 * q1_loss + 0.5 * q2_loss + 0.5 * v_loss_stop
        self.critic_oplr.optimize(critic_loss)

        q1_all = self.q_net(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        q2_all = self.q_net2(BATCH.obs,
                             begin_mask=BATCH.begin_mask)  # [T, B, A]
        logits = self.actor(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        logp_all = logits.log_softmax(-1)  # [T, B, A]

        entropy = -(logp_all.exp() * logp_all).sum(-1,
                                                   keepdim=True)  # [T, B, 1]
        q_all = th.minimum(q1_all, q2_all)  # [T, B, A]
        actor_loss = -((q_all - self.alpha * logp_all) * logp_all.exp()).sum(
            -1)  # [T, B, A] => [T, B]
        actor_loss = actor_loss.mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        summaries = {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/q1_loss': q1_loss,
            'LOSS/q2_loss': q2_loss,
            'LOSS/v_loss': v_loss_stop,
            'LOSS/critic_loss': critic_loss,
            'Statistics/log_alpha': self.log_alpha,
            'Statistics/alpha': self.alpha,
            'Statistics/entropy': entropy.mean(),
            'Statistics/v_mean': v.mean()
        }
        if self.auto_adaption:
            corr = (self.target_entropy - entropy).detach()  # [T, B, 1]
            # corr = ((logp_all - self.a_dim) * logp_all.exp()).sum(-1).detach()
            alpha_loss = -(self.alpha * corr)  # [T, B, 1]
            alpha_loss = alpha_loss.mean()  # 1
            self.alpha_oplr.optimize(alpha_loss)
            summaries.update({
                'LOSS/alpha_loss': alpha_loss,
                'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr
            })
        return (td_error1 + td_error2) / 2, summaries
예제 #20
0
    def _train_continuous(self, BATCH):
        v = self.v_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        v_target = self.v_net.t(BATCH.obs_,
                                begin_mask=BATCH.begin_mask)  # [T, B, 1]

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs,
                                     begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            pi = dist.rsample()  # [T, B, A]
            pi, log_pi = squash_action(
                pi,
                dist.log_prob(pi).unsqueeze(-1))  # [T, B, A], [T, B, 1]
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                -1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1),
                                         self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            pi = _pi_diff + _pi  # [T, B, A]
            log_pi = (logp_all * pi).sum(-1, keepdim=True)  # [T, B, 1]
        q1 = self.q_net(BATCH.obs, BATCH.action,
                        begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q2 = self.q_net2(BATCH.obs, BATCH.action,
                         begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q1_pi = self.q_net(BATCH.obs, pi,
                           begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q2_pi = self.q_net2(BATCH.obs, pi,
                            begin_mask=BATCH.begin_mask)  # [T, B, 1]
        dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, v_target,
                             BATCH.begin_mask).detach()  # [T, B, 1]
        v_from_q_stop = (th.minimum(q1_pi, q2_pi) -
                         self.alpha * log_pi).detach()  # [T, B, 1]
        td_v = v - v_from_q_stop  # [T, B, 1]
        td_error1 = q1 - dc_r  # [T, B, 1]
        td_error2 = q2 - dc_r  # [T, B, 1]
        q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
        q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
        v_loss_stop = (td_v.square() * BATCH.get('isw', 1.0)).mean()  # 1

        critic_loss = 0.5 * q1_loss + 0.5 * q2_loss + 0.5 * v_loss_stop
        self.critic_oplr.optimize(critic_loss)

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs,
                                     begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            pi = dist.rsample()  # [T, B, A]
            pi, log_pi = squash_action(
                pi,
                dist.log_prob(pi).unsqueeze(-1))  # [T, B, A], [T, B, 1]
            entropy = dist.entropy().mean()  # 1
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                -1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1),
                                         self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            pi = _pi_diff + _pi  # [T, B, A]
            log_pi = (logp_all * pi).sum(-1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(-1).mean()  # 1
        q1_pi = self.q_net(BATCH.obs, pi,
                           begin_mask=BATCH.begin_mask)  # [T, B, 1]
        actor_loss = -(q1_pi - self.alpha * log_pi).mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        summaries = {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/q1_loss': q1_loss,
            'LOSS/q2_loss': q2_loss,
            'LOSS/v_loss': v_loss_stop,
            'LOSS/critic_loss': critic_loss,
            'Statistics/log_alpha': self.log_alpha,
            'Statistics/alpha': self.alpha,
            'Statistics/entropy': entropy,
            'Statistics/q_min': th.minimum(q1, q2).min(),
            'Statistics/q_mean': th.minimum(q1, q2).mean(),
            'Statistics/q_max': th.maximum(q1, q2).max(),
            'Statistics/v_mean': v.mean()
        }
        if self.auto_adaption:
            alpha_loss = -(self.alpha *
                           (log_pi.detach() + self.target_entropy)).mean()
            self.alpha_oplr.optimize(alpha_loss)
            summaries.update({
                'LOSS/alpha_loss': alpha_loss,
                'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr
            })
        return (td_error1 + td_error2) / 2, summaries
예제 #21
0
    def _train_discrete(self, BATCH):
        q1_all = self.critic(BATCH.obs,
                             begin_mask=BATCH.begin_mask)  # [T, B, A]
        q2_all = self.critic2(BATCH.obs,
                              begin_mask=BATCH.begin_mask)  # [T, B, A]

        q1 = (q1_all * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        q2 = (q2_all * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        target_logits = self.actor(BATCH.obs_,
                                   begin_mask=BATCH.begin_mask)  # [T, B, A]
        target_log_probs = target_logits.log_softmax(-1)  # [T, B, A]
        q1_target = self.critic.t(BATCH.obs_,
                                  begin_mask=BATCH.begin_mask)  # [T, B, A]
        q2_target = self.critic2.t(BATCH.obs_,
                                   begin_mask=BATCH.begin_mask)  # [T, B, A]

        def v_target_function(x):
            return (target_log_probs.exp() *
                    (x - self.alpha * target_log_probs)).sum(
                        -1, keepdim=True)  # [T, B, 1]

        v1_target = v_target_function(q1_target)  # [T, B, 1]
        v2_target = v_target_function(q2_target)  # [T, B, 1]
        v_target = th.minimum(v1_target, v2_target)  # [T, B, 1]
        dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, v_target,
                             BATCH.begin_mask).detach()  # [T, B, 1]
        td_error1 = q1 - dc_r  # [T, B, 1]
        td_error2 = q2 - dc_r  # [T, B, 1]

        q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
        q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
        critic_loss = 0.5 * q1_loss + 0.5 * q2_loss
        self.critic_oplr.optimize(critic_loss)

        q1_all = self.critic(BATCH.obs,
                             begin_mask=BATCH.begin_mask)  # [T, B, A]
        q2_all = self.critic2(BATCH.obs,
                              begin_mask=BATCH.begin_mask)  # [T, B, A]

        logits = self.actor(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        logp_all = logits.log_softmax(-1)  # [T, B, A]
        entropy = -(logp_all.exp() * logp_all).sum(-1,
                                                   keepdim=True)  # [T, B, 1]
        q_all = th.minimum(q1_all, q2_all)  # [T, B, A]
        actor_loss = -((q_all - self.alpha * logp_all) * logp_all.exp()).sum(
            -1)  # [T, B, A] => [T, B]
        actor_loss = actor_loss.mean()  # 1
        # actor_loss = - (q_all + self.alpha * entropy).mean()

        self.actor_oplr.optimize(actor_loss)

        summaries = {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/q1_loss': q1_loss,
            'LOSS/q2_loss': q2_loss,
            'LOSS/critic_loss': critic_loss,
            'Statistics/log_alpha': self.log_alpha,
            'Statistics/alpha': self.alpha,
            'Statistics/entropy': entropy.mean()
        }
        if self.auto_adaption:
            corr = (self.target_entropy - entropy).detach()  # [T, B, 1]
            # corr = ((logp_all - self.a_dim) * logp_all.exp()).sum(-1).detach()    #[B, A] => [B,]
            # J(\alpha)=\pi_{t}\left(s_{t}\right)^{T}\left[-\alpha\left(\log \left(\pi_{t}\left(s_{t}\right)\right)+\bar{H}\right)\right]
            # \bar{H} is negative
            alpha_loss = -(self.alpha * corr)  # [T, B, 1]
            alpha_loss = alpha_loss.mean()  # 1
            self.alpha_oplr.optimize(alpha_loss)
            summaries.update({
                'LOSS/alpha_loss': alpha_loss,
                'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr
            })
        return (td_error1 + td_error2) / 2, summaries
예제 #22
0
    def _train(self, BATCH):
        if self.is_continuous:
            # Variational Auto-Encoder Training
            recon, mean, std = self.vae(BATCH.obs,
                                        BATCH.action,
                                        begin_mask=BATCH.begin_mask)
            recon_loss = F.mse_loss(recon, BATCH.action)

            KL_loss = -0.5 * (1 + th.log(std.pow(2)) - mean.pow(2) -
                              std.pow(2)).mean()
            vae_loss = recon_loss + 0.5 * KL_loss

            self.vae_oplr.optimize(vae_loss)

            target_Qs = []
            for _ in range(self._train_samples):
                # Compute value of perturbed actions sampled from the VAE
                _vae_actions = self.vae.decode(BATCH.obs_,
                                               begin_mask=BATCH.begin_mask)
                _actor_actions = self.actor.t(BATCH.obs_,
                                              _vae_actions,
                                              begin_mask=BATCH.begin_mask)
                target_Q1, target_Q2 = self.critic.t(
                    BATCH.obs_, _actor_actions, begin_mask=BATCH.begin_mask)

                # Soft Clipped Double Q-learning
                target_Q = self._lmbda * th.min(target_Q1, target_Q2) + \
                           (1. - self._lmbda) * th.max(target_Q1, target_Q2)
                target_Qs.append(target_Q)
            target_Qs = th.stack(target_Qs, dim=0)  # [N, T, B, 1]
            # Take max over each BATCH.action sampled from the VAE
            target_Q = target_Qs.max(dim=0)[0]  # [T, B, 1]

            target_Q = n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                     target_Q,
                                     BATCH.begin_mask).detach()  # [T, B, 1]

            current_Q1, current_Q2 = self.critic(BATCH.obs,
                                                 BATCH.action,
                                                 begin_mask=BATCH.begin_mask)
            td_error = ((current_Q1 - target_Q) + (current_Q2 - target_Q)) / 2
            critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
                current_Q2, target_Q)

            self.critic_oplr.optimize(critic_loss)

            # Pertubation Model / Action Training
            sampled_actions = self.vae.decode(BATCH.obs,
                                              begin_mask=BATCH.begin_mask)
            perturbed_actions = self.actor(BATCH.obs,
                                           sampled_actions,
                                           begin_mask=BATCH.begin_mask)

            # Update through DPG
            q1, _ = self.critic(BATCH.obs,
                                perturbed_actions,
                                begin_mask=BATCH.begin_mask)
            actor_loss = -q1.mean()

            self.actor_oplr.optimize(actor_loss)

            return td_error, {
                'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
                'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
                'LEARNING_RATE/vae_lr': self.vae_oplr.lr,
                'LOSS/actor_loss': actor_loss,
                'LOSS/critic_loss': critic_loss,
                'LOSS/vae_loss': vae_loss,
                'Statistics/q_min': q1.min(),
                'Statistics/q_mean': q1.mean(),
                'Statistics/q_max': q1.max()
            }

        else:
            q_next, i_next = self.q_net(
                BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            q_next = q_next - q_next.min(dim=-1, keepdim=True)[0]  # [B, *]
            i_next = F.log_softmax(i_next, dim=-1)  # [T, B, A]
            i_next = i_next.exp()  # [T, B, A]
            i_next = (i_next / i_next.max(-1, keepdim=True)[0] >
                      self._threshold).float()  # [T, B, A]
            q_next = i_next * q_next  # [T, B, A]
            next_max_action = q_next.argmax(-1)  # [T, B]
            next_max_action_one_hot = F.one_hot(
                next_max_action.squeeze(), self.a_dim).float()  # [T, B, A]

            q_target_next, _ = self.q_net.t(
                BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            q_target_next_max = (q_target_next * next_max_action_one_hot).sum(
                -1, keepdim=True)  # [T, B, 1]
            q_target = n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                     q_target_next_max,
                                     BATCH.begin_mask).detach()  # [T, B, 1]

            q, i = self.q_net(BATCH.obs,
                              begin_mask=BATCH.begin_mask)  # [T, B, A]
            q_eval = (q * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]

            td_error = q_target - q_eval  # [T, B, 1]
            q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1

            imt = F.log_softmax(i, dim=-1)  # [T, B, A]
            imt = imt.reshape(-1, self.a_dim)  # [T*B, A]
            action = BATCH.action.reshape(-1, self.a_dim)  # [T*B, A]
            i_loss = F.nll_loss(imt, action.argmax(-1))  # 1

            loss = q_loss + i_loss + 1e-2 * i.pow(2).mean()

            self.oplr.optimize(loss)
            return td_error, {
                'LEARNING_RATE/lr': self.oplr.lr,
                'LOSS/q_loss': q_loss,
                'LOSS/i_loss': i_loss,
                'LOSS/loss': loss,
                'Statistics/q_max': q_eval.max(),
                'Statistics/q_min': q_eval.min(),
                'Statistics/q_mean': q_eval.mean()
            }