def forward(self, batch): images = batch['images'] y = batch['targets'] device = trw.train.get_device(self) y_one_hot = one_hot(y, self.y_size).to(device) recon, mu, logvar = self.autoencoder.forward(images, y_one_hot) with torch.no_grad(): random_recon_y = one_hot(y, self.y_size) # keep same distribution as data random_recon = self.autoencoder.sample_given_y(random_recon_y) loss = AutoencoderConvolutionalVariational.loss_function(recon, images, mu, logvar, kullback_leibler_weight=0.1) return { 'loss': trw.train.OutputLoss(loss), 'recon': trw.train.OutputEmbedding(recon), 'random_recon': trw.train.OutputEmbedding(random_recon) }
def forward(self, batch, latent): digits = batch['targets'] real = batch['images'] assert len(digits.shape) == 1 # introduce the target as one hot encoding input to the generator digits_one_hot = one_hot(digits, self.nb_digits).unsqueeze(2).unsqueeze(3) latent = latent.unsqueeze(2).unsqueeze(3) full_latent = torch.cat((digits_one_hot, latent), dim=1) o = self.convs_t(full_latent) return o, collections.OrderedDict([ ('image', OutputEmbedding(o)), #('l1', OutputLoss(0.5 * torch.nn.L1Loss()(o, real))) # on average, the generated & real should match ])
def forward(self, batch, image, is_real): digits = batch['targets'] # introduce the target as one hot encoding input to the discriminator input_class = torch.ones( [image.shape[0], self.nb_digits, image.shape[2], image.shape[3]], device=image.device) * one_hot(digits, 10).unsqueeze(2).unsqueeze(3) o = self.convs(torch.cat((image, input_class), dim=1)) o_expected = int(is_real) * torch.ones(len(image), device=image.device, dtype=torch.long) return { 'classification': OutputClassification2( o, o_expected, criterion_fn=LossMsePacked, # LSGan loss function ) }
batch_size=1024), run_prefix='mnist_autoencoder_variational_conditional', model_fn=lambda options: Net(), optimizers_fn=lambda datasets, model: trw.train. create_adam_optimizers_scheduler_step_lr_fn(datasets=datasets, model=model, learning_rate=0.001, step_size=120, gamma=0.1)) model.training = False nb_images = 40 device = trw.train.get_device(model) latent = torch.randn([nb_images, model.latent_size], device=device) y = one_hot(torch.ones([nb_images], dtype=torch.long, device=device) * 7, 10) latent_y = torch.cat([latent, y], dim=1) latent_y = latent_y.view(latent_y.shape[0], latent_y.shape[1], 1, 1) generated = model.autoencoder.decoder(latent_y) fig, axes = plt.subplots(nrows=1, ncols=nb_images, figsize=(nb_images, 2.5), sharey=True) decoded_images = crop_or_pad_fun(generated, [28, 28]) image_width = decoded_images.shape[2] for ax, img in zip(axes, decoded_images): curr_img = img.detach().to(torch.device('cpu')) ax.imshow(curr_img.view((image_width, image_width)), cmap='binary')