def compute_gram( data: Tuple[np.ndarray, np.ndarray, Any], input_dim: int, output_dim: int, history_len: int, ) -> Tuple[OnlineGram, OnlineGram]: """Compute X.T @ X and X.T @ Y on history windows incrementally""" num_features = input_dim * history_len XTX = OnlineGram(num_features) XTY = OnlineGram(num_features, output_dim) for X, Y, _ in data: X = internalize(X, input_dim)[0] Y = internalize(Y, output_dim)[0] if X.shape[0] != Y.shape[0]: raise ValueError("Input and output data must have the same number of observations") # Expand input time series X into histories, whic should result in a # (num_histories, history_len * input_dim)-shaped array history = historify(X, history_len=history_len) history = history.reshape(history.shape[0], -1) XTX.update(history) XTY.update(history, Y[history_len - 1 :]) if XTX.observations == 0: raise IndexError("No data to fit") if XTX.observations <= num_features: raise ValueError( "Underdetermined systems not currently supported (observations: {}," "features: {})".format(XTX.observations, num_features) ) return XTX, XTY
def test_internalize_scalar(n, expected): """Test scalar""" if expected: _, is_value, _, _ = internalize(4, n) assert is_value == expected else: with pytest.raises(ValueError): internalize(4, n)
def test_internalize(n, shape, expected): """Test internalize""" X = jnp.ones(shape) X, is_value, dim, _ = internalize(X, n) assert is_value == expected[0] np.testing.assert_array_equal(X, expected[1])
def update(self, observation: Union[Real, jnp.ndarray]) -> None: """Update with new observation""" observation, is_value, _, _ = internalize(observation, self._dim) num_observations = observation.shape[0] prev_mean = self._mean curr_mean = observation if is_value else observation.mean(axis=0) self._mean = (self._observations * prev_mean + num_observations * curr_mean) / (self._observations + num_observations) prev_var = self._var curr_var = 0 if is_value else observation.var(axis=0) self._var = (self._observations * prev_var + num_observations * curr_var + self._observations * ((prev_mean - self._mean)**2) + num_observations * ((curr_mean - self._mean)**2)) / (self._observations + num_observations) self._sum += observation if is_value else observation.sum(axis=0) self._observations += num_observations
def test_internalize_type_error(): """Test type error""" with pytest.raises(TypeError): internalize(jnp.ones((1, 2, 3)), 10)
def test_internalize_value_error(n, shape): """Test value error""" with pytest.raises(ValueError): internalize(jnp.ones(shape), n)