Exemple #1
0
    def __init__(self,
                 data_dim,
                 latent_dim,
                 encoder,
                 outer_transform,
                 inner_transform=None,
                 pie_epsilon=1.0e-2,
                 apply_context_to_outer=True):
        super(EncoderManifoldFlow, self).__init__()

        assert latent_dim < data_dim

        self.data_dim = data_dim
        self.latent_dim = latent_dim
        self.total_data_dim = product(data_dim)
        self.total_latent_dim = product(latent_dim)
        self.apply_context_to_outer = apply_context_to_outer

        self.manifold_latent_distribution = distributions.StandardNormal(
            (self.total_latent_dim, ))
        self.orthogonal_latent_distribution = distributions.RescaledNormal(
            (self.total_data_dim - self.total_latent_dim, ),
            std=pie_epsilon,
            clip=5.0 * pie_epsilon)
        self.projection = ProjectionSplit(self.total_data_dim,
                                          self.total_latent_dim)

        self.encoder = encoder
        self.outer_transform = outer_transform
        if inner_transform is None:
            self.inner_transform = transforms.IdentityTransform()
        else:
            self.inner_transform = inner_transform

        self._report_model_parameters()
Exemple #2
0
class EncoderManifoldFlow(BaseFlow):
    """ Manifold-based flow with separate encoder (for MFMFE) """
    def __init__(self,
                 data_dim,
                 latent_dim,
                 encoder,
                 outer_transform,
                 inner_transform=None,
                 pie_epsilon=1.0e-2,
                 apply_context_to_outer=True):
        super(EncoderManifoldFlow, self).__init__()

        assert latent_dim < data_dim

        self.data_dim = data_dim
        self.latent_dim = latent_dim
        self.total_data_dim = product(data_dim)
        self.total_latent_dim = product(latent_dim)
        self.apply_context_to_outer = apply_context_to_outer

        self.manifold_latent_distribution = distributions.StandardNormal(
            (self.total_latent_dim, ))
        self.orthogonal_latent_distribution = distributions.RescaledNormal(
            (self.total_data_dim - self.total_latent_dim, ),
            std=pie_epsilon,
            clip=5.0 * pie_epsilon)
        self.projection = ProjectionSplit(self.total_data_dim,
                                          self.total_latent_dim)

        self.encoder = encoder
        self.outer_transform = outer_transform
        if inner_transform is None:
            self.inner_transform = transforms.IdentityTransform()
        else:
            self.inner_transform = inner_transform

        self._report_model_parameters()

    def forward(self, x, mode="mf", context=None):
        """
        Transforms data point to latent space, evaluates likelihood, and transforms it back to data space.

        mode can be "mf" (calculating the exact manifold density based on the full Jacobian), "pie" (calculating the density in x), "slice"
        (calculating the density on x, but projected onto the manifold), or "projection" (calculating no density at all).
        """

        assert mode in ["mf", "projection", "pie", "mf-fixed-manifold"]

        if mode == "mf" and not x.requires_grad:
            x.requires_grad = True

        # Encode
        u, h_manifold, log_det_inner = self._encode(x, context)

        # Decode
        x_reco, inv_log_det_inner, inv_log_det_outer, inv_jacobian_outer, h_manifold_reco = self._decode(
            u, mode=mode, context=context)

        # Log prob
        log_prob = self._log_prob(mode, u, log_det_inner, inv_log_det_inner,
                                  inv_log_det_outer, inv_jacobian_outer)

        return x_reco, log_prob, u

    def encode(self, x, context=None):
        """ Transforms data point to latent space. """

        u, _, _ = self._encode(x, context=context)
        return u

    def decode(self, u, u_orthogonal=None, context=None):
        """ Decodes latent variable to data space."""

        x, _, _, _, _ = self._decode(u,
                                     mode="projection",
                                     u_orthogonal=u_orthogonal,
                                     context=context)
        return x

    def log_prob(self, x, mode="mf", context=None):
        """ Evaluates log likelihood for given data point."""

        return self.forward(x, mode, context)[1]

    def sample(self, u=None, n=1, context=None, sample_orthogonal=False):
        """
        Generates samples from model.

        Note: this is PIE / MF sampling! Cannot sample from slice of PIE efficiently.
        """

        if u is None:
            u = self.manifold_latent_distribution.sample(n, context=None)
        u_orthogonal = self.orthogonal_latent_distribution.sample(
            n, context=None) if sample_orthogonal else None
        x = self.decode(u, u_orthogonal=u_orthogonal, context=context)
        return x

    def _encode(self, x, context=None):
        # Encode
        h_manifold = self.encoder(
            x, context=context if self.apply_context_to_outer else None)
        u, log_det_inner = self.inner_transform(h_manifold,
                                                full_jacobian=False,
                                                context=context)

        return u, h_manifold, log_det_inner

    def _decode(self, u, mode, u_orthogonal=None, context=None):
        if mode == "mf" and not u.requires_grad:
            u.requires_grad = True

        h, inv_log_det_inner = self.inner_transform.inverse(
            u, full_jacobian=False, context=context)

        if u_orthogonal is not None:
            h = self.projection.inverse(h, orthogonal_inputs=u_orthogonal)
        else:
            h = self.projection.inverse(h)

        if mode in ["pie", "slice", "projection", "mf-fixed-manifold"]:
            x, inv_log_det_outer = self.outer_transform.inverse(
                h,
                full_jacobian=False,
                context=context if self.apply_context_to_outer else None)
            inv_jacobian_outer = None
        else:
            x, inv_jacobian_outer = self.outer_transform.inverse(
                h,
                full_jacobian=True,
                context=context if self.apply_context_to_outer else None)
            inv_log_det_outer = None

        return x, inv_log_det_inner, inv_log_det_outer, inv_jacobian_outer, h

    def _log_prob(self, mode, u, log_det_inner, inv_log_det_inner,
                  inv_log_det_outer, inv_jacobian_outer):
        if mode == "mf":
            # inv_jacobian_outer is dx / du, but still need to restrict this to the manifold latents
            inv_jacobian_outer = inv_jacobian_outer[:, :, :self.latent_dim]
            # And finally calculate log det (J^T J)
            jtj = torch.bmm(torch.transpose(inv_jacobian_outer, -2, -1),
                            inv_jacobian_outer)

            log_prob = self.manifold_latent_distribution._log_prob(
                u, context=None)
            log_prob = log_prob - 0.5 * torch.slogdet(
                jtj)[1] - inv_log_det_inner

        elif mode == "mf-fixed-manifold":
            log_prob = self.manifold_latent_distribution._log_prob(
                u, context=None)
            log_prob = log_prob - inv_log_det_inner

        else:
            log_prob = None

        return log_prob