def test_forward_good(self): pos_scores = torch.full((3, ), 2, requires_grad=True) neg_scores = torch.full((3, 5), 1, requires_grad=True) loss_fn = RankingLoss(1.) loss = loss_fn(pos_scores, neg_scores) self.assertTensorEqual(loss, torch.zeros(())) loss.backward()
def test_no_pos(self): pos_scores = torch.empty((0, ), requires_grad=True) neg_scores = torch.empty((0, 3), requires_grad=True) loss_fn = RankingLoss(1.) loss = loss_fn(pos_scores, neg_scores) self.assertTensorEqual(loss, torch.zeros(())) loss.backward()
def test_forward(self): pos_scores = torch.tensor([0.8181, 0.5700, 0.3506], requires_grad=True) neg_scores = torch.tensor([ [0.4437, 0.6573, 0.9986, 0.2548, 0.0998], [0.6175, 0.4061, 0.4582, 0.5382, 0.3126], [0.9869, 0.2028, 0.1667, 0.0044, 0.9934], ], requires_grad=True) loss_fn = RankingLoss(1.) loss = loss_fn(pos_scores, neg_scores) self.assertTensorEqual(loss, torch.tensor(13.4475)) loss.backward() self.assertTrue((pos_scores.grad != 0).any()) self.assertTrue((neg_scores.grad != 0).any())
def __init__( self, global_optimizer: Optimizer, loss_fn: LossFunction, margin: float, relations: List[RelationSchema], ) -> None: super().__init__() self.global_optimizer = global_optimizer self.entity_optimizers: Dict[Tuple[EntityName, Partition], Optimizer] = {} if loss_fn is LossFunction.LOGISTIC: self.loss_fn = LogisticLoss() elif loss_fn is LossFunction.RANKING: self.loss_fn = RankingLoss(margin) elif loss_fn is LossFunction.SOFTMAX: self.loss_fn = SoftmaxLoss() else: raise NotImplementedError("Unknown loss function: %s" % loss_fn) self.relations = relations