def main(): feature_train_file = cfg.feature_train_file feature_test_file = cfg.feature_test_file train_dataset = PlaceDateset(feature_train_file) test_dataset = PlaceDateset(feature_test_file) train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False) model = Feature_class(cfg).cuda() optimizer = optim.__dict__[cfg.optim.name](model.parameters(), **cfg.optim.setting) #在指定的epoch对其进行衰减 scheduler = optim.lr_scheduler.__dict__[cfg.stepper.name]( optimizer, **cfg.stepper.setting) criterion1 = nn.CrossEntropyLoss() distance = CosineSimilarity() criterion2 = losses.TripletMarginLoss(distance=distance) total_loss = list() total_epoch = list() total_ap = list() total_acc = list() max_ap = 0 for epoch in range(0, cfg.epoch): train(cfg, model, train_loader, optimizer, scheduler, epoch, criterion1, criterion2) loss, ap, acc = test(cfg, model, test_loader, criterion1, criterion2) total_loss.append(loss) total_ap.append(ap) total_epoch.append(epoch) total_acc.append(acc) print('Test Epoch: {} \tloss: {:.6f}\tap: {:.6f}\t acc: {:.6f}'.format( epoch, loss, ap, acc)) if ap > max_ap: best_model = model save_path = cfg.store + '.pth' torch.save(best_model.state_dict(), save_path) plt.figure() plt.plot(total_epoch, total_loss, 'b-', label=u'loss') plt.legend() loss_path = cfg.store + "_loss.png" plt.savefig(loss_path) plt.figure() plt.plot(total_epoch, total_ap, 'b-', label=u'AP') plt.legend() AP_path = cfg.store + "_AP.png" plt.savefig(AP_path) plt.figure() plt.plot(total_epoch, total_acc, 'b-', label=u'acc') plt.legend() acc_path = cfg.store + "_acc.png" plt.savefig(acc_path)
def __init__( self, model, margin=0.2, lr=1e-3, lr_patience=2, lr_decay_ratio=0.5, memory_batch_max_num=2048, ): super().__init__() self.save_hyperparameters() self.model = model self.margin = margin self.lr = lr self.lr_patience = lr_patience self.lr_decay_ratio = lr_decay_ratio self.memory_batch_max_num = memory_batch_max_num self.loss_func = losses.CrossBatchMemory( losses.ContrastiveLoss(pos_margin=1, neg_margin=0, distance=CosineSimilarity()), self.model.feature_dim, memory_size=self.memory_batch_max_num, miner=miners.MultiSimilarityMiner(epsilon=self.margin))
def setUpClass(self): self.device = torch.device('cuda') self.dist_miner = BatchHardMiner(distance=LpDistance(normalize_embeddings=False)) self.normalized_dist_miner = BatchHardMiner(distance=LpDistance(normalize_embeddings=True)) self.normalized_dist_miner_squared = BatchHardMiner(distance=LpDistance(normalize_embeddings=True, power=2)) self.sim_miner = BatchHardMiner(distance=CosineSimilarity()) self.labels = torch.LongTensor([0, 0, 1, 1, 0, 2, 1, 1, 1]) self.correct_a = torch.LongTensor([0, 1, 2, 3, 4, 6, 7, 8]).to(self.device) self.correct_p = torch.LongTensor([4, 4, 8, 8, 0, 2, 2, 2]).to(self.device) self.correct_n = [torch.LongTensor([2, 2, 1, 4, 3, 5, 5, 5]).to(self.device), torch.LongTensor([2, 2, 1, 4, 5, 5, 5, 5]).to(self.device)]
def test_pair_margin_miner(self): for dtype in TEST_DTYPES: for distance in [LpDistance(), CosineSimilarity()]: embedding_angles = torch.arange(0, 16) embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.randint(low=0, high=2, size=(16,)) mat = distance(embeddings) pos_pairs = [] neg_pairs = [] for i in range(len(embeddings)): anchor_label = labels[i] for j in range(len(embeddings)): if j == i: continue positive_label = labels[j] if positive_label == anchor_label: ap_dist = mat[i,j] pos_pairs.append((i, j, ap_dist)) for i in range(len(embeddings)): anchor_label = labels[i] for j in range(len(embeddings)): if j == i: continue negative_label = labels[j] if negative_label != anchor_label: an_dist = mat[i,j] neg_pairs.append((i, j, an_dist)) for pos_margin_int in range(-1, 4): pos_margin = float(pos_margin_int) * 0.05 for neg_margin_int in range(2, 7): neg_margin = float(neg_margin_int) * 0.05 miner = PairMarginMiner(pos_margin, neg_margin, distance=distance) correct_pos_pairs = [] correct_neg_pairs = [] for i,j,k in pos_pairs: condition = (k < pos_margin) if distance.is_inverted else (k > pos_margin) if condition: correct_pos_pairs.append((i,j)) for i,j,k in neg_pairs: condition = (k > neg_margin) if distance.is_inverted else (k < neg_margin) if condition: correct_neg_pairs.append((i,j)) correct_pos = set(correct_pos_pairs) correct_neg = set(correct_neg_pairs) a1, p1, a2, n2 = miner(embeddings, labels) mined_pos = set([(a.item(),p.item()) for a,p in zip(a1,p1)]) mined_neg = set([(a.item(),n.item()) for a,n in zip(a2,n2)]) self.assertTrue(mined_pos == correct_pos) self.assertTrue(mined_neg == correct_neg)
def test_backward(self): loss_funcA = ContrastiveLoss() loss_funcB = ContrastiveLoss(distance=CosineSimilarity()) for dtype in TEST_DTYPES: for loss_func in [loss_funcA, loss_funcB]: embedding_angles = [0] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.LongTensor([0]) loss = loss_func(embeddings, labels) loss.backward()
def test_with_no_valid_pairs(self): loss_funcA = ContrastiveLoss() loss_funcB = ContrastiveLoss(distance=CosineSimilarity()) for dtype in TEST_DTYPES: embedding_angles = [0] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.LongTensor([0]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) self.assertEqual(lossA, 0) self.assertEqual(lossB, 0)
def setUpClass(self): self.device = torch.device('cuda') self.dist_miner = HDCMiner( filter_percentage=0.3, distance=LpDistance(normalize_embeddings=False)) self.normalized_dist_miner = HDCMiner( filter_percentage=0.3, distance=LpDistance(normalize_embeddings=True)) self.normalized_dist_miner_squared = HDCMiner( filter_percentage=0.3, distance=LpDistance(normalize_embeddings=True, power=2)) self.sim_miner = HDCMiner(filter_percentage=0.3, distance=CosineSimilarity()) self.labels = torch.LongTensor([0, 0, 1, 1, 1, 0]) correct_a1 = torch.LongTensor([0, 5, 1, 5]) correct_p = torch.LongTensor([5, 0, 5, 1]) self.correct_pos_pairs = torch.stack([correct_a1, correct_p], dim=1).to(self.device) correct_a2 = torch.LongTensor([1, 2, 4, 5, 0, 2]) correct_n = torch.LongTensor([2, 1, 5, 4, 2, 0]) self.correct_neg_pairs = torch.stack([correct_a2, correct_n], dim=1).to(self.device)
default=3, help='number of patches to compare with cam1') parser.add_argument('--patches_to_compare_c2', type=int, default=5, help='number of patches to compare with cam2') return parser.parse_args(args) if __name__ == '__main__': args = parse_args() trunk, embedder = load_trunk_embedder(args.trunk_model, args.embedder_model) match_finder = MatchFinder(distance=CosineSimilarity(), threshold=args.thr) inference_model = InferenceModel(trunk, embedder, match_finder=match_finder, batch_size=64) labels = pd.read_csv(args.det_csv) test_data, _ = get_data_loader('test', labels, args.det_patches) indices_cameras = c_f.get_labels_to_indices(test_data.camera) id_frames_cams = {} for cam in ['c010', 'c011', 'c012', 'c013', 'c014', 'c015']: id_frames_cams[cam] = get_id_frames_cam(test_data, indices_cameras, cam) id_frames_all = deepcopy(id_frames_cams['c010'])
def test_multi_similarity_miner(self): epsilon = 0.1 for dtype in TEST_DTYPES: for distance in [CosineSimilarity(), LpDistance()]: miner = MultiSimilarityMiner(epsilon, distance=distance) embedding_angles = torch.arange(0, 64) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(TEST_DEVICE) # 2D embeddings labels = torch.randint(low=0, high=10, size=(64, )) mat = distance(embeddings) pos_pairs = [] neg_pairs = [] for i in range(len(embeddings)): anchor_label = labels[i] for j in range(len(embeddings)): if j != i: other_label = labels[j] if anchor_label == other_label: pos_pairs.append((i, j, mat[i, j])) if anchor_label != other_label: neg_pairs.append((i, j, mat[i, j])) correct_a1, correct_p = [], [] correct_a2, correct_n = [], [] for a1, p, ap_sim in pos_pairs: most_difficult = (c_f.neg_inf(dtype) if distance.is_inverted else c_f.pos_inf(dtype)) for a2, n, an_sim in neg_pairs: if a2 == a1: condition = ((an_sim > most_difficult) if distance.is_inverted else (an_sim < most_difficult)) if condition: most_difficult = an_sim condition = ((ap_sim < most_difficult + epsilon) if distance.is_inverted else (ap_sim > most_difficult - epsilon)) if condition: correct_a1.append(a1) correct_p.append(p) for a2, n, an_sim in neg_pairs: most_difficult = (c_f.pos_inf(dtype) if distance.is_inverted else c_f.neg_inf(dtype)) for a1, p, ap_sim in pos_pairs: if a2 == a1: condition = ((ap_sim < most_difficult) if distance.is_inverted else (ap_sim > most_difficult)) if condition: most_difficult = ap_sim condition = ((an_sim > most_difficult - epsilon) if distance.is_inverted else (an_sim < most_difficult + epsilon)) if condition: correct_a2.append(a2) correct_n.append(n) correct_pos_pairs = set([ (a, p) for a, p in zip(correct_a1, correct_p) ]) correct_neg_pairs = set([ (a, n) for a, n in zip(correct_a2, correct_n) ]) a1, p1, a2, n2 = miner(embeddings, labels) pos_pairs = set([(a.item(), p.item()) for a, p in zip(a1, p1)]) neg_pairs = set([(a.item(), n.item()) for a, n in zip(a2, n2)]) self.assertTrue(pos_pairs == correct_pos_pairs) self.assertTrue(neg_pairs == correct_neg_pairs)
def test_triplet_margin_miner(self): for dtype in TEST_DTYPES: for distance in [LpDistance(), CosineSimilarity()]: embedding_angles = torch.arange(0, 16) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(self.device) # 2D embeddings labels = torch.randint(low=0, high=2, size=(16, )) mat = distance(embeddings) triplets = [] for i in range(len(embeddings)): anchor_label = labels[i] for j in range(len(embeddings)): if j == i: continue positive_label = labels[j] if positive_label == anchor_label: ap_dist = mat[i, j] for k in range(len(embeddings)): if k == j or k == i: continue negative_label = labels[k] if negative_label != positive_label: an_dist = mat[i, k] if distance.is_inverted: triplets.append( (i, j, k, ap_dist - an_dist)) else: triplets.append( (i, j, k, an_dist - ap_dist)) for margin_int in range(-1, 11): margin = float(margin_int) * 0.05 minerA = TripletMarginMiner(margin, type_of_triplets="all", distance=distance) minerB = TripletMarginMiner(margin, type_of_triplets="hard", distance=distance) minerC = TripletMarginMiner(margin, type_of_triplets="semihard", distance=distance) minerD = TripletMarginMiner(margin, type_of_triplets="easy", distance=distance) correctA, correctB, correctC, correctD = [], [], [], [] for i, j, k, distance_diff in triplets: if distance_diff > margin: correctD.append((i, j, k)) else: correctA.append((i, j, k)) if distance_diff > 0: correctC.append((i, j, k)) if distance_diff <= 0: correctB.append((i, j, k)) for correct, miner in [ (correctA, minerA), (correctB, minerB), (correctC, minerC), (correctD, minerD), ]: correct_triplets = set(correct) a1, p1, n1 = miner(embeddings, labels) mined_triplets = set([(a.item(), p.item(), n.item()) for a, p, n in zip(a1, p1, n1)]) self.assertTrue(mined_triplets == correct_triplets)
def test_triplet_margin_loss(self): margin = 0.2 loss_funcA = TripletMarginLoss(margin=margin) loss_funcB = TripletMarginLoss(margin=margin, reducer=MeanReducer()) loss_funcC = TripletMarginLoss(margin=margin, distance=CosineSimilarity()) loss_funcD = TripletMarginLoss(margin=margin, reducer=MeanReducer(), distance=CosineSimilarity()) for dtype in TEST_DTYPES: embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(self.device) # 2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) lossC = loss_funcC(embeddings, labels) lossD = loss_funcD(embeddings, labels) triplets = [ (0, 1, 2), (0, 1, 3), (0, 1, 4), (1, 0, 2), (1, 0, 3), (1, 0, 4), (2, 3, 0), (2, 3, 1), (2, 3, 4), (3, 2, 0), (3, 2, 1), (3, 2, 4), ] correct_loss = 0 correct_loss_cosine = 0 num_non_zero_triplets = 0 num_non_zero_triplets_cosine = 0 for a, p, n in triplets: anchor, positive, negative = embeddings[a], embeddings[ p], embeddings[n] curr_loss = torch.relu( torch.sqrt(torch.sum((anchor - positive)**2)) - torch.sqrt(torch.sum((anchor - negative)**2)) + margin) curr_loss_cosine = torch.relu( torch.sum(anchor * negative) - torch.sum(anchor * positive) + margin) if curr_loss > 0: num_non_zero_triplets += 1 if curr_loss_cosine > 0: num_non_zero_triplets_cosine += 1 correct_loss += curr_loss correct_loss_cosine += curr_loss_cosine rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue( torch.isclose(lossA, correct_loss / num_non_zero_triplets, rtol=rtol)) self.assertTrue( torch.isclose(lossB, correct_loss / len(triplets), rtol=rtol)) self.assertTrue( torch.isclose(lossC, correct_loss_cosine / num_non_zero_triplets_cosine, rtol=rtol)) self.assertTrue( torch.isclose(lossD, correct_loss_cosine / len(triplets), rtol=rtol))
def main(): feature_train_file = cfg.feature_train_file feature_test_file = cfg.feature_test_file train_dataset = PlaceDateset(feature_train_file) test_dataset = PlaceDateset(feature_test_file) train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False) model = ViT(cfg=cfg, feature_seq=16, num_classes=1, dim=2048, depth=8, heads=8, mlp_dim=1024, dropout=0.1, emb_dropout=0.1).cuda() #model=ViT_cat(cfg=cfg,feature_seq=16,num_classes=2,dim=4096,depth=8,heads=8,mlp_dim=1024,dropout = 0.1,emb_dropout = 0.1).cuda() optimizer = optim.__dict__[cfg.optim.name](model.parameters(), **cfg.optim.setting) #在指定的epoch对其进行衰减 scheduler = optim.lr_scheduler.__dict__[cfg.stepper.name]( optimizer, **cfg.stepper.setting) #criterion1 = nn.CrossEntropyLoss(torch.Tensor(cfg.loss.weight).cuda()) #criterion1 = nn.BCEWithLogitsLoss() criterion1 = FocalLoss(logits=True) #加入对数损失 distance = CosineSimilarity() criterion2 = losses.TripletMarginLoss(distance=distance) total_loss = list() total_epoch = list() total_ap = list() total_acc = list() max_ap = 0 for epoch in range(0, cfg.epoch): train(cfg, model, train_loader, optimizer, scheduler, epoch, criterion1, criterion2) loss, ap, acc = test(cfg, model, test_loader, criterion1, criterion2) total_loss.append(loss) total_ap.append(ap) total_epoch.append(epoch) total_acc.append(acc) print('Test Epoch: {} \tloss: {:.6f}\tap: {:.6f}\tacc: {:.6f}'.format( epoch, loss, ap, acc)) if ap > max_ap: best_model = model save_path = cfg.store + '.pth' torch.save(best_model.state_dict(), save_path) plt.figure() plt.plot(total_epoch, total_loss, 'b-', label=u'loss') plt.legend() loss_path = cfg.store + "_loss.png" plt.savefig(loss_path) plt.figure() plt.plot(total_epoch, total_ap, 'b-', label=u'AP') plt.legend() AP_path = cfg.store + "_AP.png" plt.savefig(AP_path) plt.figure() plt.plot(total_epoch, total_acc, 'b-', label=u'acc') plt.legend() acc_path = cfg.store + "_acc.png" plt.savefig(acc_path)
def get_pos_neg_vals(self, use_pairwise): output = (0, 1, LpDistance(power=2)) if not use_pairwise: return (1, 0, CosineSimilarity()) return output
def main(): muti_train_file = cfg.muti_train_file muti_test_file = cfg.muti_test_file train_dataset = MutiDateset(muti_train_file) test_dataset = MutiDateset(muti_test_file) train_loader = DataLoader(train_dataset,batch_size=cfg.batch_size,shuffle=True) test_loader = DataLoader(test_dataset,batch_size=cfg.batch_size ,shuffle=False) model1 = Dense_fenlei(num_classes=2,dim = 2048,dropout = 0.5).cuda() model2 = ViT(cfg=cfg,feature_seq=16,num_classes=1,dim=2048,depth=8,heads=8,mlp_dim=1024,dropout = 0.1,emb_dropout = 0.1,batch_normalization=False).cuda() model3 = ViT(cfg=cfg,feature_seq=16,num_classes=1,dim=2048,depth=8,heads=8,mlp_dim=1024,dropout = 0.1,emb_dropout = 0.1).cuda() optimizer1 = optim.__dict__[cfg.optim1.name](model1.parameters(), **cfg.optim1.setting) optimizer2 = optim.__dict__[cfg.optim2.name](model2.parameters(), **cfg.optim2.setting) optimizer3 = optim.__dict__[cfg.optim3.name](model3.parameters(), **cfg.optim3.setting) #在指定的epoch对其进行衰减 scheduler = optim.lr_scheduler.__dict__[cfg.stepper.name](optimizer1, **cfg.stepper.setting) criterion3 = nn.CrossEntropyLoss(torch.Tensor(cfg.loss.weight).cuda()) #criterion1 = nn.BCEWithLogitsLoss() criterion1 = FocalLoss(logits=True) #加入对数损失 distance = CosineSimilarity() criterion2 = losses.TripletMarginLoss(distance = distance) total_loss, total_loss_place, total_loss_tea=list(), list(), list() total_epoch=list() total_ap, total_ap_place, total_ap_tea=list(),list(),list() total_acc=list() max_ap=0 for epoch in range(0,cfg.epoch): train_mult(cfg, model1,model2,model3, train_loader, optimizer1,optimizer2, optimizer3, scheduler, epoch, criterion1,criterion2,criterion3) loss,loss_place,loss_tea,ap,ap_place,ap_tea,acc=test_mult(cfg, model1, model2, model3, test_loader, criterion1,criterion2,criterion3) total_loss.append(loss) total_ap.append(ap) total_loss_place.append(loss_place) total_ap_place.append(ap_place) total_loss_tea.append(loss_tea) total_ap_tea.append(ap_tea) total_epoch.append(epoch) total_acc.append(acc) print('Test Epoch: {} \tloss: {:.6f}\tap: {:.6f}\tacc: {:.6f}'.format(epoch, loss,ap,acc)) if ap>max_ap: best_model=model3 save_path=cfg.store+'.pth' torch.save(best_model.state_dict(), save_path) plt.figure(figsize=(20, 20)) plt.plot(total_epoch,total_loss,'b^',label=u'loss') plt.plot(total_epoch,total_loss_place,'y^',label=u'loss_place') plt.plot(total_epoch,total_loss_tea,'r^',label=u'loss_tea') plt.legend() loss_path=cfg.store+"_loss.png" plt.savefig(loss_path) plt.figure(figsize=(20, 20)) plt.plot(total_epoch,total_ap,'b^',label=u'AP') plt.plot(total_epoch,total_ap_place,'y^',label=u'AP_place') plt.plot(total_epoch,total_ap_tea,'r^',label=u'AP_tea') plt.legend() AP_path=cfg.store+"_AP.png" plt.savefig(AP_path) plt.figure() plt.plot(total_epoch,total_acc,'b^',label=u'acc') plt.legend() acc_path=cfg.store+"_acc.png" plt.savefig(acc_path)
def test_contrastive_loss(self): loss_funcA = ContrastiveLoss(pos_margin=0.25, neg_margin=1.5, distance=LpDistance(power=2)) loss_funcB = ContrastiveLoss(pos_margin=1.5, neg_margin=0.6, distance=CosineSimilarity()) loss_funcC = ContrastiveLoss(pos_margin=0.25, neg_margin=1.5, distance=LpDistance(power=2), reducer=MeanReducer()) loss_funcD = ContrastiveLoss(pos_margin=1.5, neg_margin=0.6, distance=CosineSimilarity(), reducer=MeanReducer()) for dtype in TEST_DTYPES: embedding_angles = [0, 20, 40, 60, 80] embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype).to(self.device) #2D embeddings labels = torch.LongTensor([0, 0, 1, 1, 2]) lossA = loss_funcA(embeddings, labels) lossB = loss_funcB(embeddings, labels) lossC = loss_funcC(embeddings, labels) lossD = loss_funcD(embeddings, labels) pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)] neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 0), (2, 1), (2, 4), (3, 0), (3, 1), (3, 4), (4, 0), (4, 1), (4, 2), (4, 3)] correct_pos_losses = [0, 0, 0, 0] correct_neg_losses = [0, 0, 0, 0] num_non_zero_pos = [0, 0, 0, 0] num_non_zero_neg = [0, 0, 0, 0] for a, p in pos_pairs: anchor, positive = embeddings[a], embeddings[p] correct_lossA = torch.relu( torch.sum((anchor - positive)**2) - 0.25) correct_lossB = torch.relu(1.5 - torch.matmul(anchor, positive)) correct_pos_losses[0] += correct_lossA correct_pos_losses[1] += correct_lossB correct_pos_losses[2] += correct_lossA correct_pos_losses[3] += correct_lossB if correct_lossA > 0: num_non_zero_pos[0] += 1 num_non_zero_pos[2] += 1 if correct_lossB > 0: num_non_zero_pos[1] += 1 num_non_zero_pos[3] += 1 for a, n in neg_pairs: anchor, negative = embeddings[a], embeddings[n] correct_lossA = torch.relu(1.5 - torch.sum((anchor - negative)**2)) correct_lossB = torch.relu( torch.matmul(anchor, negative) - 0.6) correct_neg_losses[0] += correct_lossA correct_neg_losses[1] += correct_lossB correct_neg_losses[2] += correct_lossA correct_neg_losses[3] += correct_lossB if correct_lossA > 0: num_non_zero_neg[0] += 1 num_non_zero_neg[2] += 1 if correct_lossB > 0: num_non_zero_neg[1] += 1 num_non_zero_neg[3] += 1 for i in range(2): if num_non_zero_pos[i] > 0: correct_pos_losses[i] /= num_non_zero_pos[i] if num_non_zero_neg[i] > 0: correct_neg_losses[i] /= num_non_zero_neg[i] for i in range(2, 4): correct_pos_losses[i] /= len(pos_pairs) correct_neg_losses[i] /= len(neg_pairs) correct_losses = [0, 0, 0, 0] for i in range(4): correct_losses[ i] = correct_pos_losses[i] + correct_neg_losses[i] rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue(torch.isclose(lossA, correct_losses[0], rtol=rtol)) self.assertTrue(torch.isclose(lossB, correct_losses[1], rtol=rtol)) self.assertTrue(torch.isclose(lossC, correct_losses[2], rtol=rtol)) self.assertTrue(torch.isclose(lossD, correct_losses[3], rtol=rtol))