예제 #1
0
 def forward(self, ob, state_in=None):
     if isinstance(ob, PackedSequence):
         ob = ob.data
     logits = np.random.rand(ob.shape[0], 2)
     if state_in is None:
         state_in = torch.from_numpy(np.zeros(10))
     return CatDist(torch.from_numpy(logits)), state_in + 1
예제 #2
0
    def loss(self, batch):
        """Loss function."""
        dist = self.pi(batch['obs']).dist
        q1 = self.qf1(batch['obs'], batch['action']).value
        q2 = self.qf2(batch['obs'], batch['action']).value

        # alpha loss
        if self.automatic_entropy_tuning:
            ent_error = dist.entropy() - self.target_entropy
            alpha_loss = self.log_alpha * ent_error.detach().mean()
            self.opt_alpha.zero_grad()
            alpha_loss.backward()
            self.opt_alpha.step()
            alpha = self.log_alpha.exp()
        else:
            alpha = self.alpha
            alpha_loss = 0

        # qf loss
        with torch.no_grad():
            next_dist = self.pi(batch['next_obs']).dist
            q1_next = self.target_qf1(batch['next_obs']).qvals
            q2_next = self.target_qf2(batch['next_obs']).qvals
            qmin = torch.min(q1_next, q2_next)
            # explicitly compute the expectation over next actions
            qnext = torch.sum(qmin * next_dist.probs,
                              dim=1) + alpha * next_dist.entropy()
            qtarg = batch['reward'] + (1.0 -
                                       batch['done']) * self.gamma * qnext

        assert qtarg.shape == q1.shape
        assert qtarg.shape == q2.shape
        qf1_loss = self.mse_loss(q1, qtarg)
        qf2_loss = self.mse_loss(q2, qtarg)

        # pi loss
        pi_loss = None
        if self.t % self.policy_update_period == 0:
            with torch.no_grad():
                q1_pi = self.qf1(batch['obs']).qvals
                q2_pi = self.qf2(batch['obs']).qvals
                min_q_pi = torch.min(q1_pi, q2_pi)
            assert min_q_pi.shape == dist.logits.shape
            target_dist = CatDist(logits=min_q_pi)
            pi_dist = CatDist(logits=alpha * dist.logits)
            pi_loss = pi_dist.kl(target_dist).mean()

            # log pi loss about as frequently as other losses
            if self.t % self.log_period < self.policy_update_period:
                logger.add_scalar('loss/pi', pi_loss, self.t, time.time())

        if self.t % self.log_period < self.update_period:
            if self.automatic_entropy_tuning:
                logger.add_scalar('alg/log_alpha',
                                  self.log_alpha.detach().cpu().numpy(),
                                  self.t, time.time())
                scalars = {
                    "target": self.target_entropy,
                    "entropy":
                    dist.entropy().mean().detach().cpu().numpy().item()
                }
                logger.add_scalars('alg/entropy', scalars, self.t, time.time())
            else:
                logger.add_scalar(
                    'alg/entropy',
                    dist.entropy().mean().detach().cpu().numpy().item(),
                    self.t, time.time())
            logger.add_scalar('loss/qf1', qf1_loss, self.t, time.time())
            logger.add_scalar('loss/qf2', qf2_loss, self.t, time.time())
            logger.add_scalar('alg/qf1',
                              q1.mean().detach().cpu().numpy(), self.t,
                              time.time())
            logger.add_scalar('alg/qf2',
                              q2.mean().detach().cpu().numpy(), self.t,
                              time.time())
        return pi_loss, qf1_loss, qf2_loss
예제 #3
0
파일: sac.py 프로젝트: takuma-ynd/dl
    def loss(self, batch):
        """Loss function."""
        pi_out = self.pi(batch['obs'], reparameterization_trick=self.rsample)
        if self.discrete:
            new_ac = pi_out.action
            ent = pi_out.dist.entropy()
        else:
            assert isinstance(pi_out.dist, TanhNormal), (
                "It is strongly encouraged that you use a TanhNormal "
                "action distribution for continuous action spaces.")
            if self.rsample:
                new_ac, new_pth_ac = pi_out.dist.rsample(
                                                    return_pretanh_value=True)
            else:
                new_ac, new_pth_ac = pi_out.dist.sample(
                                                    return_pretanh_value=True)
            logp = pi_out.dist.log_prob(new_ac, new_pth_ac)
        q1 = self.qf1(batch['obs'], batch['action']).value
        q2 = self.qf2(batch['obs'], batch['action']).value
        v = self.vf(batch['obs']).value

        # alpha loss
        if self.automatic_entropy_tuning:
            if self.discrete:
                ent_error = -ent + self.target_entropy
            else:
                ent_error = logp + self.target_entropy
            alpha_loss = -(self.log_alpha * ent_error.detach()).mean()
            self.opt_alpha.zero_grad()
            alpha_loss.backward()
            self.opt_alpha.step()
            alpha = self.log_alpha.exp()
        else:
            alpha = 1
            alpha_loss = 0

        # qf loss
        vtarg = self.target_vf(batch['next_obs']).value
        qtarg = self.reward_scale * batch['reward'].float() + (
                    (1.0 - batch['done']) * self.gamma * vtarg)
        assert qtarg.shape == q1.shape
        assert qtarg.shape == q2.shape
        qf1_loss = self.qf_criterion(q1, qtarg.detach())
        qf2_loss = self.qf_criterion(q2, qtarg.detach())

        # vf loss
        q1_outs = self.qf1(batch['obs'], new_ac)
        q1_new = q1_outs.value
        q2_new = self.qf2(batch['obs'], new_ac).value
        q = torch.min(q1_new, q2_new)
        if self.discrete:
            vtarg = q + alpha * ent
        else:
            vtarg = q - alpha * logp
        assert v.shape == vtarg.shape
        vf_loss = self.vf_criterion(v, vtarg.detach())

        # pi loss
        pi_loss = None
        if self.t % self.policy_update_period == 0:
            if self.discrete:
                target_dist = CatDist(logits=q1_outs.qvals.detach())
                pi_dist = CatDist(logits=alpha * pi_out.dist.logits)
                pi_loss = pi_dist.kl(target_dist).mean()
            else:
                if self.rsample:
                    assert q.shape == logp.shape
                    pi_loss = (alpha*logp - q1_new).mean()
                else:
                    pi_targ = q1_new - v
                    assert pi_targ.shape == logp.shape
                    pi_loss = (logp * (alpha * logp - pi_targ).detach()).mean()

                pi_loss += self.policy_mean_reg_weight * (
                                            pi_out.dist.normal.mean**2).mean()

            # log pi loss about as frequently as other losses
            if self.t % self.log_period < self.policy_update_period:
                logger.add_scalar('loss/pi', pi_loss, self.t, time.time())

        if self.t % self.log_period < self.update_period:
            if self.automatic_entropy_tuning:
                logger.add_scalar('ent/log_alpha',
                                  self.log_alpha.detach().cpu().numpy(), self.t,
                                  time.time())
                if self.discrete:
                    scalars = {"target": self.target_entropy,
                               "entropy": ent.mean().detach().cpu().numpy().item()}
                else:
                    scalars = {"target": self.target_entropy,
                               "entropy": -torch.mean(
                                            logp.detach()).cpu().numpy().item()}
                logger.add_scalars('ent/entropy', scalars, self.t, time.time())
            else:
                if self.discrete:
                    logger.add_scalar(
                            'ent/entropy',
                            ent.mean().detach().cpu().numpy().item(),
                            self.t, time.time())
                else:
                    logger.add_scalar(
                            'ent/entropy',
                            -torch.mean(logp.detach()).cpu().numpy().item(),
                            self.t, time.time())
            logger.add_scalar('loss/qf1', qf1_loss, self.t, time.time())
            logger.add_scalar('loss/qf2', qf2_loss, self.t, time.time())
            logger.add_scalar('loss/vf', vf_loss, self.t, time.time())
        return pi_loss, qf1_loss, qf2_loss, vf_loss
예제 #4
0
 def forward(self, ob):
     logits = np.random.rand(ob.shape[0], 2)
     v = np.random.rand(ob.shape[0], 1)
     return CatDist(torch.from_numpy(logits)), v