def test_input_indices_tuple(self): for dtype in TEST_DTYPES: batch_size = 32 pair_miner = PairMarginMiner(pos_margin=0, neg_margin=1) triplet_miner = TripletMarginMiner(margin=1) self.loss = CrossBatchMemory( loss=ContrastiveLoss(), embedding_size=self.embedding_size, memory_size=self.memory_size, ) for i in range(30): embeddings = (torch.randn( batch_size, self.embedding_size).to(TEST_DEVICE).type(dtype)) labels = torch.arange(batch_size).to(TEST_DEVICE) self.loss(embeddings, labels) for curr_miner in [pair_miner, triplet_miner]: input_indices_tuple = curr_miner(embeddings, labels) all_labels = torch.cat([labels, self.loss.label_memory], dim=0) a1ii, pii, a2ii, nii = lmu.convert_to_pairs( input_indices_tuple, labels) indices_tuple = lmu.get_all_pairs_indices( labels, self.loss.label_memory) a1i, pi, a2i, ni = self.loss.remove_self_comparisons( indices_tuple) a1, p, a2, n = self.loss.create_indices_tuple( batch_size, embeddings, labels, self.loss.embedding_memory, self.loss.label_memory, input_indices_tuple, True, ) self.assertTrue(not torch.any((all_labels[a1] - all_labels[p]).bool())) self.assertTrue( torch.all((all_labels[a2] - all_labels[n]).bool())) self.assertTrue(len(a1) == len(a1i) + len(a1ii)) self.assertTrue(len(p) == len(pi) + len(pii)) self.assertTrue(len(a2) == len(a2i) + len(a2ii)) self.assertTrue(len(n) == len(ni) + len(nii))
def test_shift_indices_tuple(self): batch_size = 32 pair_miner = PairMarginMiner(pos_margin=0, neg_margin=1, use_similarity=False) triplet_miner = TripletMarginMiner(margin=1) self.loss = CrossBatchMemory(loss=ContrastiveLoss(), embedding_size=self.embedding_size, memory_size=self.memory_size) for i in range(30): embeddings = torch.randn(batch_size, self.embedding_size) labels = torch.arange(batch_size) loss = self.loss(embeddings, labels) all_labels = torch.cat([labels, self.loss.label_memory], dim=0) indices_tuple = lmu.get_all_pairs_indices(labels, self.loss.label_memory) shifted = c_f.shift_indices_tuple(indices_tuple, batch_size) self.assertTrue(torch.equal(indices_tuple[0], shifted[0])) self.assertTrue(torch.equal(indices_tuple[2], shifted[2])) self.assertTrue(torch.equal(indices_tuple[1], shifted[1]-batch_size)) self.assertTrue(torch.equal(indices_tuple[3], shifted[3]-batch_size)) a1, p, a2, n = shifted self.assertTrue(not torch.any((all_labels[a1]-all_labels[p]).bool())) self.assertTrue(torch.all((all_labels[a2]-all_labels[n]).bool())) indices_tuple = pair_miner(embeddings, labels, self.loss.embedding_memory, self.loss.label_memory) shifted = c_f.shift_indices_tuple(indices_tuple, batch_size) self.assertTrue(torch.equal(indices_tuple[0], shifted[0])) self.assertTrue(torch.equal(indices_tuple[2], shifted[2])) self.assertTrue(torch.equal(indices_tuple[1], shifted[1]-batch_size)) self.assertTrue(torch.equal(indices_tuple[3], shifted[3]-batch_size)) a1, p, a2, n = shifted self.assertTrue(not torch.any((all_labels[a1]-all_labels[p]).bool())) self.assertTrue(torch.all((all_labels[a2]-all_labels[n]).bool())) indices_tuple = triplet_miner(embeddings, labels, self.loss.embedding_memory, self.loss.label_memory) shifted = c_f.shift_indices_tuple(indices_tuple, batch_size) self.assertTrue(torch.equal(indices_tuple[0], shifted[0])) self.assertTrue(torch.equal(indices_tuple[1], shifted[1]-batch_size)) self.assertTrue(torch.equal(indices_tuple[2], shifted[2]-batch_size)) a, p, n = shifted self.assertTrue(not torch.any((all_labels[a]-all_labels[p]).bool())) self.assertTrue(torch.all((all_labels[p]-all_labels[n]).bool()))
def test_triplet_margin_miner(self): for dtype in TEST_DTYPES: for distance in [LpDistance(), CosineSimilarity()]: embedding_angles = torch.arange(0, 16) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(self.device) # 2D embeddings labels = torch.randint(low=0, high=2, size=(16, )) mat = distance(embeddings) triplets = [] for i in range(len(embeddings)): anchor_label = labels[i] for j in range(len(embeddings)): if j == i: continue positive_label = labels[j] if positive_label == anchor_label: ap_dist = mat[i, j] for k in range(len(embeddings)): if k == j or k == i: continue negative_label = labels[k] if negative_label != positive_label: an_dist = mat[i, k] if distance.is_inverted: triplets.append( (i, j, k, ap_dist - an_dist)) else: triplets.append( (i, j, k, an_dist - ap_dist)) for margin_int in range(-1, 11): margin = float(margin_int) * 0.05 minerA = TripletMarginMiner(margin, type_of_triplets="all", distance=distance) minerB = TripletMarginMiner(margin, type_of_triplets="hard", distance=distance) minerC = TripletMarginMiner(margin, type_of_triplets="semihard", distance=distance) minerD = TripletMarginMiner(margin, type_of_triplets="easy", distance=distance) correctA, correctB, correctC, correctD = [], [], [], [] for i, j, k, distance_diff in triplets: if distance_diff > margin: correctD.append((i, j, k)) else: correctA.append((i, j, k)) if distance_diff > 0: correctC.append((i, j, k)) if distance_diff <= 0: correctB.append((i, j, k)) for correct, miner in [ (correctA, minerA), (correctB, minerB), (correctC, minerC), (correctD, minerD), ]: correct_triplets = set(correct) a1, p1, n1 = miner(embeddings, labels) mined_triplets = set([(a.item(), p.item(), n.item()) for a, p, n in zip(a1, p1, n1)]) self.assertTrue(mined_triplets == correct_triplets)
def test_sanity_check(self): # cross batch memory with batch_size == memory_size should be equivalent to just using the inner loss function for dtype in TEST_DTYPES: for test_enqueue_idx in [False, True]: for memory_size in range(20, 40, 5): inner_loss = NTXentLoss(temperature=0.1) inner_miner = TripletMarginMiner(margin=0.1) loss = CrossBatchMemory( loss=inner_loss, embedding_size=self.embedding_size, memory_size=memory_size, ) loss_with_miner = CrossBatchMemory( loss=inner_loss, embedding_size=self.embedding_size, memory_size=memory_size, miner=inner_miner, ) for i in range(10): if test_enqueue_idx: enqueue_idx = torch.arange(memory_size, memory_size * 2) not_enqueue_idx = torch.arange(memory_size) batch_size = memory_size * 2 else: enqueue_idx = None batch_size = memory_size embeddings = (torch.randn( batch_size, self.embedding_size).to(TEST_DEVICE).type(dtype)) labels = torch.randint(0, 4, (batch_size, )).to(TEST_DEVICE) if test_enqueue_idx: pairs = lmu.get_all_pairs_indices( labels[not_enqueue_idx], labels[enqueue_idx]) pairs = c_f.shift_indices_tuple(pairs, memory_size) inner_loss_val = inner_loss( embeddings, labels, pairs) else: inner_loss_val = inner_loss(embeddings, labels) loss_val = loss(embeddings, labels, enqueue_idx=enqueue_idx) self.assertTrue(torch.isclose(inner_loss_val, loss_val)) if test_enqueue_idx: triplets = inner_miner( embeddings[not_enqueue_idx], labels[not_enqueue_idx], embeddings[enqueue_idx], labels[enqueue_idx], ) triplets = c_f.shift_indices_tuple( triplets, memory_size) inner_loss_val = inner_loss( embeddings, labels, triplets) else: triplets = inner_miner(embeddings, labels) inner_loss_val = inner_loss( embeddings, labels, triplets) loss_val = loss_with_miner(embeddings, labels, enqueue_idx=enqueue_idx) self.assertTrue(torch.isclose(inner_loss_val, loss_val))
def train_stage1(args, train_data_loader, validate_data_loader, teacher, student, model_save_path1): print('===== training stage 1 =====') # build a loss function loss_function = nn.KLDivLoss(reduction='batchmean') # build an optimizer optimizer1 = SGD(params=student.parameters(), lr=args.lr1, weight_decay=args.wd, momentum=args.mo, nesterov=True) # build a scheduler scheduler1 = CosineAnnealingLR(optimizer1, args.n_training_epochs1, 0.1 * args.lr1) # generate a semi-hard triplet miner miner = TripletMarginMiner(margin=0.2, type_of_triplets='semihard') training_loss_list1 = [] validating_accuracy_list1 = [] best_validating_accuracy = 0 for epoch in range(1, args.n_training_epochs1 + 1): # init training loss and n_triplets in this epoch training_loss = 0 n_triplets = 0 # build a bar if not args.flag_no_bar: total = train_data_loader.__len__() bar = tqdm(total=total, desc='stage1: epoch %d' % (epoch), unit='batch') student.train() for batch_index, batch in enumerate(train_data_loader): images, labels = batch images = images.float().cuda(args.devices[0]) labels = labels.long().cuda(args.devices[0]) # teacher embedding with torch.no_grad(): teacher_embedding = teacher.forward(images, flag_embedding=True) teacher_embedding = F.normalize(teacher_embedding, p=2, dim=1) # student embedding student_embedding = student.forward(images, flag_embedding=True) student_embedding = F.normalize(student_embedding, p=2, dim=1) # generate triplets with torch.no_grad(): anchor_id, positive_id, negative_id = miner( student_embedding, labels) # get teacher embedding in triplets teacher_anchor = teacher_embedding[anchor_id] teacher_positive = teacher_embedding[positive_id] teacher_negative = teacher_embedding[negative_id] # get student embedding in triplets student_anchor = student_embedding[anchor_id] student_positive = student_embedding[positive_id] student_negative = student_embedding[negative_id] # get a-p dist and a-n dist in teacher embedding teacher_ap_dist = torch.norm(teacher_anchor - teacher_positive, p=2, dim=1) teacher_an_dist = torch.norm(teacher_anchor - teacher_negative, p=2, dim=1) # get a-p dist and a-n dist in student embedding student_ap_dist = torch.norm(student_anchor - student_positive, p=2, dim=1) student_an_dist = torch.norm(student_anchor - student_negative, p=2, dim=1) # get probability of triplets in teacher embedding teacher_prob = torch.sigmoid( (teacher_an_dist - teacher_ap_dist) / args.tau1) teacher_prob_aug = torch.cat( [teacher_prob.unsqueeze(1), 1 - teacher_prob.unsqueeze(1)]) # get probability of triplets in student embedding student_prob = torch.sigmoid( (student_an_dist - student_ap_dist) / args.tau1) student_prob_aug = torch.cat( [student_prob.unsqueeze(1), 1 - student_prob.unsqueeze(1)]) loss_value = 1000 * loss_function(torch.log(student_prob_aug), teacher_prob_aug) optimizer1.zero_grad() loss_value.backward() optimizer1.step() training_loss += loss_value.cpu().item() * student_prob.size()[0] n_triplets += student_prob.size()[0] if not args.flag_no_bar: bar.update(1) # get average training loss training_loss /= n_triplets training_loss_list1.append(training_loss) if not args.flag_no_bar: bar.close() if epoch % 10 == 0: # get validating accuracy validating_accuracy = test_ncm(args, validate_data_loader, student, description='validating') validating_accuracy_list1.append(validating_accuracy) # output after each epoch print( 'epoch %d finish: training_loss = %f, validating_accuracy = %f' % (epoch, training_loss, validating_accuracy)) # if we find a better model if not args.flag_debug: if validating_accuracy > best_validating_accuracy: best_validating_accuracy = validating_accuracy record = { 'state_dict': student.state_dict(), 'validating_accuracy': validating_accuracy, 'epoch': epoch } torch.save(record, model_save_path1) else: # output after each epoch print('epoch %d finish: training_loss = %f' % (epoch, training_loss)) # adjust learning rate scheduler1.step() return training_loss_list1, validating_accuracy_list1