def test_fit_constrained(n, input_dim, output_dim, history_len): """Fit constrained regression""" # NOTE: we use random data because we want to test dimensions and # correctness vs a second implementation X = jax.random.uniform(random.generate_key(), shape=(n, input_dim)) Y = jax.random.uniform(random.generate_key(), shape=(n, output_dim)) XTX, XTY = compute_gram([(X, Y, None)], input_dim, output_dim, history_len) result = fit_gram(XTX, XTY, input_dim=input_dim) # Next, check that each chunk of input_dim features have the same coefficient # result = fit_gram(XTX, XTY, input_dim=input_dim) R, r = _form_constraints( input_dim=input_dim, output_dim=output_dim, history_len=history_len, fit_intercept=True, ) XTX = XTX.matrix(fit_intercept=True, input_dim=input_dim) XTY = XTY.matrix(fit_intercept=True, input_dim=input_dim) inv = _compute_xtx_inverse(XTX, alpha=1.0) beta = _fit_unconstrained(inv, XTY) beta = _fit_constrained(beta, inv, R, r) beta = beta.reshape(history_len + 1, input_dim, -1) assert np.sum([np.abs(x - x[0]) for x in beta]) < 1e-4 # Finally, check that resulting vector is of the correct length and the # values are self-consistent assert len(beta) == history_len + 1 beta = beta[:, 0] beta = beta[1:], beta[0] # Check final results np.testing.assert_array_almost_equal(beta[0], result[0]) np.testing.assert_array_almost_equal(beta[1], result[1])
def test_gram_update(X_dim, Y_dim, n): """Test update""" X = jax.random.uniform(random.generate_key(), shape=(n, X_dim)) Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim)) gram = OnlineGram(X_dim, Y_dim) gram.update(X, Y) np.testing.assert_array_almost_equal(gram.matrix(normalize=False, fit_intercept=False), X.T @ Y)
def test_fit_constrained_bad_input_dim(): """Bad input for constrained""" XTX = OnlineGram(10) XTY = OnlineGram(5) XTX.update(jax.random.uniform(random.generate_key(), shape=(100, 10))) XTY.update(jax.random.uniform(random.generate_key(), shape=(100, 5))) with pytest.raises(ValueError): fit_gram(XTX, XTY, input_dim=7)
def test_compute_gram(n, input_dim, output_dim, history_len): """Test compouting gram matrices""" X = jax.random.uniform(random.generate_key(), shape=(n, input_dim)) Y = jax.random.uniform(random.generate_key(), shape=(n, output_dim)) XTX, XTY = compute_gram([(X, Y, None)], input_dim, output_dim, history_len) history = historify(X, history_len) history = history.reshape(history.shape[0], -1) np.testing.assert_array_almost_equal(history.T @ history, XTX.matrix(), decimal=4) np.testing.assert_array_almost_equal(history.T @ Y[history_len - 1 :], XTY.matrix(), decimal=4)
def test_gram_update_iterative(X_dim, Y_dim, n): """Test update iteratively""" X = jax.random.uniform(random.generate_key(), shape=(n, X_dim)) Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim)) gram = OnlineGram(X_dim, Y_dim) for x, y in zip(X, Y): gram.update(x.reshape(1, -1), y.reshape(1, -1)) np.testing.assert_array_almost_equal( gram.matrix(normalize=False, fit_intercept=False), X.T @ Y, decimal=3 )
def test_gram_normalize(X_dim, Y_dim, n): """Test normalize""" X = jax.random.uniform(random.generate_key(), shape=(n, X_dim)) Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim)) gram = OnlineGram(X_dim, Y_dim) gram.update(X, Y) X_norm = (X - X.mean(axis=0)) / X.std(axis=0) Y_norm = (Y - Y.mean(axis=0)) / Y.std(axis=0) np.testing.assert_array_almost_equal( gram.matrix(normalize=True, fit_intercept=False), X_norm.T @ Y_norm, decimal=1 )
def test_gram_update_value_error(): """Test update value error""" gram = OnlineGram(1, 1) with pytest.raises(ValueError): gram.update( jax.random.uniform(random.generate_key(), shape=(4, 1)), jax.random.uniform(random.generate_key(), shape=(1, 1)), ) with pytest.raises(ValueError): gram.update(jax.random.uniform(random.generate_key(), shape=(4, 1))) gram = OnlineGram(1) with pytest.raises(ValueError): gram.update(np.ones((1, 1)), np.ones((1, 1)))
def test_gram_intercept_constrained_projection(): """Constrained projection should error""" X = jax.random.uniform(random.generate_key(), shape=(5, 10)) XTX = OnlineGram(10) XTX.update(X) with pytest.raises(ValueError): XTX.fit_intercept(projection=1, input_dim=2)
def test_random(): """Test random singleton""" random.set_key() random.set_key(0) assert jnp.array_equal(jnp.array([0, 0]), random.get_key()) expected = jnp.array([2718843009, 1272950319], dtype=jnp.uint32) assert jnp.array_equal(random.generate_key(), expected) expected = jnp.array([4146024105, 967050713], dtype=jnp.uint32) assert jnp.array_equal(random.get_key(), expected)
def test_experiment(shape, num_args): """Test normal experiment behavior""" args = [ ( jax.random.uniform(random.generate_key(), shape=shape), jax.random.uniform(random.generate_key(), shape=shape), ) for _ in range(num_args) ] @experiment("a,b", args) def dummy(a, b): """dummy""" return a + b results = dummy.run() print(results) for i in range(len(results)): np.testing.assert_array_almost_equal(results[i], np.sum(args[i], axis=0))
def test_compute_projection(shape): """Test PCA projection of X vs X.T @ X""" X = jax.random.uniform(random.generate_key(), shape=shape) XTX = X.T @ X k = 1 if X.ndim == 1 else min(X.shape) p1 = compute_projection(X, k) p2 = compute_projection(XTX, k) np.testing.assert_array_almost_equal(abs(p1), abs(p2), decimal=3)
def test_fit_unconstrained(n, input_dim, output_dim, history_len): """Fit unconstrained regression""" # NOTE: we use random data because we want to test dimensions and # correctness vs a second implementation X = jax.random.uniform(random.generate_key(), shape=(n, input_dim)) Y = jax.random.uniform(random.generate_key(), shape=(n, output_dim)) XTX, XTY = compute_gram([(X, Y, None)], input_dim, output_dim, history_len) kernel, bias = fit_gram(XTX, XTY) n - history_len + 1 history = historify(X, history_len) history = history.reshape(history.shape[0], -1) expected_kernel, expected_bias = _compute_kernel_bias(history, Y[history_len - 1 :], alpha=1.0) expected_kernel = expected_kernel.reshape(1, history_len * input_dim, -1) np.testing.assert_array_almost_equal(expected_kernel, kernel, decimal=3) np.testing.assert_array_almost_equal(expected_bias, bias, decimal=3)
def test_gram_projection_xtx(X_dim, n, k, normalize, fit_intercept): """Test projection on X.T @ X""" X = jax.random.uniform(random.generate_key(), shape=(n, X_dim)) k = min(k, X_dim) projection = jax.random.uniform(random.generate_key(), shape=(X_dim, k)) gram = OnlineGram(X_dim) gram.update(X) if normalize: X = (X - X.mean(axis=0)) / X.std(axis=0) expected = projection.T @ X.T @ X @ projection if fit_intercept: expected = gram.fit_intercept(expected, normalize=normalize, projection=projection) np.testing.assert_array_almost_equal( gram.matrix(normalize=normalize, projection=projection, fit_intercept=fit_intercept), expected, decimal=1, )
def test_compute_projection_sklearn(shape): """Test PCA projection of X vs sklearn""" X = jax.random.uniform(random.generate_key(), shape=shape) projection = compute_projection(X, 1, center=True) pca = PCA(n_components=1) pca.fit(X) np.testing.assert_array_almost_equal(abs(projection), abs(pca.components_.T), decimal=3)
def test_gram_intercept_xty(X_dim, Y_dim, n, normalize): """Test intercept on X.T @ Y""" X = jax.random.uniform(random.generate_key(), shape=(n, X_dim)) Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim)) gram = OnlineGram(X_dim, Y_dim) gram.update(X, Y) np.testing.assert_array_almost_equal(gram.mean.squeeze(), X.mean(axis=0)) np.testing.assert_array_almost_equal(gram.std.squeeze(), X.std(axis=0)) assert gram.observations == X.shape[0] if normalize: X = (X - X.mean(axis=0)) / X.std(axis=0) Y = (Y - Y.mean(axis=0)) / Y.std(axis=0) X = jnp.hstack((jnp.ones((n, 1)), X)) np.testing.assert_array_almost_equal( gram.matrix(normalize=normalize, fit_intercept=True), X.T @ Y, decimal=1 )
def test_historify(m, n, history_len): """Test history-making""" X = jax.random.uniform(random.generate_key(), shape=(m, n)) if history_len < 1 or X.shape[0] - history_len < 0: with pytest.raises(ValueError): historify(X, history_len) else: batched = historify(X, history_len) batched = batched.reshape(batched.shape[0], -1) for i, batch in enumerate(batched): np.testing.assert_array_almost_equal(X[i : i + history_len].ravel().squeeze(), batch)
def test_gram_intercept_xtx(X_dim, n, normalize): """Test intercept on X.T @ X""" X = jax.random.uniform(random.generate_key(), shape=(n, X_dim)) gram = OnlineGram(X_dim) gram.update(X) if normalize: X = (X - X.mean(axis=0)) / X.std(axis=0) X = jnp.hstack((jnp.ones((n, 1)), X)) np.testing.assert_array_almost_equal( gram.matrix(normalize=normalize, fit_intercept=True), X.T @ X, decimal=1 )
def test_gram_projection_no_projection(): """Test empty projection""" X = jax.random.uniform(random.generate_key(), shape=(5, 10)) XTX = OnlineGram(10) XTX.update(X) np.testing.assert_array_almost_equal(X.T @ X, XTX.project())