Ejemplo n.º 1
0
    def _inverse(self, y):
        """
        :param y: the output of the bijection
        :type y: torch.Tensor

        Inverts y => x. Uses a previously cached inverse if available, otherwise
        performs the inversion afresh.
        """
        y1, y2 = y.split([self.split_dim, y.size(self.dim) - self.split_dim], dim=self.dim)
        x1 = y1

        # Now that we can split on an arbitrary dimension, we have do a bit of reshaping...
        mean, log_scale = self.nn(x1.reshape(x1.shape[:-self.event_dim] + (-1,)))
        mean = mean.reshape(mean.shape[:-1] + y2.shape[-self.event_dim:])
        log_scale = log_scale.reshape(log_scale.shape[:-1] + y2.shape[-self.event_dim:])

        log_scale = clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip)
        self._cached_log_scale = log_scale

        x2 = (y2 - mean) * torch.exp(-log_scale)
        return torch.cat([x1, x2], dim=self.dim)
Ejemplo n.º 2
0
    def _call(self, x):
        """
        :param x: the input into the bijection
        :type x: torch.Tensor

        Invokes the bijection x=>y; in the prototypical context of a
        :class:`~pyro.distributions.TransformedDistribution` `x` is a sample from
        the base distribution (or the output of a previous transform)
        """
        x1, x2 = x.split([self.split_dim, x.size(self.dim) - self.split_dim], dim=self.dim)

        # Now that we can split on an arbitrary dimension, we have do a bit of reshaping...
        mean, log_scale = self.nn(x1.reshape(x1.shape[:-self.event_dim] + (-1,)))
        mean = mean.reshape(mean.shape[:-1] + x2.shape[-self.event_dim:])
        log_scale = log_scale.reshape(log_scale.shape[:-1] + x2.shape[-self.event_dim:])

        log_scale = clamp_preserve_gradients(log_scale, self.log_scale_min_clip, self.log_scale_max_clip)
        self._cached_log_scale = log_scale

        y1 = x1
        y2 = torch.exp(log_scale) * x2 + mean
        return torch.cat([y1, y2], dim=self.dim)