def __init__(self, action_size, layer_fn, activation_fn=nn.ReLU, squashing_fn=nn.Tanh, bias=False): super().__init__() activation_fn = MODULES.get_if_str(activation_fn) self.action_size = action_size self.coupling1 = CouplingLayer(action_size=action_size, layer_fn=layer_fn, activation_fn=activation_fn, bias=bias, parity="odd") self.coupling2 = CouplingLayer(action_size=action_size, layer_fn=layer_fn, activation_fn=activation_fn, bias=bias, parity="even") self.squashing_layer = SquashingLayer(squashing_fn)
class GaussPolicy(nn.Module): def __init__(self, squashing_fn=nn.Tanh): super().__init__() self.squashing_layer = SquashingLayer(squashing_fn) def forward(self, inputs, logprob=False, deterministic=False): action_size = inputs.shape[1] // 2 mu, log_sigma = inputs[:, :action_size], inputs[:, action_size:] log_sigma = torch.clamp(log_sigma, LOG_SIG_MIN, LOG_SIG_MAX) sigma = torch.exp(log_sigma) z = mu if deterministic else normal_sample(mu, sigma) log_pi = normal_logprob(mu, sigma, z) action, log_pi = self.squashing_layer.forward(z, log_pi) if logprob: return action, log_pi return action
class RealNVPPolicy(nn.Module): def __init__(self, action_size, layer_fn, activation_fn=nn.ReLU, squashing_fn=nn.Tanh, bias=False): super().__init__() activation_fn = MODULES.get_if_str(activation_fn) self.action_size = action_size self.coupling1 = CouplingLayer(action_size=action_size, layer_fn=layer_fn, activation_fn=activation_fn, bias=bias, parity="odd") self.coupling2 = CouplingLayer(action_size=action_size, layer_fn=layer_fn, activation_fn=activation_fn, bias=bias, parity="even") self.squashing_layer = SquashingLayer(squashing_fn) def forward(self, logits, logprob=None, deterministic=False): state_embedding = logits loc = torch.zeros((state_embedding.shape[0], self.action_size)).to(state_embedding.device) scale = torch.ones_like(loc).to(loc.device) action = loc if deterministic else normal_sample(loc, scale) bool_logprob = isinstance(logprob, bool) and logprob value_logprob = isinstance(logprob, torch.Tensor) assert not value_logprob, "Not implemented behaviour" action_logprob = normal_logprob(loc, scale, action) action, action_logprob = \ self.coupling1.forward(action, state_embedding, action_logprob) action, action_logprob = \ self.coupling2.forward(action, state_embedding, action_logprob) action, action_logprob = \ self.squashing_layer.forward(action, action_logprob) if bool_logprob: return action, action_logprob else: return action
class RealNVPPolicy(nn.Module): def __init__(self, action_size, layer_fn, activation_fn=nn.ReLU, squashing_fn=nn.Tanh, bias=False): super().__init__() activation_fn = MODULES.get_if_str(activation_fn) self.action_size = action_size self.coupling1 = CouplingLayer(action_size=action_size, layer_fn=layer_fn, activation_fn=activation_fn, bias=bias, parity="odd") self.coupling2 = CouplingLayer(action_size=action_size, layer_fn=layer_fn, activation_fn=activation_fn, bias=bias, parity="even") self.squashing_layer = SquashingLayer(squashing_fn) def forward(self, inputs, logprob=False, deterministic=False): state_embedding = inputs mu = torch.zeros((state_embedding.shape[0], self.action_size)).to(state_embedding.device) sigma = torch.ones_like(mu).to(mu.device) z = mu if deterministic else normal_sample(mu, sigma) log_pi = normal_logprob(mu, sigma, z) z, log_pi = self.coupling1.forward(z, state_embedding, log_pi) z, log_pi = self.coupling2.forward(z, state_embedding, log_pi) action, log_pi = self.squashing_layer.forward(z, log_pi) if logprob: return action, log_pi return action
class SquashingGaussPolicy(nn.Module): def __init__(self, squashing_fn=nn.Tanh): super().__init__() self.squashing_layer = SquashingLayer(squashing_fn) def forward(self, logits, logprob=None, deterministic=False): action_size = logits.shape[1] // 2 loc, log_scale = logits[:, :action_size], logits[:, action_size:] log_scale = torch.clamp(log_scale, LOG_SIG_MIN, LOG_SIG_MAX) scale = torch.exp(log_scale) action = loc if deterministic else normal_sample(loc, scale) bool_logprob = isinstance(logprob, bool) and logprob value_logprob = isinstance(logprob, torch.Tensor) assert not value_logprob, "Not implemented behaviour" action_logprob = normal_logprob(loc, scale, action) action, action_logprob = \ self.squashing_layer.forward(action, action_logprob) if bool_logprob: return action, action_logprob else: return action
def __init__(self, squashing_fn=nn.Tanh): super().__init__() self.squashing_layer = SquashingLayer(squashing_fn)