Exemple #1
0
def test_gram_update(X_dim, Y_dim, n):
    """Test update"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))
    Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim))

    gram = OnlineGram(X_dim, Y_dim)
    gram.update(X, Y)

    np.testing.assert_array_almost_equal(gram.matrix(normalize=False, fit_intercept=False), X.T @ Y)
Exemple #2
0
def test_gram_normalize(X_dim, Y_dim, n):
    """Test normalize"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))
    Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim))

    gram = OnlineGram(X_dim, Y_dim)
    gram.update(X, Y)

    X_norm = (X - X.mean(axis=0)) / X.std(axis=0)
    Y_norm = (Y - Y.mean(axis=0)) / Y.std(axis=0)
    np.testing.assert_array_almost_equal(
        gram.matrix(normalize=True, fit_intercept=False), X_norm.T @ Y_norm, decimal=1
    )
Exemple #3
0
def test_gram_update_iterative(X_dim, Y_dim, n):
    """Test update iteratively"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))
    Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim))

    gram = OnlineGram(X_dim, Y_dim)

    for x, y in zip(X, Y):
        gram.update(x.reshape(1, -1), y.reshape(1, -1))

    np.testing.assert_array_almost_equal(
        gram.matrix(normalize=False, fit_intercept=False), X.T @ Y, decimal=3
    )
Exemple #4
0
def test_gram_intercept_xtx(X_dim, n, normalize):
    """Test intercept on X.T @ X"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))

    gram = OnlineGram(X_dim)
    gram.update(X)

    if normalize:
        X = (X - X.mean(axis=0)) / X.std(axis=0)

    X = jnp.hstack((jnp.ones((n, 1)), X))

    np.testing.assert_array_almost_equal(
        gram.matrix(normalize=normalize, fit_intercept=True), X.T @ X, decimal=1
    )
Exemple #5
0
def test_gram_projection_xtx(X_dim, n, k, normalize, fit_intercept):
    """Test projection on X.T @ X"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))
    k = min(k, X_dim)
    projection = jax.random.uniform(random.generate_key(), shape=(X_dim, k))

    gram = OnlineGram(X_dim)
    gram.update(X)

    if normalize:
        X = (X - X.mean(axis=0)) / X.std(axis=0)

    expected = projection.T @ X.T @ X @ projection
    if fit_intercept:
        expected = gram.fit_intercept(expected, normalize=normalize, projection=projection)

    np.testing.assert_array_almost_equal(
        gram.matrix(normalize=normalize, projection=projection, fit_intercept=fit_intercept),
        expected,
        decimal=1,
    )
Exemple #6
0
def test_gram_intercept_xty(X_dim, Y_dim, n, normalize):
    """Test intercept on X.T @ Y"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))
    Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim))

    gram = OnlineGram(X_dim, Y_dim)

    gram.update(X, Y)

    np.testing.assert_array_almost_equal(gram.mean.squeeze(), X.mean(axis=0))
    np.testing.assert_array_almost_equal(gram.std.squeeze(), X.std(axis=0))
    assert gram.observations == X.shape[0]

    if normalize:
        X = (X - X.mean(axis=0)) / X.std(axis=0)
        Y = (Y - Y.mean(axis=0)) / Y.std(axis=0)

    X = jnp.hstack((jnp.ones((n, 1)), X))

    np.testing.assert_array_almost_equal(
        gram.matrix(normalize=normalize, fit_intercept=True), X.T @ Y, decimal=1
    )
Exemple #7
0
def fit_gram(
    XTX: OnlineGram,
    XTY: OnlineGram,
    alpha: float = 1.0,
    normalize: bool = False,
    projection: np.ndarray = None,
    input_dim: int = None,
):
    """Compute linear regression parameters from gram matrix

    Notes:
        * Assumes over-determined systems
        * Assumes we always fit an intercept
    """

    fit_intercept = True
    feature_dim = XTX.feature_dim if projection is None else projection.shape[1]
    output_dim = XTY.output_dim

    if input_dim is None:
        history_len = None
    else:
        if feature_dim % input_dim != 0:
            raise ValueError("Original input dimension must evenly divide feature dimensions")
        history_len = feature_dim // input_dim

    if XTX.observations <= feature_dim:
        raise ValueError(
            "Underdetermined systems not currently supported (observations: {},"
            "features: {})".format(XTX.observations, feature_dim)
        )

    # Finalize gram matrices
    XTX = XTX.matrix(
        normalize=normalize,
        projection=projection,
        fit_intercept=fit_intercept,
        input_dim=input_dim,
    )
    XTY = XTY.matrix(
        normalize=normalize,
        projection=projection,
        fit_intercept=fit_intercept,
        input_dim=input_dim,
    )
    inv = _compute_xtx_inverse(XTX, alpha)
    beta = _fit_unconstrained(inv, XTY)

    # Ignore constrained regression if we project
    if input_dim is not None:
        R, r = _form_constraints(
            input_dim=input_dim,
            output_dim=output_dim,
            history_len=history_len,
            fit_intercept=fit_intercept,
        )
        beta = _fit_constrained(beta, inv, R, r)
        beta = beta.take(jnp.arange(0, len(beta), input_dim), axis=0)
        return beta[1:], beta[0]

    return beta[1:].reshape(1, feature_dim, -1), beta[0]