Ejemplo n.º 1
0
 def generator_loss_calculation(self, fake_examples, unlabeled_examples):
     """Calculates the generator's loss."""
     _, predicted_class_logits = self.D(fake_examples)
     fake_binary_logits = logsumexp(predicted_class_logits, dim=1)
     zeros = torch.zeros_like(fake_binary_logits)
     generator_loss = self.gan_criterion(fake_binary_logits, zeros).neg()
     return generator_loss
Ejemplo n.º 2
0
 def fake_loss_calculation(self, fake_examples):
     """Calculates the fake loss."""
     _, predicted_class_logits = self.D(fake_examples.detach())
     fake_binary_logits = logsumexp(predicted_class_logits, dim=1)
     zeros = torch.zeros_like(fake_binary_logits)
     fake_loss = self.gan_criterion(fake_binary_logits, zeros)
     fake_loss *= self.settings.unlabeled_loss_multiplier
     return fake_loss
Ejemplo n.º 3
0
 def interpolate_loss_calculation(self, interpolates):
     """Calculates the interpolate loss for use in the gradient penalty."""
     _, predicted_class_logits = self.D(interpolates)
     interpolate_binary_logits = logsumexp(predicted_class_logits, dim=1)
     zeros = torch.zeros_like(interpolate_binary_logits)
     interpolates_loss = self.gan_criterion(interpolate_binary_logits, zeros)
     interpolates_loss *= self.settings.gradient_penalty_multiplier
     return interpolates_loss
Ejemplo n.º 4
0
 def unlabeled_loss_calculation(self, unlabeled_examples):
     """Calculates the unlabeled loss."""
     _, predicted_class_logits = self.D(unlabeled_examples)
     unlabeled_binary_logits = logsumexp(predicted_class_logits, dim=1)
     ones = torch.ones_like(unlabeled_binary_logits)
     unlabeled_loss = self.gan_criterion(unlabeled_binary_logits, ones)
     unlabeled_loss *= self.settings.unlabeled_loss_multiplier
     return unlabeled_loss