def test_proxyanchor_loss(self):
        num_classes = 10
        embedding_size = 2
        margin = 0.5
        for dtype in TEST_DTYPES:
            alpha = 1 if dtype == torch.float16 else 32
            loss_func = ProxyAnchorLoss(num_classes,
                                        embedding_size,
                                        margin=margin,
                                        alpha=alpha).to(TEST_DEVICE)
            original_loss_func = OriginalImplementationProxyAnchor(
                num_classes, embedding_size, mrg=margin,
                alpha=alpha).to(TEST_DEVICE)
            original_loss_func.proxies.data = original_loss_func.proxies.data.type(
                dtype)
            loss_func.proxies = original_loss_func.proxies

            embedding_angles = list(range(0, 180))
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embedding_angles],
                requires_grad=True,
                dtype=dtype,
            ).to(TEST_DEVICE)  # 2D embeddings
            labels = torch.randint(low=0, high=5, size=(180, )).to(TEST_DEVICE)

            loss = loss_func(embeddings, labels)
            loss.backward()
            correct_loss = original_loss_func(embeddings, labels)
            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol))
    def test_proxyanchor_loss(self):
        num_classes = 10
        embedding_size = 2
        margin = 0.5

        for use_autocast in [True, False]:

            if use_autocast:
                cm = torch.cuda.amp.autocast()
            else:
                cm = nullcontext()

            for dtype in TEST_DTYPES:
                alpha = 1 if dtype == torch.float16 else 32
                loss_func = ProxyAnchorLoss(num_classes,
                                            embedding_size,
                                            margin=margin,
                                            alpha=alpha).to(TEST_DEVICE)
                original_loss_func = OriginalImplementationProxyAnchor(
                    num_classes, embedding_size, mrg=margin,
                    alpha=alpha).to(TEST_DEVICE)

                if not use_autocast:
                    original_loss_func.proxies.data = (
                        original_loss_func.proxies.data.type(dtype))
                loss_func.proxies = original_loss_func.proxies

                embedding_angles = list(range(0, 180))
                embeddings = torch.tensor(
                    [c_f.angle_to_coord(a) for a in embedding_angles],
                    requires_grad=True,
                    dtype=torch.float32,
                ).to(TEST_DEVICE)  # 2D embeddings

                if not use_autocast:
                    embeddings = embeddings.type(dtype)
                labels = torch.randint(low=0, high=5,
                                       size=(180, )).to(TEST_DEVICE)

                with cm:
                    loss = loss_func(embeddings, labels)

                loss.backward()
                correct_loss = original_loss_func(embeddings, labels)
                rtol = 1e-2 if dtype == torch.float16 or use_autocast else 1e-5

                self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol))
Пример #3
0
    def test_proxyanchor_loss(self):
        num_classes = 10
        embedding_size = 2
        margin = 0.5
        alpha = 32
        device = torch.device("cuda")
        loss_func = ProxyAnchorLoss(num_classes, embedding_size, margin = margin, alpha = alpha).to(device)
        original_loss_func = OriginalImplementationProxyAnchor(num_classes, embedding_size, mrg = margin, alpha = alpha).to(device)

        loss_func.proxies = original_loss_func.proxies

        embedding_angles = list(range(0, 180))
        embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=torch.float).to(device) #2D embeddings
        labels = torch.randint(low=0, high=5, size=(180,)).to(device)

        loss = loss_func(embeddings, labels)
        loss.backward()
        correct_loss = original_loss_func(embeddings, labels)
        self.assertTrue(torch.isclose(loss, correct_loss))
Пример #4
0
    def test_logit_getter(self):
        embedding_size = 512
        num_classes = 10
        batch_size = 32

        for dtype in TEST_DTYPES:
            embeddings = (
                torch.randn(batch_size, embedding_size).to(TEST_DEVICE).type(dtype)
            )
            kwargs = {"num_classes": num_classes, "embedding_size": embedding_size}
            loss1 = ArcFaceLoss(**kwargs).to(TEST_DEVICE).type(dtype)
            loss2 = NormalizedSoftmaxLoss(**kwargs).to(TEST_DEVICE).type(dtype)
            loss3 = ProxyAnchorLoss(**kwargs).to(TEST_DEVICE).type(dtype)

            # test the ability to infer shape
            for loss in [loss1, loss2, loss3]:
                self.helper_tester(loss, embeddings, batch_size, num_classes)

            # test specifying wrong layer name
            self.assertRaises(AttributeError, LogitGetter, loss1, layer_name="blah")

            # test specifying correct layer name
            self.helper_tester(
                loss1, embeddings, batch_size, num_classes, layer_name="W"
            )

            # test specifying a distance metric
            self.helper_tester(
                loss1, embeddings, batch_size, num_classes, distance=LpDistance()
            )

            # test specifying transpose incorrectly
            LG = LogitGetter(loss1, transpose=False)
            self.assertRaises(RuntimeError, LG, embeddings)

            # test specifying transpose correctly
            self.helper_tester(
                loss1, embeddings, batch_size, num_classes, transpose=True
            )

            # test copying weights
            LG = LogitGetter(loss1)
            self.assertTrue(torch.all(LG.weights == loss1.W))
            loss1.W.data *= 0
            self.assertTrue(not torch.all(LG.weights == loss1.W))

            # test not copying weights
            LG = LogitGetter(loss1, copy_weights=False)
            self.assertTrue(torch.all(LG.weights == loss1.W))
            loss1.W.data *= 0
            self.assertTrue(torch.all(LG.weights == loss1.W))