def test_proxyanchor_loss(self): num_classes = 10 embedding_size = 2 margin = 0.5 for dtype in TEST_DTYPES: alpha = 1 if dtype == torch.float16 else 32 loss_func = ProxyAnchorLoss(num_classes, embedding_size, margin=margin, alpha=alpha).to(TEST_DEVICE) original_loss_func = OriginalImplementationProxyAnchor( num_classes, embedding_size, mrg=margin, alpha=alpha).to(TEST_DEVICE) original_loss_func.proxies.data = original_loss_func.proxies.data.type( dtype) loss_func.proxies = original_loss_func.proxies embedding_angles = list(range(0, 180)) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=dtype, ).to(TEST_DEVICE) # 2D embeddings labels = torch.randint(low=0, high=5, size=(180, )).to(TEST_DEVICE) loss = loss_func(embeddings, labels) loss.backward() correct_loss = original_loss_func(embeddings, labels) rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol))
def test_proxyanchor_loss(self): num_classes = 10 embedding_size = 2 margin = 0.5 for use_autocast in [True, False]: if use_autocast: cm = torch.cuda.amp.autocast() else: cm = nullcontext() for dtype in TEST_DTYPES: alpha = 1 if dtype == torch.float16 else 32 loss_func = ProxyAnchorLoss(num_classes, embedding_size, margin=margin, alpha=alpha).to(TEST_DEVICE) original_loss_func = OriginalImplementationProxyAnchor( num_classes, embedding_size, mrg=margin, alpha=alpha).to(TEST_DEVICE) if not use_autocast: original_loss_func.proxies.data = ( original_loss_func.proxies.data.type(dtype)) loss_func.proxies = original_loss_func.proxies embedding_angles = list(range(0, 180)) embeddings = torch.tensor( [c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=torch.float32, ).to(TEST_DEVICE) # 2D embeddings if not use_autocast: embeddings = embeddings.type(dtype) labels = torch.randint(low=0, high=5, size=(180, )).to(TEST_DEVICE) with cm: loss = loss_func(embeddings, labels) loss.backward() correct_loss = original_loss_func(embeddings, labels) rtol = 1e-2 if dtype == torch.float16 or use_autocast else 1e-5 self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol))
def test_proxyanchor_loss(self): num_classes = 10 embedding_size = 2 margin = 0.5 alpha = 32 device = torch.device("cuda") loss_func = ProxyAnchorLoss(num_classes, embedding_size, margin = margin, alpha = alpha).to(device) original_loss_func = OriginalImplementationProxyAnchor(num_classes, embedding_size, mrg = margin, alpha = alpha).to(device) loss_func.proxies = original_loss_func.proxies embedding_angles = list(range(0, 180)) embeddings = torch.tensor([c_f.angle_to_coord(a) for a in embedding_angles], requires_grad=True, dtype=torch.float).to(device) #2D embeddings labels = torch.randint(low=0, high=5, size=(180,)).to(device) loss = loss_func(embeddings, labels) loss.backward() correct_loss = original_loss_func(embeddings, labels) self.assertTrue(torch.isclose(loss, correct_loss))
def test_logit_getter(self): embedding_size = 512 num_classes = 10 batch_size = 32 for dtype in TEST_DTYPES: embeddings = ( torch.randn(batch_size, embedding_size).to(TEST_DEVICE).type(dtype) ) kwargs = {"num_classes": num_classes, "embedding_size": embedding_size} loss1 = ArcFaceLoss(**kwargs).to(TEST_DEVICE).type(dtype) loss2 = NormalizedSoftmaxLoss(**kwargs).to(TEST_DEVICE).type(dtype) loss3 = ProxyAnchorLoss(**kwargs).to(TEST_DEVICE).type(dtype) # test the ability to infer shape for loss in [loss1, loss2, loss3]: self.helper_tester(loss, embeddings, batch_size, num_classes) # test specifying wrong layer name self.assertRaises(AttributeError, LogitGetter, loss1, layer_name="blah") # test specifying correct layer name self.helper_tester( loss1, embeddings, batch_size, num_classes, layer_name="W" ) # test specifying a distance metric self.helper_tester( loss1, embeddings, batch_size, num_classes, distance=LpDistance() ) # test specifying transpose incorrectly LG = LogitGetter(loss1, transpose=False) self.assertRaises(RuntimeError, LG, embeddings) # test specifying transpose correctly self.helper_tester( loss1, embeddings, batch_size, num_classes, transpose=True ) # test copying weights LG = LogitGetter(loss1) self.assertTrue(torch.all(LG.weights == loss1.W)) loss1.W.data *= 0 self.assertTrue(not torch.all(LG.weights == loss1.W)) # test not copying weights LG = LogitGetter(loss1, copy_weights=False) self.assertTrue(torch.all(LG.weights == loss1.W)) loss1.W.data *= 0 self.assertTrue(torch.all(LG.weights == loss1.W))