def forward(self, batch):
        images = batch['images']
        y = batch['targets']
        device = trw.train.get_device(self)
        y_one_hot = one_hot(y, self.y_size).to(device)
        recon, mu, logvar = self.autoencoder.forward(images, y_one_hot)

        with torch.no_grad():
            random_recon_y = one_hot(y, self.y_size)  # keep same distribution as data
            random_recon = self.autoencoder.sample_given_y(random_recon_y)

        loss = AutoencoderConvolutionalVariational.loss_function(recon, images, mu, logvar, kullback_leibler_weight=0.1)
        return {
            'loss': trw.train.OutputLoss(loss),
            'recon': trw.train.OutputEmbedding(recon),
            'random_recon': trw.train.OutputEmbedding(random_recon)
        }
Exemple #2
0
    def forward(self, batch, latent):
        digits = batch['targets']
        real = batch['images']
        assert len(digits.shape) == 1

        # introduce the target as one hot encoding input to the generator
        digits_one_hot = one_hot(digits, self.nb_digits).unsqueeze(2).unsqueeze(3)
        latent = latent.unsqueeze(2).unsqueeze(3)

        full_latent = torch.cat((digits_one_hot, latent), dim=1)
        o = self.convs_t(full_latent)
        return o, collections.OrderedDict([
            ('image', OutputEmbedding(o)),
            #('l1', OutputLoss(0.5 * torch.nn.L1Loss()(o, real)))  # on average, the generated & real should match
        ])
Exemple #3
0
    def forward(self, batch, image, is_real):
        digits = batch['targets']

        # introduce the target as one hot encoding input to the discriminator
        input_class = torch.ones(
            [image.shape[0], self.nb_digits, image.shape[2], image.shape[3]],
            device=image.device) * one_hot(digits, 10).unsqueeze(2).unsqueeze(3)
        o = self.convs(torch.cat((image, input_class), dim=1))
        o_expected = int(is_real) * torch.ones(len(image), device=image.device, dtype=torch.long)

        return {
            'classification': OutputClassification2(
                o, o_expected,
                criterion_fn=LossMsePacked,  # LSGan loss function
            )
        }
Exemple #4
0
                                                        batch_size=1024),
    run_prefix='mnist_autoencoder_variational_conditional',
    model_fn=lambda options: Net(),
    optimizers_fn=lambda datasets, model: trw.train.
    create_adam_optimizers_scheduler_step_lr_fn(datasets=datasets,
                                                model=model,
                                                learning_rate=0.001,
                                                step_size=120,
                                                gamma=0.1))

model.training = False
nb_images = 40

device = trw.train.get_device(model)
latent = torch.randn([nb_images, model.latent_size], device=device)
y = one_hot(torch.ones([nb_images], dtype=torch.long, device=device) * 7, 10)
latent_y = torch.cat([latent, y], dim=1)
latent_y = latent_y.view(latent_y.shape[0], latent_y.shape[1], 1, 1)
generated = model.autoencoder.decoder(latent_y)

fig, axes = plt.subplots(nrows=1,
                         ncols=nb_images,
                         figsize=(nb_images, 2.5),
                         sharey=True)
decoded_images = crop_or_pad_fun(generated, [28, 28])
image_width = decoded_images.shape[2]

for ax, img in zip(axes, decoded_images):
    curr_img = img.detach().to(torch.device('cpu'))
    ax.imshow(curr_img.view((image_width, image_width)), cmap='binary')