Пример #1
0
def test_fit_constrained(n, input_dim, output_dim, history_len):
    """Fit constrained regression"""
    # NOTE: we use random data because we want to test dimensions and
    # correctness vs a second implementation
    X = jax.random.uniform(random.generate_key(), shape=(n, input_dim))
    Y = jax.random.uniform(random.generate_key(), shape=(n, output_dim))

    XTX, XTY = compute_gram([(X, Y, None)], input_dim, output_dim, history_len)
    result = fit_gram(XTX, XTY, input_dim=input_dim)

    # Next, check that each chunk of input_dim features have the same coefficient
    # result = fit_gram(XTX, XTY, input_dim=input_dim)
    R, r = _form_constraints(
        input_dim=input_dim, output_dim=output_dim, history_len=history_len, fit_intercept=True,
    )

    XTX = XTX.matrix(fit_intercept=True, input_dim=input_dim)
    XTY = XTY.matrix(fit_intercept=True, input_dim=input_dim)
    inv = _compute_xtx_inverse(XTX, alpha=1.0)
    beta = _fit_unconstrained(inv, XTY)
    beta = _fit_constrained(beta, inv, R, r)
    beta = beta.reshape(history_len + 1, input_dim, -1)
    assert np.sum([np.abs(x - x[0]) for x in beta]) < 1e-4

    # Finally, check that resulting vector is of the correct length and the
    # values are self-consistent
    assert len(beta) == history_len + 1

    beta = beta[:, 0]
    beta = beta[1:], beta[0]

    # Check final results
    np.testing.assert_array_almost_equal(beta[0], result[0])
    np.testing.assert_array_almost_equal(beta[1], result[1])
Пример #2
0
def test_gram_update(X_dim, Y_dim, n):
    """Test update"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))
    Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim))

    gram = OnlineGram(X_dim, Y_dim)
    gram.update(X, Y)

    np.testing.assert_array_almost_equal(gram.matrix(normalize=False, fit_intercept=False), X.T @ Y)
Пример #3
0
def test_fit_constrained_bad_input_dim():
    """Bad input for constrained"""
    XTX = OnlineGram(10)
    XTY = OnlineGram(5)

    XTX.update(jax.random.uniform(random.generate_key(), shape=(100, 10)))
    XTY.update(jax.random.uniform(random.generate_key(), shape=(100, 5)))

    with pytest.raises(ValueError):
        fit_gram(XTX, XTY, input_dim=7)
Пример #4
0
def test_compute_gram(n, input_dim, output_dim, history_len):
    """Test compouting gram matrices"""
    X = jax.random.uniform(random.generate_key(), shape=(n, input_dim))
    Y = jax.random.uniform(random.generate_key(), shape=(n, output_dim))

    XTX, XTY = compute_gram([(X, Y, None)], input_dim, output_dim, history_len)

    history = historify(X, history_len)
    history = history.reshape(history.shape[0], -1)
    np.testing.assert_array_almost_equal(history.T @ history, XTX.matrix(), decimal=4)
    np.testing.assert_array_almost_equal(history.T @ Y[history_len - 1 :], XTY.matrix(), decimal=4)
Пример #5
0
def test_gram_update_iterative(X_dim, Y_dim, n):
    """Test update iteratively"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))
    Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim))

    gram = OnlineGram(X_dim, Y_dim)

    for x, y in zip(X, Y):
        gram.update(x.reshape(1, -1), y.reshape(1, -1))

    np.testing.assert_array_almost_equal(
        gram.matrix(normalize=False, fit_intercept=False), X.T @ Y, decimal=3
    )
Пример #6
0
def test_gram_normalize(X_dim, Y_dim, n):
    """Test normalize"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))
    Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim))

    gram = OnlineGram(X_dim, Y_dim)
    gram.update(X, Y)

    X_norm = (X - X.mean(axis=0)) / X.std(axis=0)
    Y_norm = (Y - Y.mean(axis=0)) / Y.std(axis=0)
    np.testing.assert_array_almost_equal(
        gram.matrix(normalize=True, fit_intercept=False), X_norm.T @ Y_norm, decimal=1
    )
Пример #7
0
def test_gram_update_value_error():
    """Test update value error"""
    gram = OnlineGram(1, 1)
    with pytest.raises(ValueError):
        gram.update(
            jax.random.uniform(random.generate_key(), shape=(4, 1)),
            jax.random.uniform(random.generate_key(), shape=(1, 1)),
        )

    with pytest.raises(ValueError):
        gram.update(jax.random.uniform(random.generate_key(), shape=(4, 1)))

    gram = OnlineGram(1)
    with pytest.raises(ValueError):
        gram.update(np.ones((1, 1)), np.ones((1, 1)))
Пример #8
0
def test_gram_intercept_constrained_projection():
    """Constrained projection should error"""
    X = jax.random.uniform(random.generate_key(), shape=(5, 10))
    XTX = OnlineGram(10)
    XTX.update(X)

    with pytest.raises(ValueError):
        XTX.fit_intercept(projection=1, input_dim=2)
Пример #9
0
def test_random():
    """Test random singleton"""
    random.set_key()
    random.set_key(0)
    assert jnp.array_equal(jnp.array([0, 0]), random.get_key())
    expected = jnp.array([2718843009, 1272950319], dtype=jnp.uint32)
    assert jnp.array_equal(random.generate_key(), expected)
    expected = jnp.array([4146024105, 967050713], dtype=jnp.uint32)
    assert jnp.array_equal(random.get_key(), expected)
Пример #10
0
def test_experiment(shape, num_args):
    """Test normal experiment behavior"""
    args = [
        (
            jax.random.uniform(random.generate_key(), shape=shape),
            jax.random.uniform(random.generate_key(), shape=shape),
        )
        for _ in range(num_args)
    ]

    @experiment("a,b", args)
    def dummy(a, b):
        """dummy"""
        return a + b

    results = dummy.run()
    print(results)
    for i in range(len(results)):
        np.testing.assert_array_almost_equal(results[i], np.sum(args[i], axis=0))
Пример #11
0
def test_compute_projection(shape):
    """Test PCA projection of X vs X.T @ X"""
    X = jax.random.uniform(random.generate_key(), shape=shape)
    XTX = X.T @ X

    k = 1 if X.ndim == 1 else min(X.shape)
    p1 = compute_projection(X, k)
    p2 = compute_projection(XTX, k)

    np.testing.assert_array_almost_equal(abs(p1), abs(p2), decimal=3)
Пример #12
0
def test_fit_unconstrained(n, input_dim, output_dim, history_len):
    """Fit unconstrained regression"""
    # NOTE: we use random data because we want to test dimensions and
    # correctness vs a second implementation
    X = jax.random.uniform(random.generate_key(), shape=(n, input_dim))
    Y = jax.random.uniform(random.generate_key(), shape=(n, output_dim))

    XTX, XTY = compute_gram([(X, Y, None)], input_dim, output_dim, history_len)

    kernel, bias = fit_gram(XTX, XTY)
    n - history_len + 1
    history = historify(X, history_len)
    history = history.reshape(history.shape[0], -1)

    expected_kernel, expected_bias = _compute_kernel_bias(history, Y[history_len - 1 :], alpha=1.0)
    expected_kernel = expected_kernel.reshape(1, history_len * input_dim, -1)

    np.testing.assert_array_almost_equal(expected_kernel, kernel, decimal=3)
    np.testing.assert_array_almost_equal(expected_bias, bias, decimal=3)
Пример #13
0
def test_gram_projection_xtx(X_dim, n, k, normalize, fit_intercept):
    """Test projection on X.T @ X"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))
    k = min(k, X_dim)
    projection = jax.random.uniform(random.generate_key(), shape=(X_dim, k))

    gram = OnlineGram(X_dim)
    gram.update(X)

    if normalize:
        X = (X - X.mean(axis=0)) / X.std(axis=0)

    expected = projection.T @ X.T @ X @ projection
    if fit_intercept:
        expected = gram.fit_intercept(expected, normalize=normalize, projection=projection)

    np.testing.assert_array_almost_equal(
        gram.matrix(normalize=normalize, projection=projection, fit_intercept=fit_intercept),
        expected,
        decimal=1,
    )
Пример #14
0
def test_compute_projection_sklearn(shape):
    """Test PCA projection of X vs sklearn"""
    X = jax.random.uniform(random.generate_key(), shape=shape)

    projection = compute_projection(X, 1, center=True)

    pca = PCA(n_components=1)
    pca.fit(X)

    np.testing.assert_array_almost_equal(abs(projection),
                                         abs(pca.components_.T),
                                         decimal=3)
Пример #15
0
def test_gram_intercept_xty(X_dim, Y_dim, n, normalize):
    """Test intercept on X.T @ Y"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))
    Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim))

    gram = OnlineGram(X_dim, Y_dim)

    gram.update(X, Y)

    np.testing.assert_array_almost_equal(gram.mean.squeeze(), X.mean(axis=0))
    np.testing.assert_array_almost_equal(gram.std.squeeze(), X.std(axis=0))
    assert gram.observations == X.shape[0]

    if normalize:
        X = (X - X.mean(axis=0)) / X.std(axis=0)
        Y = (Y - Y.mean(axis=0)) / Y.std(axis=0)

    X = jnp.hstack((jnp.ones((n, 1)), X))

    np.testing.assert_array_almost_equal(
        gram.matrix(normalize=normalize, fit_intercept=True), X.T @ Y, decimal=1
    )
Пример #16
0
def test_historify(m, n, history_len):
    """Test history-making"""
    X = jax.random.uniform(random.generate_key(), shape=(m, n))

    if history_len < 1 or X.shape[0] - history_len < 0:
        with pytest.raises(ValueError):
            historify(X, history_len)

    else:
        batched = historify(X, history_len)
        batched = batched.reshape(batched.shape[0], -1)

        for i, batch in enumerate(batched):
            np.testing.assert_array_almost_equal(X[i : i + history_len].ravel().squeeze(), batch)
Пример #17
0
def test_gram_intercept_xtx(X_dim, n, normalize):
    """Test intercept on X.T @ X"""
    X = jax.random.uniform(random.generate_key(), shape=(n, X_dim))

    gram = OnlineGram(X_dim)
    gram.update(X)

    if normalize:
        X = (X - X.mean(axis=0)) / X.std(axis=0)

    X = jnp.hstack((jnp.ones((n, 1)), X))

    np.testing.assert_array_almost_equal(
        gram.matrix(normalize=normalize, fit_intercept=True), X.T @ X, decimal=1
    )
Пример #18
0
def test_gram_projection_no_projection():
    """Test empty projection"""
    X = jax.random.uniform(random.generate_key(), shape=(5, 10))
    XTX = OnlineGram(10)
    XTX.update(X)
    np.testing.assert_array_almost_equal(X.T @ X, XTX.project())