def transforming_vae_mnist( data, transforms, decoder=None, cond=True, grid_size=28, **kwargs ): decoder = pyro.module("decoder", decoder) with pyro.plate(data.shape[0]): # firsnt attempt - literally try to learn a template to check # we got the tranformation logic correct template = pyro.param( "template", torch.rand(1, 32, 32, device=data.device), constraint=constraints.unit_interval, ) grid = torch.stack( torch.meshgrid( torch.linspace(-1, 1, grid_size, device=data.device), torch.linspace(-1, 1, grid_size, device=data.device), ), -1, ) grid = grid.expand(data.shape[0], *grid.shape) transform = random_pose_transform(transforms) transform_grid = transform(grid) transformed_template = T.broadcasting_grid_sample( template.expand(data.shape[0], 1, 32, 32), transform_grid ) obs = data if cond else None pyro.sample("pixels", D.Bernoulli(transformed_template).to_event(3), obs=obs)
def forward_model( data, transforms=None, cond=True, decoder=None, output_size=40, device=torch.device("cpu"), **kwargs ): decoder = pyro.module("view_decoder", decoder) with pyro.plate(data.shape[0]): z = pyro.sample( "z", D.Normal( torch.zeros(decoder.latent_dim, device=device), torch.ones(decoder.latent_dim, device=device), ).to_event(1), ) view = decoder(z) pyro.deterministic("canonical_view", view) grid = coordinates.identity_grid([output_size, output_size], device=device) grid = grid.expand(data.shape[0], *grid.shape) transform = random_pose_transform(transforms) transform_grid = transform(grid) transformed_view = T.broadcasting_grid_sample(view, transform_grid) obs = data if cond else None pyro.sample("pixels", D.Bernoulli(transformed_view).to_event(3), obs=obs)
def forward_model(data, label, N=-1, transforms=None, instantiate_label=False, cond_label=True, cond=True, decoder=None, latent_decoder=None, output_size=40, device=torch.device("cpu"), **kwargs): decoder = pyro.module("view_decoder", decoder) with pyro.plate("batch", N): z = pyro.sample( "z", D.Normal( torch.zeros(N, decoder.latent_dim, device=device), torch.ones(N, decoder.latent_dim, device=device), ).to_event(1), ) # use supervision if instantiate_label: latent_decoder = pyro.module("latent_decoder", latent_decoder) label_logits = latent_decoder(z) obs_label = label if cond_label else None pyro.sample("y", D.Categorical(logits=label_logits), obs=obs_label) view = decoder(z) pyro.deterministic("canonical_view", view) grid = coordinates.identity_grid([output_size, output_size], device=device) grid = grid.expand(N, *grid.shape) transform = random_pose_transform(transforms, device=device) transform_grid = transform(grid) transformed_view = T.broadcasting_grid_sample(view, transform_grid) obs = data if cond else None pyro.sample("pixels", D.Bernoulli(transformed_view).to_event(3), obs=obs)
def forward_model( data, transforms=None, instantiate_label=False, cond=True, decoder=None, output_size=128, device=torch.device("cpu"), kl_beta=1.0, **kwargs, ): decoder = pyro.module("view_decoder", decoder) N = data.shape[0] with poutine.scale_messenger.ScaleMessenger(1 / N): with pyro.plate("batch", N): with poutine.scale_messenger.ScaleMessenger(kl_beta): z = pyro.sample( "z", D.Normal( torch.zeros(N, decoder.latent_dim, device=device), torch.ones(N, decoder.latent_dim, device=device), ).to_event(1), ) # use supervision view = decoder(z) pyro.deterministic("canonical_view", view) grid = coordinates.identity_grid([output_size, output_size], device=device) grid = grid.expand(N, *grid.shape) scale = view.shape[-1] / output_size grid = grid * ( 1 / scale ) # rescales the image co-ordinates so one pixel of the recon corresponds to 1 pixel of the view. transform = random_pose_transform(transforms, device=device) transform_grid = transform(grid) transformed_view = T.broadcasting_grid_sample(view, transform_grid) obs = data if cond else None pyro.sample("pixels", D.Laplace(transformed_view, 0.5).to_event(3), obs=obs)
def forward(self, x): output = {} transform_output = self.transformer(x) output["transform"] = transform_output["transform"] output["transform_params"] = transform_output["params"] grid = coordinates.identity_grid([self.input_size, self.input_size], device=x.device) grid = grid.expand(x.shape[0], *grid.shape) transformed_grid = output["transform"][-1](grid) view = T.broadcasting_grid_sample(x, transformed_grid) out = self.view_encoder(view) z_mu, z_std = torch.split(out, self.latent_dim, -1) output["z_mu"] = z_mu output["z_std"] = z_std return output
def forward(self, x): output = {} # apply normalisation. # don't want to do this in the dataloader because we want values in [0,1] for the bernoulli loss x = x - 0.2222 x = x / 0.156 transform_output = self.transformer(x) output["transform"] = transform_output["transform"] output["transform_params"] = transform_output["params"] grid = coordinates.identity_grid([128, 128], device=x.device) grid = grid.expand(x.shape[0], *grid.shape) transformed_grid = output["transform"][-1](grid) view = T.broadcasting_grid_sample(x, transformed_grid) out = self.view_encoder(view) z_mu, z_std = torch.split(out, self.latent_dim, -1) output["z_mu"] = z_mu output["z_std"] = z_std output["view"] = view return output