def test_gram_update(X_dim, Y_dim, n): """Test update""" X = jax.random.uniform(random.generate_key(), shape=(n, X_dim)) Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim)) gram = OnlineGram(X_dim, Y_dim) gram.update(X, Y) np.testing.assert_array_almost_equal(gram.matrix(normalize=False, fit_intercept=False), X.T @ Y)
def test_gram_normalize(X_dim, Y_dim, n): """Test normalize""" X = jax.random.uniform(random.generate_key(), shape=(n, X_dim)) Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim)) gram = OnlineGram(X_dim, Y_dim) gram.update(X, Y) X_norm = (X - X.mean(axis=0)) / X.std(axis=0) Y_norm = (Y - Y.mean(axis=0)) / Y.std(axis=0) np.testing.assert_array_almost_equal( gram.matrix(normalize=True, fit_intercept=False), X_norm.T @ Y_norm, decimal=1 )
def test_gram_update_iterative(X_dim, Y_dim, n): """Test update iteratively""" X = jax.random.uniform(random.generate_key(), shape=(n, X_dim)) Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim)) gram = OnlineGram(X_dim, Y_dim) for x, y in zip(X, Y): gram.update(x.reshape(1, -1), y.reshape(1, -1)) np.testing.assert_array_almost_equal( gram.matrix(normalize=False, fit_intercept=False), X.T @ Y, decimal=3 )
def compute_gram( data: Tuple[np.ndarray, np.ndarray, Any], input_dim: int, output_dim: int, history_len: int, ) -> Tuple[OnlineGram, OnlineGram]: """Compute X.T @ X and X.T @ Y on history windows incrementally""" num_features = input_dim * history_len XTX = OnlineGram(num_features) XTY = OnlineGram(num_features, output_dim) for X, Y, _ in data: X = internalize(X, input_dim)[0] Y = internalize(Y, output_dim)[0] if X.shape[0] != Y.shape[0]: raise ValueError("Input and output data must have the same number of observations") # Expand input time series X into histories, whic should result in a # (num_histories, history_len * input_dim)-shaped array history = historify(X, history_len=history_len) history = history.reshape(history.shape[0], -1) XTX.update(history) XTY.update(history, Y[history_len - 1 :]) if XTX.observations == 0: raise IndexError("No data to fit") if XTX.observations <= num_features: raise ValueError( "Underdetermined systems not currently supported (observations: {}," "features: {})".format(XTX.observations, num_features) ) return XTX, XTY
def test_gram_intercept_xtx(X_dim, n, normalize): """Test intercept on X.T @ X""" X = jax.random.uniform(random.generate_key(), shape=(n, X_dim)) gram = OnlineGram(X_dim) gram.update(X) if normalize: X = (X - X.mean(axis=0)) / X.std(axis=0) X = jnp.hstack((jnp.ones((n, 1)), X)) np.testing.assert_array_almost_equal( gram.matrix(normalize=normalize, fit_intercept=True), X.T @ X, decimal=1 )
def test_gram_update_value_error(): """Test update value error""" gram = OnlineGram(1, 1) with pytest.raises(ValueError): gram.update( jax.random.uniform(random.generate_key(), shape=(4, 1)), jax.random.uniform(random.generate_key(), shape=(1, 1)), ) with pytest.raises(ValueError): gram.update(jax.random.uniform(random.generate_key(), shape=(4, 1))) gram = OnlineGram(1) with pytest.raises(ValueError): gram.update(np.ones((1, 1)), np.ones((1, 1)))
def test_fit_gram_underdetermined(): """Test underdetermined""" XTX = OnlineGram(1) XTY = XTX with pytest.raises(ValueError): fit_gram(XTX, XTY)
def test_fit_constrained_bad_input_dim(): """Bad input for constrained""" XTX = OnlineGram(10) XTY = OnlineGram(5) XTX.update(jax.random.uniform(random.generate_key(), shape=(100, 10))) XTY.update(jax.random.uniform(random.generate_key(), shape=(100, 5))) with pytest.raises(ValueError): fit_gram(XTX, XTY, input_dim=7)
def test_gram_intercept_xty(X_dim, Y_dim, n, normalize): """Test intercept on X.T @ Y""" X = jax.random.uniform(random.generate_key(), shape=(n, X_dim)) Y = jax.random.uniform(random.generate_key(), shape=(n, Y_dim)) gram = OnlineGram(X_dim, Y_dim) gram.update(X, Y) np.testing.assert_array_almost_equal(gram.mean.squeeze(), X.mean(axis=0)) np.testing.assert_array_almost_equal(gram.std.squeeze(), X.std(axis=0)) assert gram.observations == X.shape[0] if normalize: X = (X - X.mean(axis=0)) / X.std(axis=0) Y = (Y - Y.mean(axis=0)) / Y.std(axis=0) X = jnp.hstack((jnp.ones((n, 1)), X)) np.testing.assert_array_almost_equal( gram.matrix(normalize=normalize, fit_intercept=True), X.T @ Y, decimal=1 )
def test_gram_intercept_constrained_projection(): """Constrained projection should error""" X = jax.random.uniform(random.generate_key(), shape=(5, 10)) XTX = OnlineGram(10) XTX.update(X) with pytest.raises(ValueError): XTX.fit_intercept(projection=1, input_dim=2)
def test_gram_projection_xtx(X_dim, n, k, normalize, fit_intercept): """Test projection on X.T @ X""" X = jax.random.uniform(random.generate_key(), shape=(n, X_dim)) k = min(k, X_dim) projection = jax.random.uniform(random.generate_key(), shape=(X_dim, k)) gram = OnlineGram(X_dim) gram.update(X) if normalize: X = (X - X.mean(axis=0)) / X.std(axis=0) expected = projection.T @ X.T @ X @ projection if fit_intercept: expected = gram.fit_intercept(expected, normalize=normalize, projection=projection) np.testing.assert_array_almost_equal( gram.matrix(normalize=normalize, projection=projection, fit_intercept=fit_intercept), expected, decimal=1, )
def test_gram_projection_no_projection(): """Test empty projection""" X = jax.random.uniform(random.generate_key(), shape=(5, 10)) XTX = OnlineGram(10) XTX.update(X) np.testing.assert_array_almost_equal(X.T @ X, XTX.project())
def fit_gram( XTX: OnlineGram, XTY: OnlineGram, alpha: float = 1.0, normalize: bool = False, projection: np.ndarray = None, input_dim: int = None, ): """Compute linear regression parameters from gram matrix Notes: * Assumes over-determined systems * Assumes we always fit an intercept """ fit_intercept = True feature_dim = XTX.feature_dim if projection is None else projection.shape[1] output_dim = XTY.output_dim if input_dim is None: history_len = None else: if feature_dim % input_dim != 0: raise ValueError("Original input dimension must evenly divide feature dimensions") history_len = feature_dim // input_dim if XTX.observations <= feature_dim: raise ValueError( "Underdetermined systems not currently supported (observations: {}," "features: {})".format(XTX.observations, feature_dim) ) # Finalize gram matrices XTX = XTX.matrix( normalize=normalize, projection=projection, fit_intercept=fit_intercept, input_dim=input_dim, ) XTY = XTY.matrix( normalize=normalize, projection=projection, fit_intercept=fit_intercept, input_dim=input_dim, ) inv = _compute_xtx_inverse(XTX, alpha) beta = _fit_unconstrained(inv, XTY) # Ignore constrained regression if we project if input_dim is not None: R, r = _form_constraints( input_dim=input_dim, output_dim=output_dim, history_len=history_len, fit_intercept=fit_intercept, ) beta = _fit_constrained(beta, inv, R, r) beta = beta.take(jnp.arange(0, len(beta), input_dim), axis=0) return beta[1:], beta[0] return beta[1:].reshape(1, feature_dim, -1), beta[0]