Example #1
0
def test_gram_intercept_constrained_projection():
    """Constrained projection should error"""
    X = jax.random.uniform(random.generate_key(), shape=(5, 10))
    XTX = OnlineGram(10)
    XTX.update(X)

    with pytest.raises(ValueError):
        XTX.fit_intercept(projection=1, input_dim=2)
Example #2
0
def test_gram_projection_xtx(X_dim, n, k, normalize, fit_intercept):
    """Test projection on X.T @ X"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))
    k = min(k, X_dim)
    projection = jax.random.uniform(random.generate_key(), shape=(X_dim, k))

    gram = OnlineGram(X_dim)
    gram.update(X)

    if normalize:
        X = (X - X.mean(axis=0)) / X.std(axis=0)

    expected = projection.T @ X.T @ X @ projection
    if fit_intercept:
        expected = gram.fit_intercept(expected, normalize=normalize, projection=projection)

    np.testing.assert_array_almost_equal(
        gram.matrix(normalize=normalize, projection=projection, fit_intercept=fit_intercept),
        expected,
        decimal=1,
    )