def test_set_components(): from pymks import MKSStructureAnalysis from pymks import PrimitiveBasis p_basis = PrimitiveBasis(2) model = MKSStructureAnalysis(basis=p_basis) X = np.random.randint(2, size=(50, 10, 10)) model.fit(X) components = model.components_ model.components_ = components * 2 assert np.allclose(model.components_, components * 2)
def test_store_correlations(): from pymks import MKSStructureAnalysis from pymks import PrimitiveBasis from pymks.stats import correlate p_basis = PrimitiveBasis(2) model = MKSStructureAnalysis(basis=p_basis, store_correlations=True) X = np.random.randint(2, size=(2, 4, 4)) model.fit(X) X = correlate(X, p_basis, correlations=[(0, 0), (0, 1)]) assert np.allclose(X, model.fit_correlations) X_0 = np.random.randint(2, size=(2, 4, 4)) model.transform(X_0) X_corr_0 = correlate(X_0, p_basis, correlations=[(0, 0), (0, 1)]) assert np.allclose(X_corr_0, model.transform_correlations) X_1 = np.random.randint(2, size=(2, 4, 4)) model.transform(X_1) X_corr_1 = correlate(X_1, p_basis, correlations=[(0, 0), (0, 1)]) X_corr_ = np.concatenate((X_corr_0, X_corr_1)) assert np.allclose(X_corr_, model.transform_correlations)