def test_with_no_valid_pairs(self):
     loss = NTXentLoss(temperature=0.1)
     embedding_angles = [0]
     embeddings = torch.tensor(
         [c_f.angle_to_coord(a) for a in embedding_angles],
         requires_grad=True,
         dtype=torch.float)  #2D embeddings
     labels = torch.LongTensor([0])
     loss = loss(embeddings, labels)
     loss.backward()
     self.assertEqual(loss, 0)
    def test_ntxent_loss(self):
        temperature = 0.1
        loss_funcA = NTXentLoss(temperature=temperature)
        loss_funcB = NTXentLoss(temperature=temperature, distance=LpDistance())

        for dtype in TEST_DTYPES:
            embedding_angles = [0, 20, 40, 60, 80]
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embedding_angles],
                requires_grad=True,
                dtype=dtype).to(self.device)  #2D embeddings

            labels = torch.LongTensor([0, 0, 1, 1, 2])

            lossA = loss_funcA(embeddings, labels)
            lossB = loss_funcB(embeddings, labels)

            pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)]
            neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4),
                         (2, 0), (2, 1), (2, 4), (3, 0), (3, 1), (3, 4),
                         (4, 0), (4, 1), (4, 2), (4, 3)]

            total_lossA, total_lossB = 0, 0
            for a1, p in pos_pairs:
                anchor, positive = embeddings[a1], embeddings[p]
                numeratorA = torch.exp(
                    torch.matmul(anchor, positive) / temperature)
                numeratorB = torch.exp(
                    -torch.sqrt(torch.sum(
                        (anchor - positive)**2)) / temperature)
                denominatorA = numeratorA.clone()
                denominatorB = numeratorB.clone()
                for a2, n in neg_pairs:
                    if a2 == a1:
                        negative = embeddings[n]
                    else:
                        continue
                    denominatorA += torch.exp(
                        torch.matmul(anchor, negative) / temperature)
                    denominatorB += torch.exp(
                        -torch.sqrt(torch.sum(
                            (anchor - negative)**2)) / temperature)
                curr_lossA = -torch.log(numeratorA / denominatorA)
                curr_lossB = -torch.log(numeratorB / denominatorB)
                total_lossA += curr_lossA
                total_lossB += curr_lossB

            total_lossA /= len(pos_pairs)
            total_lossB /= len(pos_pairs)
            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            self.assertTrue(torch.isclose(lossA, total_lossA, rtol=rtol))
            self.assertTrue(torch.isclose(lossB, total_lossB, rtol=rtol))
 def test_backward(self):
     temperature = 0.1
     loss_funcA = NTXentLoss(temperature=temperature)
     loss_funcB = NTXentLoss(temperature=temperature, distance=LpDistance())
     for dtype in TEST_DTYPES:
         for loss_func in [loss_funcA, loss_funcB]:
             embedding_angles = [0, 20, 40, 60, 80]
             embeddings = torch.tensor(
                 [c_f.angle_to_coord(a) for a in embedding_angles],
                 requires_grad=True,
                 dtype=dtype).to(self.device)  #2D embeddings
             labels = torch.LongTensor([0, 0, 1, 1, 2])
             loss = loss_func(embeddings, labels)
             loss.backward()
Ejemplo n.º 4
0
    def test_sanity_check(self):
        # cross batch memory with batch_size == memory_size should be equivalent to just using the inner loss function
        for dtype in TEST_DTYPES:
            for test_enqueue_idx in [False, True]:
                for memory_size in range(20, 40, 5):
                    inner_loss = NTXentLoss(temperature=0.1)
                    inner_miner = TripletMarginMiner(margin=0.1)
                    loss = CrossBatchMemory(
                        loss=inner_loss,
                        embedding_size=self.embedding_size,
                        memory_size=memory_size,
                    )
                    loss_with_miner = CrossBatchMemory(
                        loss=inner_loss,
                        embedding_size=self.embedding_size,
                        memory_size=memory_size,
                        miner=inner_miner,
                    )
                    for i in range(10):
                        if test_enqueue_idx:
                            enqueue_idx = torch.arange(memory_size, memory_size * 2)
                            not_enqueue_idx = torch.arange(memory_size)
                            batch_size = memory_size * 2
                        else:
                            enqueue_idx = None
                            batch_size = memory_size
                        embeddings = (
                            torch.randn(batch_size, self.embedding_size)
                            .to(TEST_DEVICE)
                            .type(dtype)
                        )
                        labels = torch.randint(0, 4, (batch_size,)).to(TEST_DEVICE)

                        if test_enqueue_idx:
                            pairs = lmu.get_all_pairs_indices(
                                labels[not_enqueue_idx], labels[enqueue_idx]
                            )
                            pairs = c_f.shift_indices_tuple(pairs, memory_size)
                            inner_loss_val = inner_loss(embeddings, labels, pairs)
                        else:
                            inner_loss_val = inner_loss(embeddings, labels)
                        loss_val = loss(embeddings, labels, enqueue_idx=enqueue_idx)
                        self.assertTrue(torch.isclose(inner_loss_val, loss_val))

                        if test_enqueue_idx:
                            triplets = inner_miner(
                                embeddings[not_enqueue_idx],
                                labels[not_enqueue_idx],
                                embeddings[enqueue_idx],
                                labels[enqueue_idx],
                            )
                            triplets = c_f.shift_indices_tuple(triplets, memory_size)
                            inner_loss_val = inner_loss(embeddings, labels, triplets)
                        else:
                            triplets = inner_miner(embeddings, labels)
                            inner_loss_val = inner_loss(embeddings, labels, triplets)
                        loss_val = loss_with_miner(
                            embeddings, labels, enqueue_idx=enqueue_idx
                        )
                        self.assertTrue(torch.isclose(inner_loss_val, loss_val))
Ejemplo n.º 5
0
 def test_with_distance_weighted_miner(self):
     for dtype in TEST_DTYPES:
         memory_size = 256
         inner_loss = NTXentLoss(temperature=0.1)
         inner_miner = DistanceWeightedMiner(cutoff=0.5,
                                             nonzero_loss_cutoff=1.4)
         loss_with_miner = CrossBatchMemory(
             loss=inner_loss,
             embedding_size=2,
             memory_size=memory_size,
             miner=inner_miner,
         )
         for i in range(20):
             embedding_angles = torch.arange(0, 32)
             embeddings = torch.tensor(
                 [c_f.angle_to_coord(a) for a in embedding_angles],
                 requires_grad=True,
                 dtype=dtype,
             ).to(TEST_DEVICE)  # 2D embeddings
             labels = torch.randint(low=0, high=10,
                                    size=(32, )).to(TEST_DEVICE)
             loss_val = loss_with_miner(embeddings, labels)
             loss_val.backward()
             self.assertTrue(
                 True)  # just check if we got here without an exception
Ejemplo n.º 6
0
 def __init__(self, args):
     super(CP, self).__init__()
     self.model = BertForMaskedLM.from_pretrained(args.model_name)
     self.tokenizer = BertTokenizer.from_pretrained(args.model_name,
                                                    do_basic_tokenize=False)
     self.ntxloss = NTXentLoss(temperature=args.temperature)
     self.args = args
Ejemplo n.º 7
0
    def test_ntxent_loss(self):
        temperature = 0.1
        loss_func = NTXentLoss(temperature=temperature)

        embedding_angles = [0, 20, 40, 60, 80]
        embeddings = torch.FloatTensor(
            [c_f.angle_to_coord(a) for a in embedding_angles])  #2D embeddings
        labels = torch.LongTensor([0, 0, 1, 1, 2])

        loss = loss_func(embeddings, labels)

        pos_pairs = [(0, 1), (1, 0), (2, 3), (3, 2)]
        neg_pairs = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 0),
                     (2, 1), (2, 4), (3, 0), (3, 1), (3, 4), (4, 0), (4, 1),
                     (4, 2), (4, 3)]

        total_loss = 0
        for a1, p in pos_pairs:
            anchor, positive = embeddings[a1], embeddings[p]
            numerator = torch.exp(torch.matmul(anchor, positive) / temperature)
            denominator = numerator.clone()
            for a2, n in neg_pairs:
                if a2 == a1:
                    negative = embeddings[n]
                else:
                    continue
                denominator += torch.exp(
                    torch.matmul(anchor, negative) / temperature)
            curr_loss = -torch.log(numerator / denominator)
            total_loss += curr_loss

        total_loss /= len(pos_pairs)
        self.assertTrue(torch.isclose(loss, total_loss))
Ejemplo n.º 8
0
 def test_with_no_valid_pairs(self):
     loss = NTXentLoss(temperature=0.1)
     embedding_angles = [0]
     embeddings = torch.FloatTensor(
         [c_f.angle_to_coord(a) for a in embedding_angles])  #2D embeddings
     labels = torch.LongTensor([0])
     self.assertEqual(loss(embeddings, labels), 0)
    def test_sanity_check(self):
        # cross batch memory with batch_size == memory_size should be equivalent to just using the inner loss function
        for dtype in [torch.float16, torch.float32, torch.float64]:
            for memory_size in range(20, 40, 5):
                inner_loss = NTXentLoss(temperature=0.1)
                inner_miner = TripletMarginMiner(margin=0.1)
                loss = CrossBatchMemory(loss=inner_loss,
                                        embedding_size=self.embedding_size,
                                        memory_size=memory_size)
                loss_with_miner = CrossBatchMemory(
                    loss=inner_loss,
                    embedding_size=self.embedding_size,
                    memory_size=memory_size,
                    miner=inner_miner)
                for i in range(10):
                    embeddings = torch.randn(memory_size,
                                             self.embedding_size).to(
                                                 self.device).type(dtype)
                    labels = torch.randint(0, 4,
                                           (memory_size, )).to(self.device)
                    inner_loss_val = inner_loss(embeddings, labels)
                    loss_val = loss(embeddings, labels)
                    self.assertTrue(torch.isclose(inner_loss_val, loss_val))

                    triplets = inner_miner(embeddings, labels)
                    inner_loss_val = inner_loss(embeddings, labels, triplets)
                    loss_val = loss_with_miner(embeddings, labels)
                    self.assertTrue(torch.isclose(inner_loss_val, loss_val))
Ejemplo n.º 10
0
    def test_metric_loss_only(self):
        loss_fn = NTXentLoss()
        dataset = datasets.FakeData()
        model = torch.nn.Identity()
        batch_size = 32

        for trainer_class in [
            MetricLossOnly,
            DeepAdversarialMetricLearning,
            TrainWithClassifier,
            TwoStreamMetricLoss,
        ]:

            model_dict = {"trunk": model}
            optimizer_dict = {"trunk_optimizer": None}
            loss_fn_dict = {"metric_loss": loss_fn}
            lr_scheduler_dict = {"trunk_scheduler_by_iteration": None}

            if trainer_class is DeepAdversarialMetricLearning:
                model_dict["generator"] = model
                loss_fn_dict["synth_loss"] = loss_fn
                loss_fn_dict["g_adv_loss"] = TripletMarginLoss()

            kwargs = {
                "models": model_dict,
                "optimizers": optimizer_dict,
                "batch_size": batch_size,
                "loss_funcs": loss_fn_dict,
                "mining_funcs": {},
                "dataset": dataset,
                "freeze_these": ["trunk"],
                "lr_schedulers": lr_scheduler_dict,
            }

            trainer = trainer_class(**kwargs)

            for k in [
                "models",
                "mining_funcs",
                "loss_funcs",
                "freeze_these",
                "lr_schedulers",
            ]:
                new_kwargs = copy.deepcopy(kwargs)
                if k == "models":
                    new_kwargs[k] = {}
                if k == "mining_funcs":
                    new_kwargs[k] = {"dog": None}
                if k == "loss_funcs":
                    if trainer_class is DeepAdversarialMetricLearning:
                        new_kwargs[k] = {}
                    else:
                        continue
                if k == "freeze_these":
                    new_kwargs[k] = ["frog"]
                if k == "lr_schedulers":
                    new_kwargs[k] = {"trunk_scheduler": None}
                with self.assertRaises(AssertionError):
                    trainer = trainer_class(**new_kwargs)
 def test_with_no_valid_pairs(self):
     loss_func = NTXentLoss(temperature=0.1)
     for dtype in TEST_DTYPES:
         embedding_angles = [0]
         embeddings = torch.tensor(
             [c_f.angle_to_coord(a) for a in embedding_angles],
             requires_grad=True,
             dtype=dtype).to(self.device)  #2D embeddings
         labels = torch.LongTensor([0])
         loss = loss_func(embeddings, labels)
         loss.backward()
         self.assertEqual(loss, 0)
Ejemplo n.º 12
0
 def test_with_no_valid_pairs(self):
     loss_func = NTXentLoss(temperature=0.1)
     all_embedding_angles = [[0], [0, 10, 20]]
     all_labels = [torch.LongTensor([0]), torch.LongTensor([0, 0, 0])]
     for dtype in TEST_DTYPES:
         for embedding_angles, labels in zip(all_embedding_angles,
                                             all_labels):
             embeddings = torch.tensor(
                 [c_f.angle_to_coord(a) for a in embedding_angles],
                 requires_grad=True,
                 dtype=dtype,
             ).to(TEST_DEVICE)  # 2D embeddings
             loss = loss_func(embeddings, labels)
             loss.backward()
             self.assertEqual(loss, 0)
Ejemplo n.º 13
0
    def test_metric_loss_only(self):

        cifar_resnet_folder = "temp_cifar_resnet_for_pytorch_metric_learning_test"
        dataset_folder = "temp_dataset_for_pytorch_metric_learning_test"
        model_folder = "temp_saved_models_for_pytorch_metric_learning_test"
        logs_folder = "temp_logs_for_pytorch_metric_learning_test"
        tensorboard_folder = "temp_tensorboard_for_pytorch_metric_learning_test"

        os.system(
            "git clone https://github.com/akamaster/pytorch_resnet_cifar10.git {}"
            .format(cifar_resnet_folder))

        loss_fn = NTXentLoss()

        normalize_transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                   std=[0.229, 0.224, 0.225])

        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize_transform,
        ])

        eval_transform = transforms.Compose(
            [transforms.ToTensor(), normalize_transform])

        assert not os.path.isdir(dataset_folder)
        assert not os.path.isdir(model_folder)
        assert not os.path.isdir(logs_folder)
        assert not os.path.isdir(tensorboard_folder)

        subset_idx = np.arange(10000)

        train_dataset = datasets.CIFAR100(dataset_folder,
                                          train=True,
                                          download=True,
                                          transform=train_transform)

        train_dataset_for_eval = datasets.CIFAR100(dataset_folder,
                                                   train=True,
                                                   download=True,
                                                   transform=eval_transform)

        val_dataset = datasets.CIFAR100(dataset_folder,
                                        train=False,
                                        download=True,
                                        transform=eval_transform)

        train_dataset = torch.utils.data.Subset(train_dataset, subset_idx)
        train_dataset_for_eval = torch.utils.data.Subset(
            train_dataset_for_eval, subset_idx)
        val_dataset = torch.utils.data.Subset(val_dataset, subset_idx)

        for dtype in TEST_DTYPES:
            for splits_to_eval in [
                    None,
                [("train", ["train", "val"]), ("val", ["train", "val"])],
            ]:
                from temp_cifar_resnet_for_pytorch_metric_learning_test import resnet

                model = torch.nn.DataParallel(resnet.resnet20())
                checkpoint = torch.load(
                    "{}/pretrained_models/resnet20-12fca82f.th".format(
                        cifar_resnet_folder),
                    map_location=TEST_DEVICE,
                )
                model.load_state_dict(checkpoint["state_dict"])
                model.module.linear = c_f.Identity()
                if TEST_DEVICE == torch.device("cpu"):
                    model = model.module
                model = model.to(TEST_DEVICE).type(dtype)

                optimizer = torch.optim.Adam(
                    model.parameters(),
                    lr=0.0002,
                    weight_decay=0.0001,
                    eps=1e-04,
                )

                batch_size = 32
                iterations_per_epoch = None if splits_to_eval is None else 1
                model_dict = {"trunk": model}
                optimizer_dict = {"trunk_optimizer": optimizer}
                loss_fn_dict = {"metric_loss": loss_fn}
                sampler = MPerClassSampler(
                    np.array(train_dataset.dataset.targets)[subset_idx],
                    m=4,
                    batch_size=32,
                    length_before_new_iter=len(train_dataset),
                )

                record_keeper, _, _ = logging_presets.get_record_keeper(
                    logs_folder, tensorboard_folder)
                hooks = logging_presets.get_hook_container(
                    record_keeper, primary_metric="precision_at_1")
                dataset_dict = {
                    "train": train_dataset_for_eval,
                    "val": val_dataset
                }

                tester = GlobalEmbeddingSpaceTester(
                    end_of_testing_hook=hooks.end_of_testing_hook,
                    accuracy_calculator=accuracy_calculator.AccuracyCalculator(
                        include=("precision_at_1", "AMI"), k=1),
                    data_device=TEST_DEVICE,
                    dtype=dtype,
                    dataloader_num_workers=32,
                )

                end_of_epoch_hook = hooks.end_of_epoch_hook(
                    tester,
                    dataset_dict,
                    model_folder,
                    test_interval=1,
                    patience=1,
                    splits_to_eval=splits_to_eval,
                )

                trainer = MetricLossOnly(
                    models=model_dict,
                    optimizers=optimizer_dict,
                    batch_size=batch_size,
                    loss_funcs=loss_fn_dict,
                    mining_funcs={},
                    dataset=train_dataset,
                    sampler=sampler,
                    data_device=TEST_DEVICE,
                    dtype=dtype,
                    dataloader_num_workers=32,
                    iterations_per_epoch=iterations_per_epoch,
                    freeze_trunk_batchnorm=True,
                    end_of_iteration_hook=hooks.end_of_iteration_hook,
                    end_of_epoch_hook=end_of_epoch_hook,
                )

                num_epochs = 3
                trainer.train(num_epochs=num_epochs)
                best_epoch, best_accuracy = hooks.get_best_epoch_and_accuracy(
                    tester, "val")
                if splits_to_eval is None:
                    self.assertTrue(best_epoch == 3)
                    self.assertTrue(best_accuracy > 0.2)

                accuracies, primary_metric_key = hooks.get_accuracies_of_best_epoch(
                    tester, "val")
                accuracies = c_f.sqliteObjToDict(accuracies)
                self.assertTrue(
                    accuracies[primary_metric_key][0] == best_accuracy)
                self.assertTrue(primary_metric_key == "precision_at_1_level0")

                best_epoch_accuracies = hooks.get_accuracies_of_epoch(
                    tester, "val", best_epoch)
                best_epoch_accuracies = c_f.sqliteObjToDict(
                    best_epoch_accuracies)
                self.assertTrue(best_epoch_accuracies[primary_metric_key][0] ==
                                best_accuracy)

                accuracy_history = hooks.get_accuracy_history(tester, "val")
                self.assertTrue(accuracy_history[primary_metric_key][
                    accuracy_history["epoch"].index(best_epoch)] ==
                                best_accuracy)

                loss_history = hooks.get_loss_history()
                if splits_to_eval is None:
                    self.assertTrue(
                        len(loss_history["metric_loss"]) == (len(sampler) /
                                                             batch_size) *
                        num_epochs)

                curr_primary_metric = hooks.get_curr_primary_metric(
                    tester, "val")
                self.assertTrue(curr_primary_metric ==
                                accuracy_history[primary_metric_key][-1])

                base_record_group_name = hooks.base_record_group_name(tester)

                self.assertTrue(
                    base_record_group_name ==
                    "accuracies_normalized_GlobalEmbeddingSpaceTester_level_0")

                record_group_name = hooks.record_group_name(tester, "val")

                if splits_to_eval is None:
                    self.assertTrue(
                        record_group_name ==
                        "accuracies_normalized_GlobalEmbeddingSpaceTester_level_0_VAL_vs_self"
                    )
                else:
                    self.assertTrue(
                        record_group_name ==
                        "accuracies_normalized_GlobalEmbeddingSpaceTester_level_0_VAL_vs_TRAIN_and_VAL"
                    )

                shutil.rmtree(model_folder)
                shutil.rmtree(logs_folder)
                shutil.rmtree(tensorboard_folder)

        shutil.rmtree(cifar_resnet_folder)
        shutil.rmtree(dataset_folder)
Ejemplo n.º 14
0
    model = SimCLR(i3d, n_projection, n_features)
    model.load_state_dict(model_weights)
    print("Loading the whole SimCLR model")
else:
    i3d = InceptionI3d(n_features, in_channels=3)
    i3d.load_state_dict(torch.load(model_weights))
    print("Previous weights loaded")

    model = SimCLR(i3d, n_projection, n_features)

# Construct SimCLR model

model.to(device)

# Loss function
criterion = NTXentLoss(temperature=0.10)

# LARS optimizer

learning_rate = 0.3 * (batch_size * cumulation) / 256
optimizer = LARS(
    model.parameters(),
    lr=learning_rate,
    weight_decay=0.0005,
    exclude_from_weight_decay=["batch_normalization", "bias"],
    device=device
)

# "decay the learning rate with the cosine decay schedule without restarts"
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 100, eta_min=0, last_epoch=-1
    def test_ntxent_loss(self):
        temperature = 0.1
        loss_funcA = NTXentLoss(temperature=temperature)
        loss_funcB = NTXentLoss(temperature=temperature, distance=LpDistance())
        loss_funcC = NTXentLoss(
            temperature=temperature, reducer=PerAnchorReducer(AvgNonZeroReducer())
        )
        loss_funcD = SupConLoss(temperature=temperature)
        loss_funcE = SupConLoss(temperature=temperature, distance=LpDistance())

        for dtype in TEST_DTYPES:
            embedding_angles = [0, 10, 20, 50, 60, 80]
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embedding_angles],
                requires_grad=True,
                dtype=dtype,
            ).to(
                TEST_DEVICE
            )  # 2D embeddings

            labels = torch.LongTensor([0, 0, 0, 1, 1, 2])

            obtained_losses = [
                x(embeddings, labels)
                for x in [loss_funcA, loss_funcB, loss_funcC, loss_funcD, loss_funcE]
            ]

            pos_pairs = [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1), (3, 4), (4, 3)]
            neg_pairs = [
                (0, 3),
                (0, 4),
                (0, 5),
                (1, 3),
                (1, 4),
                (1, 5),
                (2, 3),
                (2, 4),
                (2, 5),
                (3, 0),
                (3, 1),
                (3, 2),
                (3, 5),
                (4, 0),
                (4, 1),
                (4, 2),
                (4, 5),
                (5, 0),
                (5, 1),
                (5, 2),
                (5, 3),
                (5, 4),
            ]

            total_lossA, total_lossB, total_lossC, total_lossD, total_lossE = (
                0,
                0,
                torch.zeros(5, device=TEST_DEVICE, dtype=dtype),
                torch.zeros(5, device=TEST_DEVICE, dtype=dtype),
                torch.zeros(5, device=TEST_DEVICE, dtype=dtype),
            )
            for a1, p in pos_pairs:
                anchor, positive = embeddings[a1], embeddings[p]
                numeratorA = torch.exp(torch.matmul(anchor, positive) / temperature)
                numeratorB = torch.exp(
                    -torch.sqrt(torch.sum((anchor - positive) ** 2)) / temperature
                )
                denominatorA = numeratorA.clone()
                denominatorB = numeratorB.clone()
                denominatorD = 0
                denominatorE = 0
                for a2, n in pos_pairs + neg_pairs:
                    if a2 == a1:
                        negative = embeddings[n]
                        curr_denomD = torch.exp(
                            torch.matmul(anchor, negative) / temperature
                        )
                        curr_denomE = torch.exp(
                            -torch.sqrt(torch.sum((anchor - negative) ** 2))
                            / temperature
                        )
                        denominatorD += curr_denomD
                        denominatorE += curr_denomE
                        if (a2, n) not in pos_pairs:
                            denominatorA += curr_denomD
                            denominatorB += curr_denomE
                    else:
                        continue

                curr_lossA = -torch.log(numeratorA / denominatorA)
                curr_lossB = -torch.log(numeratorB / denominatorB)
                curr_lossD = -torch.log(numeratorA / denominatorD)
                curr_lossE = -torch.log(numeratorB / denominatorE)
                total_lossA += curr_lossA
                total_lossB += curr_lossB
                total_lossC[a1] += curr_lossA
                total_lossD[a1] += curr_lossD
                total_lossE[a1] += curr_lossE

            total_lossA /= len(pos_pairs)
            total_lossB /= len(pos_pairs)
            pos_pair_per_anchor = torch.tensor(
                [2, 2, 2, 1, 1], device=TEST_DEVICE, dtype=dtype
            )
            total_lossC, total_lossD, total_lossE = [
                torch.mean(x / pos_pair_per_anchor)
                for x in [total_lossC, total_lossD, total_lossE]
            ]

            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            self.assertTrue(torch.isclose(obtained_losses[0], total_lossA, rtol=rtol))
            self.assertTrue(torch.isclose(obtained_losses[1], total_lossB, rtol=rtol))
            self.assertTrue(torch.isclose(obtained_losses[2], total_lossC, rtol=rtol))
            self.assertTrue(torch.isclose(obtained_losses[3], total_lossD, rtol=rtol))
            self.assertTrue(torch.isclose(obtained_losses[4], total_lossE, rtol=rtol))
Ejemplo n.º 16
0
 def __init__(self, args):
     super().__init__()
     self.model = BertForMaskedLM.from_pretrained('bert-base-uncased')
     self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
     self.ntxloss = NTXentLoss(temperature=args.temperature)
     self.args = args
Ejemplo n.º 17
0
 def __init__(self, embedding_size, memory_size):
     super(MoCo, self).__init__(NTXentLoss(temperature=0.1), embedding_size,
                                memory_size)