Пример #1
0
    def _compute_forward_loss(self):

        kl_divergence_loss_1 = self.opt.kl_divergence_coef * (
            -0.5 * torch.sum(1 + self.style_logvar_1 - self.style_mu_1.pow(2) -
                             self.style_logvar_1.exp()))
        kl_divergence_loss_1 /= (self.opt.batch_size * self.opt.num_channels *
                                 self.opt.image_size * self.opt.image_size)
        kl_divergence_loss_1.backward(retain_graph=True)

        kl_divergence_loss_2 = self.opt.kl_divergence_coef * (
            -0.5 * torch.sum(1 + self.style_logvar_2 - self.style_mu_2.pow(2) -
                             self.style_logvar_2.exp()))
        kl_divergence_loss_2 /= (self.opt.batch_size * self.opt.num_channels *
                                 self.opt.image_size * self.opt.image_size)
        kl_divergence_loss_2.backward(retain_graph=True)

        reconstruction_error_1 = self.opt.reconstruction_coef * mse_loss(
            self.reconstructed_X_1, Variable(self.X_1))
        reconstruction_error_1.backward(retain_graph=True)

        reconstruction_error_2 = self.opt.reconstruction_coef * mse_loss(
            self.reconstructed_X_2, Variable(self.X_2))
        reconstruction_error_2.backward()

        self.reconstruction_error = (
            reconstruction_error_1 +
            reconstruction_error_2) / self.opt.reconstruction_coef
        self.kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2
                                    ) / self.opt.kl_divergence_coef
Пример #2
0
            reconstruction_weight=reconstruction_weight,
            kl_weight=kl_d_weight,
        )
        # For multiple MSE
        # For every MSE, we halve the image size
        # And take the MSE between the resulting images
        for i in range(num_mse):
            new_size = image_size // pow(2, i + 1)
            with torch.no_grad():
                resized_batch = nn.functional.interpolate(batch,
                                                          size=new_size,
                                                          mode="bilinear")
            resized_output = nn.functional.interpolate(reconstructed,
                                                       size=new_size,
                                                       mode="bilinear")
            mse = loss.mse_loss(resized_output, resized_batch, use_sum)
            batch_loss += mse
            loss_dict["MSE"] += mse.item()
        # Backprop
        batch_loss.backward()
        # Update our optimizer parameters
        optimizer.step()
        # Add the batch's loss to the total loss for the epoch
        train_loss += batch_loss.item()
        train_recon_loss += loss_dict["MSE"] + loss_dict["SSIM"]
        train_kl_d += loss_dict["KL Divergence"]

    if freeze_conv_for_fusions:
        models.toggle_layer_freezing(freezable_layers, trainable=False)

    if learning_rate != fusion_learning_rate:
Пример #3
0
        # Decoded Fake Image (Noise)
        noise_representation = torch.randn(current_batch_size, vae_latent_dim)
        reconstructed_noise = decoder(noise_representation.to(device))

        # Run Discriminator for Real, Fake (Reconstructed), Fake (Noise) Images
        real_output, real_lth_output = discriminator(batch)
        recon_output, recon_lth_output = discriminator(reconstructed_image)
        noise_output, noise_lth_output = discriminator(reconstructed_noise)

        # Calculate Loss
        disc_real_loss = loss.bce_loss(real_output, y_real, use_sum)
        disc_recon_loss = loss.bce_loss(recon_output, y_fake, use_sum)
        disc_noise_loss = loss.bce_loss(noise_output, y_fake, use_sum)
        L_gan = disc_real_loss + disc_recon_loss + disc_noise_loss
        L_prior = loss.kl_divergence(mu, log_var, use_sum)
        L_reconstruction = loss.mse_loss(recon_lth_output, real_lth_output, use_sum)
        discriminator_loss = L_gan
        decoder_loss = gamma * L_reconstruction - L_gan
        encoder_loss = L_prior + L_reconstruction

        # Zero Gradients
        discriminator_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        encoder_optimizer.zero_grad()

        # Backpropagate
        discriminator_loss.backward(retain_graph=True)
        decoder_loss.backward(retain_graph=True)
        encoder_loss.backward()

        # Update Parameters