Пример #1
0
def test_arcface_with_cross_entropy_loss():
    emb_size = 4
    n_classes = 3
    s = 3.0
    m = 0.5
    eps = 1e-8

    # fmt: off
    features = np.array(
        [
            [1, 2, 3, 4],
            [5, 6, 7, 8],
        ],
        dtype="f",
    )
    target = np.array([0, 2], dtype="l")
    weight = np.array(
        [
            [0.1, 0.2, 0.3, 0.4],
            [1.1, 3.2, 5.3, 0.4],
            [0.1, 0.2, 6.3, 0.4],
        ],
        dtype="f",
    )
    # fmt: on

    layer = ArcFace(emb_size, n_classes, s, m, eps)
    layer.weight.data = torch.from_numpy(weight)
    loss_fn = nn.CrossEntropyLoss(reduction="none")

    normalized_features = normalize(features)  # 2x4
    normalized_projection = normalize(weight)  # 3x4

    cosine = normalized_features @ normalized_projection.T  # 2x4 * 4x3 = 2x3
    theta = np.arccos(np.clip(cosine, -1 + eps, 1 - eps))  # 2x3

    # one_hot(target)
    mask = np.array([[1, 0, 0], [0, 0, 1]], dtype="l")
    mask = np.where(theta > (np.pi - m), np.zeros_like(mask), mask)  # 2x3
    feats = np.cos(np.where(mask > 0, theta + m, theta)) * s  # 2x3

    expected_loss = cross_entropy(feats, mask, 1)
    actual = (loss_fn(
        layer(torch.from_numpy(features), torch.LongTensor(target)),
        torch.LongTensor(target),
    ).detach().numpy())
    assert np.allclose(expected_loss, actual)

    loss_fn = nn.CrossEntropyLoss(reduction="mean")

    expected_loss = cross_entropy(feats, mask, 1)
    actual = (loss_fn(
        layer(torch.from_numpy(features), torch.LongTensor(target)),
        torch.LongTensor(target),
    ).detach().numpy())
    assert np.isclose(expected_loss.mean(), actual)

    loss_fn = nn.CrossEntropyLoss(reduction="sum")

    expected_loss = cross_entropy(feats, mask, 1)
    actual = (loss_fn(
        layer(torch.from_numpy(features), torch.LongTensor(target)),
        torch.LongTensor(target),
    ).detach().numpy())
    assert np.isclose(expected_loss.sum(), actual)
Пример #2
0
def test_arcface_inference_mode():
    _check_layer(ArcFace(5, 10, s=1.31, m=0.5))