Ejemplo n.º 1
0
 def test_input_indices_tuple(self):
     for dtype in TEST_DTYPES:
         batch_size = 32
         pair_miner = PairMarginMiner(pos_margin=0, neg_margin=1)
         triplet_miner = TripletMarginMiner(margin=1)
         self.loss = CrossBatchMemory(
             loss=ContrastiveLoss(),
             embedding_size=self.embedding_size,
             memory_size=self.memory_size,
         )
         for i in range(30):
             embeddings = (torch.randn(
                 batch_size,
                 self.embedding_size).to(TEST_DEVICE).type(dtype))
             labels = torch.arange(batch_size).to(TEST_DEVICE)
             self.loss(embeddings, labels)
             for curr_miner in [pair_miner, triplet_miner]:
                 input_indices_tuple = curr_miner(embeddings, labels)
                 all_labels = torch.cat([labels, self.loss.label_memory],
                                        dim=0)
                 a1ii, pii, a2ii, nii = lmu.convert_to_pairs(
                     input_indices_tuple, labels)
                 indices_tuple = lmu.get_all_pairs_indices(
                     labels, self.loss.label_memory)
                 a1i, pi, a2i, ni = self.loss.remove_self_comparisons(
                     indices_tuple)
                 a1, p, a2, n = self.loss.create_indices_tuple(
                     batch_size,
                     embeddings,
                     labels,
                     self.loss.embedding_memory,
                     self.loss.label_memory,
                     input_indices_tuple,
                     True,
                 )
                 self.assertTrue(not torch.any((all_labels[a1] -
                                                all_labels[p]).bool()))
                 self.assertTrue(
                     torch.all((all_labels[a2] - all_labels[n]).bool()))
                 self.assertTrue(len(a1) == len(a1i) + len(a1ii))
                 self.assertTrue(len(p) == len(pi) + len(pii))
                 self.assertTrue(len(a2) == len(a2i) + len(a2ii))
                 self.assertTrue(len(n) == len(ni) + len(nii))
Ejemplo n.º 2
0
    def test_shift_indices_tuple(self):
        batch_size = 32
        pair_miner = PairMarginMiner(pos_margin=0, neg_margin=1, use_similarity=False)
        triplet_miner = TripletMarginMiner(margin=1)
        self.loss = CrossBatchMemory(loss=ContrastiveLoss(), embedding_size=self.embedding_size, memory_size=self.memory_size)
        for i in range(30):
            embeddings = torch.randn(batch_size, self.embedding_size)
            labels = torch.arange(batch_size)
            loss = self.loss(embeddings, labels)
            all_labels = torch.cat([labels, self.loss.label_memory], dim=0)

            indices_tuple = lmu.get_all_pairs_indices(labels, self.loss.label_memory)
            shifted = c_f.shift_indices_tuple(indices_tuple, batch_size)
            self.assertTrue(torch.equal(indices_tuple[0], shifted[0]))
            self.assertTrue(torch.equal(indices_tuple[2], shifted[2]))
            self.assertTrue(torch.equal(indices_tuple[1], shifted[1]-batch_size))
            self.assertTrue(torch.equal(indices_tuple[3], shifted[3]-batch_size))
            a1, p, a2, n = shifted
            self.assertTrue(not torch.any((all_labels[a1]-all_labels[p]).bool()))
            self.assertTrue(torch.all((all_labels[a2]-all_labels[n]).bool()))
            
            indices_tuple = pair_miner(embeddings, labels, self.loss.embedding_memory, self.loss.label_memory)
            shifted = c_f.shift_indices_tuple(indices_tuple, batch_size)
            self.assertTrue(torch.equal(indices_tuple[0], shifted[0]))
            self.assertTrue(torch.equal(indices_tuple[2], shifted[2]))
            self.assertTrue(torch.equal(indices_tuple[1], shifted[1]-batch_size))
            self.assertTrue(torch.equal(indices_tuple[3], shifted[3]-batch_size))
            a1, p, a2, n = shifted
            self.assertTrue(not torch.any((all_labels[a1]-all_labels[p]).bool()))
            self.assertTrue(torch.all((all_labels[a2]-all_labels[n]).bool()))

            indices_tuple = triplet_miner(embeddings, labels, self.loss.embedding_memory, self.loss.label_memory)
            shifted = c_f.shift_indices_tuple(indices_tuple, batch_size)
            self.assertTrue(torch.equal(indices_tuple[0], shifted[0]))
            self.assertTrue(torch.equal(indices_tuple[1], shifted[1]-batch_size))
            self.assertTrue(torch.equal(indices_tuple[2], shifted[2]-batch_size))
            a, p, n = shifted
            self.assertTrue(not torch.any((all_labels[a]-all_labels[p]).bool()))
            self.assertTrue(torch.all((all_labels[p]-all_labels[n]).bool()))
Ejemplo n.º 3
0
    def test_triplet_margin_miner(self):
        for dtype in TEST_DTYPES:
            for distance in [LpDistance(), CosineSimilarity()]:
                embedding_angles = torch.arange(0, 16)
                embeddings = torch.tensor(
                    [c_f.angle_to_coord(a) for a in embedding_angles],
                    requires_grad=True,
                    dtype=dtype,
                ).to(self.device)  # 2D embeddings
                labels = torch.randint(low=0, high=2, size=(16, ))
                mat = distance(embeddings)
                triplets = []
                for i in range(len(embeddings)):
                    anchor_label = labels[i]
                    for j in range(len(embeddings)):
                        if j == i:
                            continue
                        positive_label = labels[j]
                        if positive_label == anchor_label:
                            ap_dist = mat[i, j]
                            for k in range(len(embeddings)):
                                if k == j or k == i:
                                    continue
                                negative_label = labels[k]
                                if negative_label != positive_label:
                                    an_dist = mat[i, k]
                                    if distance.is_inverted:
                                        triplets.append(
                                            (i, j, k, ap_dist - an_dist))
                                    else:
                                        triplets.append(
                                            (i, j, k, an_dist - ap_dist))

                for margin_int in range(-1, 11):
                    margin = float(margin_int) * 0.05
                    minerA = TripletMarginMiner(margin,
                                                type_of_triplets="all",
                                                distance=distance)
                    minerB = TripletMarginMiner(margin,
                                                type_of_triplets="hard",
                                                distance=distance)
                    minerC = TripletMarginMiner(margin,
                                                type_of_triplets="semihard",
                                                distance=distance)
                    minerD = TripletMarginMiner(margin,
                                                type_of_triplets="easy",
                                                distance=distance)

                    correctA, correctB, correctC, correctD = [], [], [], []
                    for i, j, k, distance_diff in triplets:
                        if distance_diff > margin:
                            correctD.append((i, j, k))
                        else:
                            correctA.append((i, j, k))
                            if distance_diff > 0:
                                correctC.append((i, j, k))
                            if distance_diff <= 0:
                                correctB.append((i, j, k))

                    for correct, miner in [
                        (correctA, minerA),
                        (correctB, minerB),
                        (correctC, minerC),
                        (correctD, minerD),
                    ]:
                        correct_triplets = set(correct)
                        a1, p1, n1 = miner(embeddings, labels)
                        mined_triplets = set([(a.item(), p.item(), n.item())
                                              for a, p, n in zip(a1, p1, n1)])
                        self.assertTrue(mined_triplets == correct_triplets)
Ejemplo n.º 4
0
    def test_sanity_check(self):
        # cross batch memory with batch_size == memory_size should be equivalent to just using the inner loss function
        for dtype in TEST_DTYPES:
            for test_enqueue_idx in [False, True]:
                for memory_size in range(20, 40, 5):
                    inner_loss = NTXentLoss(temperature=0.1)
                    inner_miner = TripletMarginMiner(margin=0.1)
                    loss = CrossBatchMemory(
                        loss=inner_loss,
                        embedding_size=self.embedding_size,
                        memory_size=memory_size,
                    )
                    loss_with_miner = CrossBatchMemory(
                        loss=inner_loss,
                        embedding_size=self.embedding_size,
                        memory_size=memory_size,
                        miner=inner_miner,
                    )
                    for i in range(10):
                        if test_enqueue_idx:
                            enqueue_idx = torch.arange(memory_size,
                                                       memory_size * 2)
                            not_enqueue_idx = torch.arange(memory_size)
                            batch_size = memory_size * 2
                        else:
                            enqueue_idx = None
                            batch_size = memory_size
                        embeddings = (torch.randn(
                            batch_size,
                            self.embedding_size).to(TEST_DEVICE).type(dtype))
                        labels = torch.randint(0, 4,
                                               (batch_size, )).to(TEST_DEVICE)

                        if test_enqueue_idx:
                            pairs = lmu.get_all_pairs_indices(
                                labels[not_enqueue_idx], labels[enqueue_idx])
                            pairs = c_f.shift_indices_tuple(pairs, memory_size)
                            inner_loss_val = inner_loss(
                                embeddings, labels, pairs)
                        else:
                            inner_loss_val = inner_loss(embeddings, labels)
                        loss_val = loss(embeddings,
                                        labels,
                                        enqueue_idx=enqueue_idx)
                        self.assertTrue(torch.isclose(inner_loss_val,
                                                      loss_val))

                        if test_enqueue_idx:
                            triplets = inner_miner(
                                embeddings[not_enqueue_idx],
                                labels[not_enqueue_idx],
                                embeddings[enqueue_idx],
                                labels[enqueue_idx],
                            )
                            triplets = c_f.shift_indices_tuple(
                                triplets, memory_size)
                            inner_loss_val = inner_loss(
                                embeddings, labels, triplets)
                        else:
                            triplets = inner_miner(embeddings, labels)
                            inner_loss_val = inner_loss(
                                embeddings, labels, triplets)
                        loss_val = loss_with_miner(embeddings,
                                                   labels,
                                                   enqueue_idx=enqueue_idx)
                        self.assertTrue(torch.isclose(inner_loss_val,
                                                      loss_val))
Ejemplo n.º 5
0
def train_stage1(args, train_data_loader, validate_data_loader, teacher,
                 student, model_save_path1):
    print('===== training stage 1 =====')
    # build a loss function
    loss_function = nn.KLDivLoss(reduction='batchmean')
    # build an optimizer
    optimizer1 = SGD(params=student.parameters(),
                     lr=args.lr1,
                     weight_decay=args.wd,
                     momentum=args.mo,
                     nesterov=True)
    # build a scheduler
    scheduler1 = CosineAnnealingLR(optimizer1, args.n_training_epochs1,
                                   0.1 * args.lr1)
    # generate a semi-hard triplet miner
    miner = TripletMarginMiner(margin=0.2, type_of_triplets='semihard')

    training_loss_list1 = []
    validating_accuracy_list1 = []
    best_validating_accuracy = 0

    for epoch in range(1, args.n_training_epochs1 + 1):
        # init training loss and n_triplets in this epoch
        training_loss = 0
        n_triplets = 0
        # build a bar
        if not args.flag_no_bar:
            total = train_data_loader.__len__()
            bar = tqdm(total=total,
                       desc='stage1: epoch %d' % (epoch),
                       unit='batch')

        student.train()
        for batch_index, batch in enumerate(train_data_loader):
            images, labels = batch
            images = images.float().cuda(args.devices[0])
            labels = labels.long().cuda(args.devices[0])

            # teacher embedding
            with torch.no_grad():
                teacher_embedding = teacher.forward(images,
                                                    flag_embedding=True)
                teacher_embedding = F.normalize(teacher_embedding, p=2, dim=1)

            # student embedding
            student_embedding = student.forward(images, flag_embedding=True)
            student_embedding = F.normalize(student_embedding, p=2, dim=1)

            # generate triplets
            with torch.no_grad():
                anchor_id, positive_id, negative_id = miner(
                    student_embedding, labels)

            # get teacher embedding in triplets
            teacher_anchor = teacher_embedding[anchor_id]
            teacher_positive = teacher_embedding[positive_id]
            teacher_negative = teacher_embedding[negative_id]

            # get student embedding in triplets
            student_anchor = student_embedding[anchor_id]
            student_positive = student_embedding[positive_id]
            student_negative = student_embedding[negative_id]

            # get a-p dist and a-n dist in teacher embedding
            teacher_ap_dist = torch.norm(teacher_anchor - teacher_positive,
                                         p=2,
                                         dim=1)
            teacher_an_dist = torch.norm(teacher_anchor - teacher_negative,
                                         p=2,
                                         dim=1)

            # get a-p dist and a-n dist in student embedding
            student_ap_dist = torch.norm(student_anchor - student_positive,
                                         p=2,
                                         dim=1)
            student_an_dist = torch.norm(student_anchor - student_negative,
                                         p=2,
                                         dim=1)

            # get probability of triplets in teacher embedding
            teacher_prob = torch.sigmoid(
                (teacher_an_dist - teacher_ap_dist) / args.tau1)
            teacher_prob_aug = torch.cat(
                [teacher_prob.unsqueeze(1), 1 - teacher_prob.unsqueeze(1)])

            # get probability of triplets in student embedding
            student_prob = torch.sigmoid(
                (student_an_dist - student_ap_dist) / args.tau1)
            student_prob_aug = torch.cat(
                [student_prob.unsqueeze(1), 1 - student_prob.unsqueeze(1)])

            loss_value = 1000 * loss_function(torch.log(student_prob_aug),
                                              teacher_prob_aug)

            optimizer1.zero_grad()
            loss_value.backward()
            optimizer1.step()

            training_loss += loss_value.cpu().item() * student_prob.size()[0]
            n_triplets += student_prob.size()[0]

            if not args.flag_no_bar:
                bar.update(1)

        # get average training loss
        training_loss /= n_triplets
        training_loss_list1.append(training_loss)

        if not args.flag_no_bar:
            bar.close()

        if epoch % 10 == 0:
            # get validating accuracy
            validating_accuracy = test_ncm(args,
                                           validate_data_loader,
                                           student,
                                           description='validating')
            validating_accuracy_list1.append(validating_accuracy)
            # output after each epoch
            print(
                'epoch %d finish: training_loss = %f, validating_accuracy = %f'
                % (epoch, training_loss, validating_accuracy))

            # if we find a better model
            if not args.flag_debug:
                if validating_accuracy > best_validating_accuracy:
                    best_validating_accuracy = validating_accuracy
                    record = {
                        'state_dict': student.state_dict(),
                        'validating_accuracy': validating_accuracy,
                        'epoch': epoch
                    }
                    torch.save(record, model_save_path1)
        else:
            # output after each epoch
            print('epoch %d finish: training_loss = %f' %
                  (epoch, training_loss))

        # adjust learning rate
        scheduler1.step()

    return training_loss_list1, validating_accuracy_list1