def _inverse(self, y): """ :param y: the output of the bijection :type y: torch.Tensor Inverts y => x. Uses a previously cached inverse if available, otherwise performs the inversion afresh. """ y1, y2 = y.split([self.split_dim, y.size(self.dim) - self.split_dim], dim=self.dim) x1 = y1 # Now that we can split on an arbitrary dimension, we have do a bit of reshaping... mean, log_scale = self.nn(x1.reshape(x1.shape[:-self.event_dim] + (-1,))) mean = mean.reshape(mean.shape[:-1] + y2.shape[-self.event_dim:]) log_scale = log_scale.reshape(log_scale.shape[:-1] + y2.shape[-self.event_dim:]) log_scale = clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip) self._cached_log_scale = log_scale x2 = (y2 - mean) * torch.exp(-log_scale) return torch.cat([x1, x2], dim=self.dim)
def _call(self, x): """ :param x: the input into the bijection :type x: torch.Tensor Invokes the bijection x=>y; in the prototypical context of a :class:`~pyro.distributions.TransformedDistribution` `x` is a sample from the base distribution (or the output of a previous transform) """ x1, x2 = x.split([self.split_dim, x.size(self.dim) - self.split_dim], dim=self.dim) # Now that we can split on an arbitrary dimension, we have do a bit of reshaping... mean, log_scale = self.nn(x1.reshape(x1.shape[:-self.event_dim] + (-1,))) mean = mean.reshape(mean.shape[:-1] + x2.shape[-self.event_dim:]) log_scale = log_scale.reshape(log_scale.shape[:-1] + x2.shape[-self.event_dim:]) log_scale = clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip) self._cached_log_scale = log_scale y1 = x1 y2 = torch.exp(log_scale) * x2 + mean return torch.cat([y1, y2], dim=self.dim)