def test_glvq_loss_int_labels(self): d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1) labels = torch.tensor([0, 1]) targets = torch.ones(100) batch_loss = losses.glvq_loss(distances=d, target_labels=targets, prototype_labels=labels) loss_value = torch.sum(batch_loss, dim=0) self.assertEqual(loss_value, -100)
def test_glvq_loss_one_hot_unequal(self): dlist = [torch.ones(100), torch.zeros(100), torch.zeros(100)] d = torch.stack(dlist, dim=1) labels = torch.tensor([[0, 1], [1, 0], [1, 0]]) wl = torch.tensor([1, 0]) targets = torch.stack([wl for _ in range(100)], dim=0) batch_loss = losses.glvq_loss(distances=d, target_labels=targets, prototype_labels=labels) loss_value = torch.sum(batch_loss, dim=0) self.assertEqual(loss_value, -100)
def forward(self, outputs, targets): distances, plabels = outputs mu = glvq_loss(distances, targets, prototype_labels=plabels) batch_loss = self.squashing(mu + self.margin, beta=self.beta) return torch.sum(batch_loss, dim=0)