예제 #1
0
 def test_glvq_loss_int_labels(self):
     d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
     labels = torch.tensor([0, 1])
     targets = torch.ones(100)
     batch_loss = losses.glvq_loss(distances=d,
                                   target_labels=targets,
                                   prototype_labels=labels)
     loss_value = torch.sum(batch_loss, dim=0)
     self.assertEqual(loss_value, -100)
예제 #2
0
 def test_glvq_loss_one_hot_unequal(self):
     dlist = [torch.ones(100), torch.zeros(100), torch.zeros(100)]
     d = torch.stack(dlist, dim=1)
     labels = torch.tensor([[0, 1], [1, 0], [1, 0]])
     wl = torch.tensor([1, 0])
     targets = torch.stack([wl for _ in range(100)], dim=0)
     batch_loss = losses.glvq_loss(distances=d,
                                   target_labels=targets,
                                   prototype_labels=labels)
     loss_value = torch.sum(batch_loss, dim=0)
     self.assertEqual(loss_value, -100)
예제 #3
0
 def forward(self, outputs, targets):
     distances, plabels = outputs
     mu = glvq_loss(distances, targets, prototype_labels=plabels)
     batch_loss = self.squashing(mu + self.margin, beta=self.beta)
     return torch.sum(batch_loss, dim=0)