コード例 #1
0
    def compute_loss_q(self, ob_no, ac_na, re_n, next_ob_no, terminal_n):
        q1 = self.ac.q1(ob_no, ac_na)
        q2 = self.ac.q2(ob_no, ac_na)

        # Bellman backup for Q functions
        with torch.no_grad():
            # Target actions come from *current* policy
            a2, logp_a2 = self.ac.pi(next_ob_no)

            # Target Q-values
            q1_pi_targ = self.ac_target.q1(next_ob_no, a2)
            q2_pi_targ = self.ac_target.q2(next_ob_no, a2)
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            backup = re_n + self.gamma * (1 - terminal_n) * (q_pi_targ - self.alpha * logp_a2)

        # MSE loss against Bellman backup
        loss_q1 = ((q1 - backup)**2).mean()
        loss_q2 = ((q2 - backup)**2).mean()
        loss_q = loss_q1 + loss_q2

        # Useful info for logging
        q_info = dict(Q1Vals=ptu.to_numpy(q1),
                      Q2Vals=ptu.to_numpy(q2))

        return loss_q, q_info
コード例 #2
0
    def eval(self, paths, **kwargs):
        """
        Return bonus
        """
        if self.score_discrim:
            obs, obs_next, acts, path_probs = self.extract_paths(
                paths,
                keys=('observation', 'next_observation', 'action', 'a_logprob'
                      ))  # log prob should already be calculated by fit
            path_probs = np.expand_dims(path_probs, axis=1)

            obs = from_numpy(obs)
            obs_next = from_numpy(obs_next)
            acts = from_numpy(acts)
            path_probs = from_numpy(path_probs)

            scores = self.discrim(obs, acts, obs_next, path_probs)[0]
            score = np.log(scores) - np.log(1 - scores)
            score = score[:, 0]
        else:
            obs, acts = self.extract_paths(paths)
            obs = from_numpy(obs)
            acts = from_numpy(acts)
            reward = self.discrim.get_reward(obs, acts)
            score = reward[:, 0]
        return to_numpy(score)
コード例 #3
0
    def compute_loss_pi(self, ob_no):
        pi, logp_pi = self.ac.pi(ob_no)
        q1_pi = self.ac.q1(ob_no, pi)
        q2_pi = self.ac.q2(ob_no, pi)
        q_pi = torch.min(q1_pi, q2_pi)

        # Entropy-regularized policy loss
        loss_pi = (self.alpha * logp_pi - q_pi).mean()

        # Useful info for logging
        pi_info = dict(LogPi=ptu.to_numpy(logp_pi))

        return loss_pi, pi_info
コード例 #4
0
 def eval_prob(self, paths, insert_key='a_logprob'):
     for path in paths:
         obs = ptu.from_numpy(path['observation'])
         ac = ptu.from_numpy(path['action'])
         path_probs = ptu.to_numpy(self.ac.pi.get_prob(obs, ac))
         path[insert_key] = path_probs
コード例 #5
0
 def eval_single(self, obs):
     obs = from_numpy(obs)
     reward = self.discrim.get_reward(obs, None)
     print(reward.shape)
     # score = reward[:, 0]
     return to_numpy(score)
コード例 #6
0
 def _get_action(self, obs, deterministic=False):
     with torch.no_grad():
         a, _ = self.pi(obs, deterministic, False)
         return ptu.to_numpy(a)