def forward(self, inputs, with_log_pi=True, deterministic=False): action_size = inputs.shape[1] // 2 mu, log_sigma = inputs[:, :action_size], inputs[:, action_size:] log_sigma = torch.clamp(log_sigma, LOG_SIG_MIN, LOG_SIG_MAX) sigma = torch.exp(log_sigma) z = mu if deterministic else normal_sample(mu, sigma) log_pi = normal_log_prob(mu, sigma, z) action, log_pi = self.squashing_layer.forward(z, log_pi) if with_log_pi: return action, log_pi return action
def forward(self, inputs, with_log_pi=True): state_embedding = inputs mu = torch.zeros((state_embedding.shape[0], self.action_size)).to(state_embedding.device) sigma = torch.ones_like(mu).to(mu.device) z = normal_sample(mu, sigma) log_pi = normal_log_prob(mu, sigma, z) z, log_pi = self.coupling1.forward(z, state_embedding, log_pi) z, log_pi = self.coupling2.forward(z, state_embedding, log_pi) action, log_pi = self.squashing_layer.forward(z, log_pi) if with_log_pi: return action, log_pi return action