Exemple #1
0
    def forward(self, x):
        output = {}
        transform_output = self.transformer(x)
        output["transform"] = transform_output["transform"]
        output["transform_params"] = transform_output["params"]

        grid = coordinates.identity_grid([self.insize, self.insize],
                                         device=x.device)
        grid = grid.expand(x.shape[0], *grid.shape)

        transformed_grid = output["transform"][-1](grid)

        x = x - 0.222
        x = x / 0.156

        view = T.broadcasting_grid_sample(x, transformed_grid)

        split = self.net(view)
        split = self.relu(split)
        reshape = split.view(-1, self.linear_size)
        z_loc = self.loc(reshape)
        z_scale = torch.exp(self.scale(reshape))

        output["z_mu"] = z_loc
        output["z_std"] = z_scale

        return output, split
Exemple #2
0
    def model(self,
              data,
              transforms=None):

        output_size = self.encoder.insize
        decoder = pyro.module("decoder", self.decoder)
        # decoder takes z and std in the transformed coordinate frame
        # and the theta
        # and outputs an upright image
        with pyro.plate(data.shape[0]):
            # prior for z
            z = pyro.sample(
                "z",
                D.Normal(
                    torch.zeros(decoder.z_dim, device=data.device),
                    torch.ones(decoder.z_dim, device=data.device),
                ).to_event(1),
            )

            # given a z, the decoder produces an "image"
            # this image must be transformed from the self consistent basis
            # to real world basis
            # first, z and std for the self consistent basis is outputted
            # then it is transfomed
            view = decoder(z)

            pyro.deterministic("canonical_view", view)
            # pyro.deterministic
            # is like pyro.sample but it is deterministic...?

            # all of this is completely independent of the input
            # maybe this is the "prior for the transformation"
            # and hence it looks completely independent of the input
            # but when the model is run again, these variables are replayed
            # with the theta generated by the guide
            # makes sense
            # so the model replays with theta and mu and sigma generated by
            # the guide,
            # taking theta and mu sigma and applying the inverse transform
            # to get the output image.
            grid = coordinates.identity_grid(
                [output_size, output_size], device=data.device)
            grid = grid.expand(data.shape[0], *grid.shape)
            transforms = T.TransformSequence(T.Translation(), T.Rotation())
            transform = random_pose_transform(transforms)

            transform_grid = transform(grid)

            # output from decoder is transormed in do a different coordinate system

            transformed_view = T.broadcasting_grid_sample(view, transform_grid)

            # view from decoder outputs an image
            pyro.sample(
                "pixels", D.Bernoulli(transformed_view).to_event(3), obs=data)