def test_regular_face_regularizer(self):
        temperature = 0.1
        num_classes = 10
        embedding_size = 512
        reg_weight = 0.1
        loss_func = NormalizedSoftmaxLoss(temperature=temperature,
                                          num_classes=num_classes,
                                          embedding_size=embedding_size,
                                          regularizer=RegularFaceRegularizer(),
                                          reg_weight=reg_weight)

        embeddings = torch.nn.functional.normalize(
            torch.randn((180, embedding_size),
                        requires_grad=True,
                        dtype=torch.float))
        labels = torch.randint(low=0, high=10, size=(180, ))

        loss = loss_func(embeddings, labels)
        loss.backward()

        weights = torch.nn.functional.normalize(loss_func.W, p=2, dim=0)
        logits = torch.matmul(embeddings, weights)
        correct_class_loss = torch.nn.functional.cross_entropy(
            logits / temperature, labels)

        weight_cos_matrix = torch.matmul(weights.t(), weights)
        weight_cos_matrix.fill_diagonal_(float('-inf'))
        correct_reg_loss = 0
        for i in range(num_classes):
            correct_reg_loss += torch.max(weight_cos_matrix[i])
        correct_reg_loss /= num_classes

        correct_total_loss = correct_class_loss + (correct_reg_loss *
                                                   reg_weight)
        self.assertTrue(torch.isclose(loss, correct_total_loss))
    def test_normalized_softmax_loss(self):
        temperature = 0.1
        for dtype in TEST_DTYPES:
            loss_func = NormalizedSoftmaxLoss(
                temperature=temperature, num_classes=10, embedding_size=2
            )
            embedding_angles = torch.arange(0, 180)
            embeddings = torch.tensor(
                [c_f.angle_to_coord(a) for a in embedding_angles],
                requires_grad=True,
                dtype=dtype,
            ).to(
                TEST_DEVICE
            )  # 2D embeddings
            labels = torch.randint(low=0, high=10, size=(180,)).to(TEST_DEVICE)

            loss = loss_func(embeddings, labels)
            loss.backward()

            weights = torch.nn.functional.normalize(loss_func.W, p=2, dim=0)
            logits = torch.matmul(embeddings, weights)
            correct_loss = torch.nn.functional.cross_entropy(
                logits / temperature, labels
            )
            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol))
Example #3
0
    def test_logit_getter(self):
        embedding_size = 512
        num_classes = 10
        batch_size = 32

        for dtype in TEST_DTYPES:
            embeddings = (
                torch.randn(batch_size, embedding_size).to(TEST_DEVICE).type(dtype)
            )
            kwargs = {"num_classes": num_classes, "embedding_size": embedding_size}
            loss1 = ArcFaceLoss(**kwargs).to(TEST_DEVICE).type(dtype)
            loss2 = NormalizedSoftmaxLoss(**kwargs).to(TEST_DEVICE).type(dtype)
            loss3 = ProxyAnchorLoss(**kwargs).to(TEST_DEVICE).type(dtype)

            # test the ability to infer shape
            for loss in [loss1, loss2, loss3]:
                self.helper_tester(loss, embeddings, batch_size, num_classes)

            # test specifying wrong layer name
            self.assertRaises(AttributeError, LogitGetter, loss1, layer_name="blah")

            # test specifying correct layer name
            self.helper_tester(
                loss1, embeddings, batch_size, num_classes, layer_name="W"
            )

            # test specifying a distance metric
            self.helper_tester(
                loss1, embeddings, batch_size, num_classes, distance=LpDistance()
            )

            # test specifying transpose incorrectly
            LG = LogitGetter(loss1, transpose=False)
            self.assertRaises(RuntimeError, LG, embeddings)

            # test specifying transpose correctly
            self.helper_tester(
                loss1, embeddings, batch_size, num_classes, transpose=True
            )

            # test copying weights
            LG = LogitGetter(loss1)
            self.assertTrue(torch.all(LG.weights == loss1.W))
            loss1.W.data *= 0
            self.assertTrue(not torch.all(LG.weights == loss1.W))

            # test not copying weights
            LG = LogitGetter(loss1, copy_weights=False)
            self.assertTrue(torch.all(LG.weights == loss1.W))
            loss1.W.data *= 0
            self.assertTrue(torch.all(LG.weights == loss1.W))
Example #4
0
    def test_center_invariant_regularizer(self):
        temperature = 0.1
        num_classes = 10
        embedding_size = 512
        reg_weight = 0.1
        for dtype in TEST_DTYPES:
            loss_func = NormalizedSoftmaxLoss(
                temperature=temperature,
                num_classes=num_classes,
                embedding_size=embedding_size,
                weight_regularizer=CenterInvariantRegularizer(),
                weight_reg_weight=reg_weight,
            ).to(TEST_DEVICE)

            embeddings = torch.nn.functional.normalize(
                torch.randn((180, embedding_size),
                            requires_grad=True).type(dtype).to(TEST_DEVICE))
            labels = torch.randint(low=0, high=10,
                                   size=(180, )).to(TEST_DEVICE)

            loss = loss_func(embeddings, labels)
            loss.backward()

            weights = torch.nn.functional.normalize(loss_func.W, p=2, dim=0)
            logits = torch.matmul(embeddings, weights)
            correct_class_loss = torch.nn.functional.cross_entropy(
                logits / temperature, labels)

            correct_reg_loss = 0
            average_squared_weight_norms = 0
            for i in range(num_classes):
                average_squared_weight_norms += torch.norm(loss_func.W[:, i],
                                                           p=2)**2
            average_squared_weight_norms /= num_classes
            for i in range(num_classes):
                deviation = (torch.norm(loss_func.W[:, i], p=2)**2 -
                             average_squared_weight_norms)
                correct_reg_loss += (deviation**2) / 4
            correct_reg_loss /= num_classes

            correct_total_loss = correct_class_loss + (correct_reg_loss *
                                                       reg_weight)
            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            self.assertTrue(torch.isclose(loss, correct_total_loss, rtol=rtol))
Example #5
0
    def test_regular_face_regularizer(self):
        temperature = 0.1
        num_classes = 10
        embedding_size = 512
        reg_weight = 0.1
        for dtype in TEST_DTYPES:
            loss_func = NormalizedSoftmaxLoss(
                temperature=temperature,
                num_classes=num_classes,
                embedding_size=embedding_size,
                weight_regularizer=RegularFaceRegularizer(),
                weight_reg_weight=reg_weight,
            ).to(TEST_DEVICE)

            embeddings = torch.nn.functional.normalize(
                torch.randn((180, embedding_size),
                            requires_grad=True).type(dtype).to(TEST_DEVICE))
            labels = torch.randint(low=0, high=10,
                                   size=(180, )).to(TEST_DEVICE)

            loss = loss_func(embeddings, labels)
            loss.backward()

            weights = torch.nn.functional.normalize(loss_func.W, p=2, dim=0)
            logits = torch.matmul(embeddings, weights)
            correct_class_loss = torch.nn.functional.cross_entropy(
                logits / temperature, labels)

            weight_cos_matrix = torch.matmul(weights.t(), weights)
            weight_cos_matrix.fill_diagonal_(c_f.neg_inf(dtype))
            correct_reg_loss = 0
            for i in range(num_classes):
                correct_reg_loss += torch.max(weight_cos_matrix[i])
            correct_reg_loss /= num_classes

            correct_total_loss = correct_class_loss + (correct_reg_loss *
                                                       reg_weight)
            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            self.assertTrue(torch.isclose(loss, correct_total_loss, rtol=rtol))