def get_action(self, belief, state, det=False, scale=None): action_mean, action_std = self.forward(belief, state) if scale: #exploration distribution dist = Normal(action_mean, action_std + action_std.detach() * (1 - scale)) dist = TransformedDistribution(dist, TanhBijector()) dist = torch.distributions.Independent(dist, 1) dist = SampleDist(dist) action = dist.mode() if det else dist.rsample() proposal_loglike = dist.log_prob(action).detach() #true distribution dist = Normal(action_mean, action_std) dist = TransformedDistribution(dist, TanhBijector()) dist = torch.distributions.Independent(dist, 1) dist = SampleDist(dist) policy_loglike = dist.log_prob(action) return action, policy_loglike, proposal_loglike else: dist = Normal(action_mean, action_std) dist = TransformedDistribution(dist, TanhBijector()) dist = torch.distributions.Independent(dist, 1) dist = SampleDist(dist) action = dist.mode() if det else dist.rsample() return action
def get_action(self, belief, state, det=False): action_mean, action_std = self.forward(belief, state) dist = Normal(action_mean, action_std) dist = TransformedDistribution(dist, TanhBijector()) dist = torch.distributions.Independent(dist, 1) dist = SampleDist(dist) if det: return dist.mode() else: return dist.rsample()