예제 #1
0
def transforming_vae_mnist(
    data, transforms, decoder=None, cond=True, grid_size=28, **kwargs
):
    decoder = pyro.module("decoder", decoder)
    with pyro.plate(data.shape[0]):
        # firsnt attempt - literally try to learn a template to check
        # we got the tranformation logic correct
        template = pyro.param(
            "template",
            torch.rand(1, 32, 32, device=data.device),
            constraint=constraints.unit_interval,
        )

        grid = torch.stack(
            torch.meshgrid(
                torch.linspace(-1, 1, grid_size, device=data.device),
                torch.linspace(-1, 1, grid_size, device=data.device),
            ),
            -1,
        )
        grid = grid.expand(data.shape[0], *grid.shape)

        transform = random_pose_transform(transforms)

        transform_grid = transform(grid)

        transformed_template = T.broadcasting_grid_sample(
            template.expand(data.shape[0], 1, 32, 32), transform_grid
        )
        obs = data if cond else None
        pyro.sample("pixels", D.Bernoulli(transformed_template).to_event(3), obs=obs)
예제 #2
0
def forward_model(
    data,
    transforms=None,
    cond=True,
    decoder=None,
    output_size=40,
    device=torch.device("cpu"),
    **kwargs
):
    decoder = pyro.module("view_decoder", decoder)
    with pyro.plate(data.shape[0]):

        z = pyro.sample(
            "z",
            D.Normal(
                torch.zeros(decoder.latent_dim, device=device),
                torch.ones(decoder.latent_dim, device=device),
            ).to_event(1),
        )

        view = decoder(z)

        pyro.deterministic("canonical_view", view)

        grid = coordinates.identity_grid([output_size, output_size], device=device)
        grid = grid.expand(data.shape[0], *grid.shape)

        transform = random_pose_transform(transforms)

        transform_grid = transform(grid)

        transformed_view = T.broadcasting_grid_sample(view, transform_grid)
        obs = data if cond else None
        pyro.sample("pixels", D.Bernoulli(transformed_view).to_event(3), obs=obs)
예제 #3
0
def forward_model(data,
                  label,
                  N=-1,
                  transforms=None,
                  instantiate_label=False,
                  cond_label=True,
                  cond=True,
                  decoder=None,
                  latent_decoder=None,
                  output_size=40,
                  device=torch.device("cpu"),
                  **kwargs):
    decoder = pyro.module("view_decoder", decoder)
    with pyro.plate("batch", N):

        z = pyro.sample(
            "z",
            D.Normal(
                torch.zeros(N, decoder.latent_dim, device=device),
                torch.ones(N, decoder.latent_dim, device=device),
            ).to_event(1),
        )

        # use supervision
        if instantiate_label:
            latent_decoder = pyro.module("latent_decoder", latent_decoder)
            label_logits = latent_decoder(z)
            obs_label = label if cond_label else None
            pyro.sample("y", D.Categorical(logits=label_logits), obs=obs_label)

        view = decoder(z)

        pyro.deterministic("canonical_view", view)

        grid = coordinates.identity_grid([output_size, output_size],
                                         device=device)
        grid = grid.expand(N, *grid.shape)

        transform = random_pose_transform(transforms, device=device)

        transform_grid = transform(grid)

        transformed_view = T.broadcasting_grid_sample(view, transform_grid)
        obs = data if cond else None
        pyro.sample("pixels",
                    D.Bernoulli(transformed_view).to_event(3),
                    obs=obs)
예제 #4
0
def forward_model(
        data,
        transforms=None,
        instantiate_label=False,
        cond=True,
        decoder=None,
        output_size=128,
        device=torch.device("cpu"),
        kl_beta=1.0,
        **kwargs,
):
    decoder = pyro.module("view_decoder", decoder)
    N = data.shape[0]
    with poutine.scale_messenger.ScaleMessenger(1 / N):
        with pyro.plate("batch", N):
            with poutine.scale_messenger.ScaleMessenger(kl_beta):
                z = pyro.sample(
                    "z",
                    D.Normal(
                        torch.zeros(N, decoder.latent_dim, device=device),
                        torch.ones(N, decoder.latent_dim, device=device),
                    ).to_event(1),
                )

            # use supervision

            view = decoder(z)

            pyro.deterministic("canonical_view", view)

            grid = coordinates.identity_grid([output_size, output_size],
                                             device=device)
            grid = grid.expand(N, *grid.shape)
            scale = view.shape[-1] / output_size
            grid = grid * (
                1 / scale
            )  # rescales the image co-ordinates so one pixel of the recon corresponds to 1 pixel of the view.

            transform = random_pose_transform(transforms, device=device)

            transform_grid = transform(grid)

            transformed_view = T.broadcasting_grid_sample(view, transform_grid)
            obs = data if cond else None
            pyro.sample("pixels",
                        D.Laplace(transformed_view, 0.5).to_event(3),
                        obs=obs)
예제 #5
0
    def forward(self, x):
        output = {}
        transform_output = self.transformer(x)
        output["transform"] = transform_output["transform"]
        output["transform_params"] = transform_output["params"]

        grid = coordinates.identity_grid([self.input_size, self.input_size],
                                         device=x.device)
        grid = grid.expand(x.shape[0], *grid.shape)

        transformed_grid = output["transform"][-1](grid)
        view = T.broadcasting_grid_sample(x, transformed_grid)
        out = self.view_encoder(view)
        z_mu, z_std = torch.split(out, self.latent_dim, -1)
        output["z_mu"] = z_mu
        output["z_std"] = z_std

        return output
예제 #6
0
    def forward(self, x):
        output = {}
        # apply normalisation.
        # don't want to do this in the dataloader because we want values in [0,1] for the bernoulli loss
        x = x - 0.2222
        x = x / 0.156
        transform_output = self.transformer(x)
        output["transform"] = transform_output["transform"]
        output["transform_params"] = transform_output["params"]

        grid = coordinates.identity_grid([128, 128], device=x.device)
        grid = grid.expand(x.shape[0], *grid.shape)

        transformed_grid = output["transform"][-1](grid)
        view = T.broadcasting_grid_sample(x, transformed_grid)
        out = self.view_encoder(view)
        z_mu, z_std = torch.split(out, self.latent_dim, -1)
        output["z_mu"] = z_mu
        output["z_std"] = z_std
        output["view"] = view
        return output