示例#1
0
 def get_action(self, belief, state, det=False, scale=None):
     action_mean, action_std = self.forward(belief, state)
     if scale:
         #exploration distribution
         dist = Normal(action_mean,
                       action_std + action_std.detach() * (1 - scale))
         dist = TransformedDistribution(dist, TanhBijector())
         dist = torch.distributions.Independent(dist, 1)
         dist = SampleDist(dist)
         action = dist.mode() if det else dist.rsample()
         proposal_loglike = dist.log_prob(action).detach()
         #true distribution
         dist = Normal(action_mean, action_std)
         dist = TransformedDistribution(dist, TanhBijector())
         dist = torch.distributions.Independent(dist, 1)
         dist = SampleDist(dist)
         policy_loglike = dist.log_prob(action)
         return action, policy_loglike, proposal_loglike
     else:
         dist = Normal(action_mean, action_std)
         dist = TransformedDistribution(dist, TanhBijector())
         dist = torch.distributions.Independent(dist, 1)
         dist = SampleDist(dist)
         action = dist.mode() if det else dist.rsample()
         return action
示例#2
0
 def get_action(self, belief, state, det=False):
     action_mean, action_std = self.forward(belief, state)
     dist = Normal(action_mean, action_std)
     dist = TransformedDistribution(dist, TanhBijector())
     dist = torch.distributions.Independent(dist, 1)
     dist = SampleDist(dist)
     if det: return dist.mode()
     else: return dist.rsample()