def test_arcface_with_cross_entropy_loss(): emb_size = 4 n_classes = 3 s = 3.0 m = 0.5 eps = 1e-8 # fmt: off features = np.array( [ [1, 2, 3, 4], [5, 6, 7, 8], ], dtype="f", ) target = np.array([0, 2], dtype="l") weight = np.array( [ [0.1, 0.2, 0.3, 0.4], [1.1, 3.2, 5.3, 0.4], [0.1, 0.2, 6.3, 0.4], ], dtype="f", ) # fmt: on layer = ArcFace(emb_size, n_classes, s, m, eps) layer.weight.data = torch.from_numpy(weight) loss_fn = nn.CrossEntropyLoss(reduction="none") normalized_features = normalize(features) # 2x4 normalized_projection = normalize(weight) # 3x4 cosine = normalized_features @ normalized_projection.T # 2x4 * 4x3 = 2x3 theta = np.arccos(np.clip(cosine, -1 + eps, 1 - eps)) # 2x3 # one_hot(target) mask = np.array([[1, 0, 0], [0, 0, 1]], dtype="l") mask = np.where(theta > (np.pi - m), np.zeros_like(mask), mask) # 2x3 feats = np.cos(np.where(mask > 0, theta + m, theta)) * s # 2x3 expected_loss = cross_entropy(feats, mask, 1) actual = (loss_fn( layer(torch.from_numpy(features), torch.LongTensor(target)), torch.LongTensor(target), ).detach().numpy()) assert np.allclose(expected_loss, actual) loss_fn = nn.CrossEntropyLoss(reduction="mean") expected_loss = cross_entropy(feats, mask, 1) actual = (loss_fn( layer(torch.from_numpy(features), torch.LongTensor(target)), torch.LongTensor(target), ).detach().numpy()) assert np.isclose(expected_loss.mean(), actual) loss_fn = nn.CrossEntropyLoss(reduction="sum") expected_loss = cross_entropy(feats, mask, 1) actual = (loss_fn( layer(torch.from_numpy(features), torch.LongTensor(target)), torch.LongTensor(target), ).detach().numpy()) assert np.isclose(expected_loss.sum(), actual)
def test_arcface_inference_mode(): _check_layer(ArcFace(5, 10, s=1.31, m=0.5))