Ejemplo n.º 1
0
    def forward(self, architecture: Architecture, real_features: Tensor, fake_features: Tensor,
                **additional_inputs: Tensor) -> Tensor:
        # real loss
        real_predictions = architecture.discriminator(real_features, **additional_inputs)
        real_loss = - critic_loss_function(real_predictions)

        # fake loss
        fake_predictions = architecture.discriminator(fake_features, **additional_inputs)
        fake_loss = critic_loss_function(fake_predictions)

        # total loss
        return real_loss + fake_loss
Ejemplo n.º 2
0
    def forward(self, architecture: Architecture, real_features: Tensor, fake_features: Tensor,
                **additional_inputs: Tensor) -> Tensor:
        loss = super(WGANCriticLossWithGradientPenalty, self).forward(
            architecture, real_features, fake_features, **additional_inputs)

        # calculate gradient penalty
        alpha = rand(len(real_features), 1)
        alpha = alpha.expand(real_features.size())
        alpha = to_gpu_if_available(alpha)

        interpolates = alpha * real_features + ((1 - alpha) * fake_features)
        interpolates.requires_grad_()

        # we do not interpolate the conditions because they are the same for fake and real features
        discriminator_interpolates = architecture.discriminator(interpolates, **additional_inputs)

        gradients = grad(outputs=discriminator_interpolates,
                         inputs=interpolates,
                         grad_outputs=to_gpu_if_available(ones_like(discriminator_interpolates)),
                         create_graph=True,
                         retain_graph=True,
                         only_inputs=True)[0]

        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.weight

        # return total loss
        return loss + gradient_penalty
Ejemplo n.º 3
0
 def forward(self, architecture: Architecture, fake_features: Tensor,
             **additional_inputs: Tensor) -> Tensor:
     fake_predictions = architecture.discriminator(fake_features,
                                                   **additional_inputs)
     positive_labels = generate_positive_labels(len(fake_predictions),
                                                self.smooth_positive_labels)
     return self.bce_loss(fake_predictions, positive_labels)
Ejemplo n.º 4
0
    def forward(self, architecture: Architecture, real_features: Tensor,
                fake_features: Tensor, **additional_inputs: Tensor) -> Tensor:
        # real loss
        real_predictions = architecture.discriminator(real_features,
                                                      **additional_inputs)
        positive_labels = generate_positive_labels(len(real_predictions),
                                                   self.smooth_positive_labels)
        real_loss = self.bce_loss(real_predictions, positive_labels)

        # fake loss
        fake_predictions = architecture.discriminator(fake_features,
                                                      **additional_inputs)
        negative_labels = to_gpu_if_available(zeros(len(fake_predictions)))
        fake_loss = self.bce_loss(fake_predictions, negative_labels)

        # total loss
        return real_loss + fake_loss
Ejemplo n.º 5
0
    def forward(self, architecture: Architecture, features: Tensor,
                generated: Tensor, imputed: Tensor, hint: Tensor,
                non_missing_mask: Tensor) -> Tensor:
        # the discriminator should predict the missing mask
        # which means that it detects which positions where imputed and which ones were real
        predictions = architecture.discriminator(imputed, missing_mask=hint)
        # but the generator wants to fool the discriminator
        # so we optimize for the inverse mask
        adversarial_loss = self.bce_loss(predictions, non_missing_mask)

        # reconstruction of the non-missing values
        reconstruction_loss = self.reconstruction_loss(generated, features,
                                                       non_missing_mask)

        # return the complete loss
        return adversarial_loss + self.reconstruction_loss_weight * reconstruction_loss
Ejemplo n.º 6
0
 def forward(self, architecture: Architecture, fake_features: Tensor, **additional_inputs: Tensor) -> Tensor:
     fake_predictions = architecture.discriminator(fake_features, **additional_inputs)
     return - critic_loss_function(fake_predictions)
Ejemplo n.º 7
0
 def forward(self, architecture: Architecture, imputed: Tensor,
             hint: Tensor, missing_mask: Tensor) -> Tensor:
     # the discriminator should predict the missing mask
     # which means that it detects which positions where imputed and which ones were real
     predictions = architecture.discriminator(imputed, missing_mask=hint)
     return self.bce_loss(predictions, missing_mask)