def test_key_mismatch(self): lossA = ContrastiveLoss() lossB = TripletMarginLoss(0.1) self.assertRaises( AssertionError, lambda: MultipleLosses( losses={ "lossA": lossA, "lossB": lossB }, weights={ "blah": 1, "lossB": 0.23 }, ), ) minerA = MultiSimilarityMiner() self.assertRaises( AssertionError, lambda: MultipleLosses( losses={ "lossA": lossA, "lossB": lossB }, weights={ "lossA": 1, "lossB": 0.23 }, miners={"blah": minerA}, ), )
def test_input_indices_tuple(self): lossA = ContrastiveLoss() lossB = TripletMarginLoss(0.1) miner = MultiSimilarityMiner() loss_func1 = MultipleLosses(losses={ "lossA": lossA, "lossB": lossB }, weights={ "lossA": 1, "lossB": 0.23 }) loss_func2 = MultipleLosses(losses=[lossA, lossB], weights=[1, 0.23]) for loss_func in [loss_func1, loss_func2]: for dtype in TEST_DTYPES: embedding_angles = torch.arange(0, 180) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(TEST_DEVICE) # 2D embeddings labels = torch.randint(low=0, high=10, size=(180, )) indices_tuple = miner(embeddings, labels) loss = loss_func(embeddings, labels, indices_tuple) loss.backward() correct_loss = ( lossA(embeddings, labels, indices_tuple) + lossB(embeddings, labels, indices_tuple) * 0.23) self.assertTrue(torch.isclose(loss, correct_loss))
def test_length_mistmatch(self): lossA = ContrastiveLoss() lossB = TripletMarginLoss(0.1) self.assertRaises( AssertionError, lambda: MultipleLosses(losses=[lossA, lossB], weights=[1])) minerA = MultiSimilarityMiner() self.assertRaises( AssertionError, lambda: MultipleLosses( losses=[lossA, lossB], weights=[1, 0.2], miners=[minerA], ), )
def test_multiple_losses(self): lossA = ContrastiveLoss() lossB = TripletMarginLoss(0.1) loss_func = MultipleLosses(losses={ "lossA": lossA, "lossB": lossB }, weights={ "lossA": 1, "lossB": 0.23 }) for dtype in TEST_DTYPES: embedding_angles = torch.arange(0, 180) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.randint(low=0, high=10, size=(180, )) loss = loss_func(embeddings, labels) loss.backward() correct_loss = lossA(embeddings, labels) + lossB(embeddings, labels) * 0.23 self.assertTrue(torch.isclose(loss, correct_loss))