def test_n_components_change():
    from pymks import MKSStructureAnalysis
    from pymks import DiscreteIndicatorBasis
    dbasis = DiscreteIndicatorBasis(n_states=2)
    model = MKSStructureAnalysis(basis=dbasis)
    model.n_components = 27
    assert model.n_components == 27
def test_reshape_X():
    from pymks import MKSStructureAnalysis
    from pymks import PrimitiveBasis
    anaylzer = MKSStructureAnalysis(basis=PrimitiveBasis())
    X = np.arange(18, dtype='float64').reshape(2, 3, 3)
    X_test = np.concatenate((np.arange(-4, 5)[None], np.arange(-4, 5)[None]))
    assert np.allclose(anaylzer._reduce_shape(X), X_test)
def test_reshape_X():
    from pymks import MKSStructureAnalysis
    from pymks import PrimitiveBasis
    anaylzer = MKSStructureAnalysis(basis=PrimitiveBasis())
    X = np.arange(18, dtype='float64').reshape(2, 3, 3)
    X_test = np.concatenate((np.arange(-4, 5)[None], np.arange(-4, 5)[None]))
    assert np.allclose(anaylzer._reduce_shape(X), X_test)
def test_n_components_change():
    from pymks import MKSStructureAnalysis
    from pymks import DiscreteIndicatorBasis
    dbasis = DiscreteIndicatorBasis(n_states=2)
    model = MKSStructureAnalysis(basis=dbasis)
    model.n_components = 27
    assert model.n_components == 27
def test_set_components():
    from pymks import MKSStructureAnalysis
    from pymks import PrimitiveBasis
    p_basis = PrimitiveBasis(2)
    model = MKSStructureAnalysis(basis=p_basis)
    X = np.random.randint(2, size=(50, 10, 10))
    model.fit(X)
    components = model.components_
    model.components_ = components * 2
    assert np.allclose(model.components_, components * 2)
def test_set_components():
    from pymks import MKSStructureAnalysis
    from pymks import PrimitiveBasis
    p_basis = PrimitiveBasis(2)
    model = MKSStructureAnalysis(basis=p_basis)
    X = np.random.randint(2, size=(50, 10, 10))
    model.fit(X)
    components = model.components_
    model.components_ = components * 2
    assert np.allclose(model.components_, components * 2)
def test_default_correlations():
    from pymks import PrimitiveBasis
    from pymks import MKSStructureAnalysis
    prim_basis = PrimitiveBasis(6)
    model_prim = MKSStructureAnalysis(basis=prim_basis)
    assert model_prim.correlations == [(0, 0), (0, 1), (0, 2), (0, 3), (0, 4),
                                       (0, 5)]
def test_set_correlations():
    from pymks import PrimitiveBasis
    from pymks import MKSStructureAnalysis
    test_correlations = [(0, 0), (0, 2), (0, 4)]
    prim_basis = PrimitiveBasis(6)
    model_prim = MKSStructureAnalysis(basis=prim_basis,
                                      correlations=test_correlations)
    assert model_prim.correlations == test_correlations
def test_n_componets_from_reducer():
    from pymks import MKSStructureAnalysis
    from pymks import DiscreteIndicatorBasis
    from sklearn.manifold import LocallyLinearEmbedding
    reducer = LocallyLinearEmbedding(n_components=7)
    dbasis = DiscreteIndicatorBasis(n_states=3, domain=[0, 2])
    model = MKSStructureAnalysis(dimension_reducer=reducer, basis=dbasis)
    assert model.n_components == 7
def test_n_components_with_reducer():
    from pymks import MKSStructureAnalysis
    from pymks import DiscreteIndicatorBasis
    from sklearn.manifold import Isomap
    reducer = Isomap(n_components=7)
    dbasis = DiscreteIndicatorBasis(n_states=3, domain=[0, 2])
    model = MKSStructureAnalysis(dimension_reducer=reducer, basis=dbasis,
                                 n_components=9)
    assert model.n_components == 9
def test_store_correlations():
    from pymks import MKSStructureAnalysis
    from pymks import PrimitiveBasis
    from pymks.stats import correlate
    p_basis = PrimitiveBasis(2)
    model = MKSStructureAnalysis(basis=p_basis, store_correlations=True)
    X = np.random.randint(2, size=(2, 4, 4))
    model.fit(X)
    X = correlate(X, p_basis, correlations=[(0, 0), (0, 1)])
    assert np.allclose(X, model.fit_correlations)
    X_0 = np.random.randint(2, size=(2, 4, 4))
    model.transform(X_0)
    X_corr_0 = correlate(X_0, p_basis, correlations=[(0, 0), (0, 1)])
    assert np.allclose(X_corr_0, model.transform_correlations)
    X_1 = np.random.randint(2, size=(2, 4, 4))
    model.transform(X_1)
    X_corr_1 = correlate(X_1, p_basis, correlations=[(0, 0), (0, 1)])
    X_corr_ = np.concatenate((X_corr_0, X_corr_1))
    assert np.allclose(X_corr_, model.transform_correlations)
import numpy as np
from scipy.io import loadmat
import matplotlib.pyplot as plt
from pymks import MKSStructureAnalysis
from pymks.bases import GSHBasis
from pymks.tools import draw_components_scatter

data_dict = loadmat('orientaiton_data.mat')
keys = sorted(data_dict.keys())[:-3]
data_list = [data_dict[k] for k in keys]
shapes = [s.shape for s in data_list]
X = []
crop_size = (100, 100)  # (400, 1100)
for s, d in zip(shapes, data_list):
    x_0, x_1 = (s[0] - crop_size[0] - 1) / 2, - (s[0] - crop_size[0] - 1) / 2
    y_0, y_1 = (s[1] - crop_size[1] - 1) / 2, - (s[1] - crop_size[1] - 1) / 2
    X.append(d[x_0:x_1, y_0:y_1][None])
X = np.concatenate(X)
X_masks = np.sum(X, axis=-1) > 0
gsh_basis = GSHBasis(n_states=20, domain='cubic')
anaylzer = MKSStructureAnalysis(basis=gsh_basis, n_components=3)
y = anaylzer.fit_transform(X, confidence_index=X_masks)
labels = [k[:-3] for k in keys[::-3]]
draw_components_scatter(np.array_split(y, 9), labels=labels)
def test_default_dimension_reducer():
    from sklearn.decomposition import PCA
    from pymks import MKSStructureAnalysis
    from pymks import PrimitiveBasis
    model = MKSStructureAnalysis(basis=PrimitiveBasis())
    assert isinstance(model.dimension_reducer, PCA)