Ejemplo n.º 1
0
def compute_gram(
    data: Tuple[np.ndarray, np.ndarray, Any], input_dim: int, output_dim: int, history_len: int,
) -> Tuple[OnlineGram, OnlineGram]:
    """Compute X.T @ X and X.T @ Y on history windows incrementally"""
    num_features = input_dim * history_len
    XTX = OnlineGram(num_features)
    XTY = OnlineGram(num_features, output_dim)

    for X, Y, _ in data:
        X = internalize(X, input_dim)[0]
        Y = internalize(Y, output_dim)[0]

        if X.shape[0] != Y.shape[0]:
            raise ValueError("Input and output data must have the same number of observations")

        # Expand input time series X into histories, whic should result in a
        # (num_histories, history_len * input_dim)-shaped array
        history = historify(X, history_len=history_len)
        history = history.reshape(history.shape[0], -1)

        XTX.update(history)
        XTY.update(history, Y[history_len - 1 :])

    if XTX.observations == 0:
        raise IndexError("No data to fit")

    if XTX.observations <= num_features:
        raise ValueError(
            "Underdetermined systems not currently supported (observations: {},"
            "features: {})".format(XTX.observations, num_features)
        )

    return XTX, XTY
Ejemplo n.º 2
0
def test_gram_intercept_constrained_projection():
    """Constrained projection should error"""
    X = jax.random.uniform(random.generate_key(), shape=(5, 10))
    XTX = OnlineGram(10)
    XTX.update(X)

    with pytest.raises(ValueError):
        XTX.fit_intercept(projection=1, input_dim=2)
Ejemplo n.º 3
0
def test_gram_update(X_dim, Y_dim, n):
    """Test update"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))
    Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim))

    gram = OnlineGram(X_dim, Y_dim)
    gram.update(X, Y)

    np.testing.assert_array_almost_equal(gram.matrix(normalize=False, fit_intercept=False), X.T @ Y)
Ejemplo n.º 4
0
def test_fit_constrained_bad_input_dim():
    """Bad input for constrained"""
    XTX = OnlineGram(10)
    XTY = OnlineGram(5)

    XTX.update(jax.random.uniform(random.generate_key(), shape=(100, 10)))
    XTY.update(jax.random.uniform(random.generate_key(), shape=(100, 5)))

    with pytest.raises(ValueError):
        fit_gram(XTX, XTY, input_dim=7)
Ejemplo n.º 5
0
def test_gram_normalize(X_dim, Y_dim, n):
    """Test normalize"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))
    Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim))

    gram = OnlineGram(X_dim, Y_dim)
    gram.update(X, Y)

    X_norm = (X - X.mean(axis=0)) / X.std(axis=0)
    Y_norm = (Y - Y.mean(axis=0)) / Y.std(axis=0)
    np.testing.assert_array_almost_equal(
        gram.matrix(normalize=True, fit_intercept=False), X_norm.T @ Y_norm, decimal=1
    )
Ejemplo n.º 6
0
def test_gram_update_iterative(X_dim, Y_dim, n):
    """Test update iteratively"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))
    Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim))

    gram = OnlineGram(X_dim, Y_dim)

    for x, y in zip(X, Y):
        gram.update(x.reshape(1, -1), y.reshape(1, -1))

    np.testing.assert_array_almost_equal(
        gram.matrix(normalize=False, fit_intercept=False), X.T @ Y, decimal=3
    )
Ejemplo n.º 7
0
def test_gram_intercept_xtx(X_dim, n, normalize):
    """Test intercept on X.T @ X"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))

    gram = OnlineGram(X_dim)
    gram.update(X)

    if normalize:
        X = (X - X.mean(axis=0)) / X.std(axis=0)

    X = jnp.hstack((jnp.ones((n, 1)), X))

    np.testing.assert_array_almost_equal(
        gram.matrix(normalize=normalize, fit_intercept=True), X.T @ X, decimal=1
    )
Ejemplo n.º 8
0
def test_gram_projection_xtx(X_dim, n, k, normalize, fit_intercept):
    """Test projection on X.T @ X"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))
    k = min(k, X_dim)
    projection = jax.random.uniform(random.generate_key(), shape=(X_dim, k))

    gram = OnlineGram(X_dim)
    gram.update(X)

    if normalize:
        X = (X - X.mean(axis=0)) / X.std(axis=0)

    expected = projection.T @ X.T @ X @ projection
    if fit_intercept:
        expected = gram.fit_intercept(expected, normalize=normalize, projection=projection)

    np.testing.assert_array_almost_equal(
        gram.matrix(normalize=normalize, projection=projection, fit_intercept=fit_intercept),
        expected,
        decimal=1,
    )
Ejemplo n.º 9
0
def test_gram_intercept_xty(X_dim, Y_dim, n, normalize):
    """Test intercept on X.T @ Y"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))
    Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim))

    gram = OnlineGram(X_dim, Y_dim)

    gram.update(X, Y)

    np.testing.assert_array_almost_equal(gram.mean.squeeze(), X.mean(axis=0))
    np.testing.assert_array_almost_equal(gram.std.squeeze(), X.std(axis=0))
    assert gram.observations == X.shape[0]

    if normalize:
        X = (X - X.mean(axis=0)) / X.std(axis=0)
        Y = (Y - Y.mean(axis=0)) / Y.std(axis=0)

    X = jnp.hstack((jnp.ones((n, 1)), X))

    np.testing.assert_array_almost_equal(
        gram.matrix(normalize=normalize, fit_intercept=True), X.T @ Y, decimal=1
    )
Ejemplo n.º 10
0
def test_gram_update_value_error():
    """Test update value error"""
    gram = OnlineGram(1, 1)
    with pytest.raises(ValueError):
        gram.update(
            jax.random.uniform(random.generate_key(), shape=(4, 1)),
            jax.random.uniform(random.generate_key(), shape=(1, 1)),
        )

    with pytest.raises(ValueError):
        gram.update(jax.random.uniform(random.generate_key(), shape=(4, 1)))

    gram = OnlineGram(1)
    with pytest.raises(ValueError):
        gram.update(np.ones((1, 1)), np.ones((1, 1)))
Ejemplo n.º 11
0
def test_gram_projection_no_projection():
    """Test empty projection"""
    X = jax.random.uniform(random.generate_key(), shape=(5, 10))
    XTX = OnlineGram(10)
    XTX.update(X)
    np.testing.assert_array_almost_equal(X.T @ X, XTX.project())