def test_divisor_reducer(self): reducer = DivisorReducer() batch_size = 100 embedding_size = 64 for dtype in TEST_DTYPES: embeddings = ( torch.randn(batch_size, embedding_size).type(dtype).to(TEST_DEVICE) ) labels = torch.randint(0, 10, (batch_size,)) pair_indices = ( torch.randint(0, batch_size, (batch_size,)), torch.randint(0, batch_size, (batch_size,)), ) triplet_indices = pair_indices + ( torch.randint(0, batch_size, (batch_size,)), ) losses = torch.randn(batch_size).type(dtype).to(TEST_DEVICE) for indices, reduction_type in [ (torch.arange(batch_size), "element"), (pair_indices, "pos_pair"), (pair_indices, "neg_pair"), (triplet_indices, "triplet"), (None, "already_reduced"), ]: for partA, partB in [(0, 0), (32, 15)]: if reduction_type == "already_reduced": L = losses[0] else: L = losses loss_dict = { "loss": { "losses": L, "indices": indices, "reduction_type": reduction_type, "divisor": partA + partB, } } output = reducer(loss_dict, embeddings, labels) if reduction_type == "already_reduced": correct_output = L elif partA + partB == 0: correct_output = torch.sum(L) * 0 else: correct_output = torch.sum(L) / (32 + 15) self.assertTrue(output == correct_output) loss_dict = { "loss": { "losses": losses[0], "indices": None, "reduction_type": "already_reduced", } } output = reducer(loss_dict, embeddings, labels) correct_output = losses[0] self.assertTrue(output == correct_output)
def test_multiple_reducers(self): reducer = MultipleReducers({ "lossA": AvgNonZeroReducer(), "lossB": DivisorReducer() }) batch_size = 100 embedding_size = 64 for dtype in TEST_DTYPES: embeddings = (torch.randn( batch_size, embedding_size).type(dtype).to(TEST_DEVICE)) labels = torch.randint(0, 10, (batch_size, )) pair_indices = ( torch.randint(0, batch_size, (batch_size, )), torch.randint(0, batch_size, (batch_size, )), ) triplet_indices = pair_indices + (torch.randint( 0, batch_size, (batch_size, )), ) lossesA = torch.randn(batch_size).type(dtype).to(TEST_DEVICE) lossesB = torch.randn(batch_size).type(dtype).to(TEST_DEVICE) for indices, reduction_type in [ (torch.arange(batch_size), "element"), (pair_indices, "pos_pair"), (pair_indices, "neg_pair"), (triplet_indices, "triplet"), ]: loss_dict = { "lossA": { "losses": lossesA, "indices": indices, "reduction_type": reduction_type, }, "lossB": { "losses": lossesB, "indices": indices, "reduction_type": reduction_type, "divisor_summands": { "partA": 32, "partB": 15 }, }, } output = reducer(loss_dict, embeddings, labels) correct_output = (torch.mean( lossesA[lossesA > 0])) + (torch.sum(lossesB) / (32 + 15)) self.assertTrue(output == correct_output)