def test_gram_intercept_constrained_projection(): """Constrained projection should error""" X = jax.random.uniform(random.generate_key(), shape=(5, 10)) XTX = OnlineGram(10) XTX.update(X) with pytest.raises(ValueError): XTX.fit_intercept(projection=1, input_dim=2)
def test_gram_projection_xtx(X_dim, n, k, normalize, fit_intercept): """Test projection on X.T @ X""" X = jax.random.uniform(random.generate_key(), shape=(n, X_dim)) k = min(k, X_dim) projection = jax.random.uniform(random.generate_key(), shape=(X_dim, k)) gram = OnlineGram(X_dim) gram.update(X) if normalize: X = (X - X.mean(axis=0)) / X.std(axis=0) expected = projection.T @ X.T @ X @ projection if fit_intercept: expected = gram.fit_intercept(expected, normalize=normalize, projection=projection) np.testing.assert_array_almost_equal( gram.matrix(normalize=normalize, projection=projection, fit_intercept=fit_intercept), expected, decimal=1, )