Esempio n. 1
0
def test_invalid_ranks(data, init_signal_ranks, individual_ranks):
    # invalid number of samples
    dat = data["diff_views"]
    with pytest.raises(ValueError):
        _ = AJIVE(init_signal_ranks=init_signal_ranks).fit_transform(dat)
    with pytest.raises(ValueError):
        _ = AJIVE(individual_ranks=individual_ranks).fit_transform(dat)
Esempio n. 2
0
def test_indiv_rank_0(data):
    dat = data["diff_views"]
    ajive = AJIVE(init_signal_ranks=[2, 2], individual_ranks=[0, 0])
    _ = ajive.fit_transform(dat)
    Is = ajive.individual_mats_
    assert_allclose(Is[0], 0)
    assert_allclose(Is[1], 0)
Esempio n. 3
0
def test_signal_rank_warning():
    # Throws warning signal rank is larger than possible rank, sets to max-1
    ajive = AJIVE(init_signal_ranks=[2, 2], joint_rank=2)
    a = np.vstack([[1, 1], [1, 1]])
    with pytest.warns(RuntimeWarning):
        ajive.fit([a, a])

    assert ajive.init_signal_ranks_ == [1, 1]
Esempio n. 4
0
def test_wrong_signal_ranks(data):
    # rank < 0
    dat = data["diff_views"]
    with pytest.raises(ValueError):
        _ = AJIVE(init_signal_ranks=[-1, -4]).fit_transform(dat)
    with pytest.raises(ValueError):
        _ = AJIVE(init_signal_ranks=[max(dat[0].shape) +
                                     1, -4]).fit_transform(dat)
Esempio n. 5
0
def test_same_indiv(data, init_signal_ranks, individual_ranks, joint_rank):
    # Test same indiv input result across varying inputs
    dat = data["same_views"]
    ajive = AJIVE(
        init_signal_ranks=init_signal_ranks,
        joint_rank=joint_rank,
        individual_ranks=individual_ranks,
    )
    ajive = ajive.fit(Xs=dat)
    Is = ajive.individual_mats_
    assert_allclose(Is[0], Is[1])
Esempio n. 6
0
def test_same_joint(data, init_signal_ranks, individual_ranks, joint_rank):
    # Test same joint input result across varying inputs
    dat = data["same_views"]
    ajive = AJIVE(
        init_signal_ranks=init_signal_ranks,
        joint_rank=joint_rank,
        individual_ranks=individual_ranks,
    )
    Js = ajive.fit_transform(Xs=dat)
    for i in np.arange(20):
        j = np.sum(Js[0][i] == Js[1][i])
        assert_equal(j, 100)
Esempio n. 7
0
def test_fit_elbows():
    n = 10
    elbows = 3
    np.random.seed(1)
    x = np.random.binomial(1, 0.6, (n**2)).reshape(n, n)
    xorth = orth(x)
    d = np.zeros(xorth.shape[0])
    for i in range(0, len(d), int(len(d) / (elbows + 1))):
        d[:i] += 10
    X = xorth.T.dot(np.diag(d)).dot(xorth)

    Xs = [X, X]

    ajive = AJIVE(n_elbows=2)
    ajive = ajive.fit(Xs)

    assert_equal(ajive.init_signal_ranks_[0], 4)
Esempio n. 8
0
def test_random_state(data):
    # Tests reproducible simulations
    dat = data["same_views"]
    ajive1 = AJIVE(init_signal_ranks=[2, 2], random_state=0)
    ajive1 = ajive1.fit(Xs=dat)
    ajive2 = AJIVE(init_signal_ranks=[2, 2], random_state=0)
    ajive2 = ajive2.fit(Xs=dat)
    assert_allclose(ajive1.wedin_samples_, ajive2.wedin_samples_)
    assert_allclose(ajive1.random_sv_samples_, ajive2.random_sv_samples_)
Esempio n. 9
0
def test_joint_rank_0(data):
    dat = data["diff_views"]
    ajive = AJIVE(init_signal_ranks=[2, 2], joint_rank=0)
    Js = ajive.fit_transform(dat)
    assert_allclose(Js[0], 0)
    assert_allclose(Js[1], 0)
Esempio n. 10
0
def ajive(data):
    aj = AJIVE(init_signal_ranks=[3, 2], random_state=0)
    joint_mats = aj.fit_transform(Xs=data['diff_views'])
    aj.joint_mats = joint_mats

    return aj
Esempio n. 11
0
def test_invalid_samples(n_wedin_samples, n_randdir_samples):
    # invalid number of samples
    with pytest.raises(ValueError):
        _ = AJIVE(n_wedin_samples=n_wedin_samples)
    with pytest.raises(ValueError):
        _ = AJIVE(n_randdir_samples=n_randdir_samples)
Esempio n. 12
0
def test_indiv_rank(data):
    dat = data["same_views"]
    ajive = AJIVE(init_signal_ranks=[2, 2], individual_ranks=[2, 1])
    ajive = ajive.fit(Xs=dat)
    assert_equal(ajive.individual_ranks_[0], 2)
Esempio n. 13
0
def test_signal_ranks_None():
    # Both init rank inputs are None
    with pytest.raises(ValueError):
        _ = AJIVE(init_signal_ranks=None, n_elbows=None)
Esempio n. 14
0
def test_zero_rank_warn(data):
    # warn returning rank 0 joint
    dat = data["diff_views"]
    ajive = AJIVE(init_signal_ranks=[2, 2], joint_rank=0)
    with pytest.warns(RuntimeWarning):
        _ = ajive.fit_transform(dat)
Esempio n. 15
0
def test_check_joint_rank_large():
    # Joint rank < sum(init_signal_ranks)
    with pytest.raises(ValueError):
        _ = AJIVE(init_signal_ranks=[2, 2], joint_rank=5)
Esempio n. 16
0
def test_joint_rank(data):
    dat = data["same_views"]
    ajive = AJIVE(init_signal_ranks=[2, 2], joint_rank=2)
    ajive = ajive.fit(Xs=dat)
    assert_equal(ajive.joint_rank, 2)
Esempio n. 17
0
axes[0].set_ylabel('MVMDS component 2')
axes[0].set_title('Multiview Kmeans Clusters')
plt.tight_layout()
plt.show()

###############################################################################
# Decomposition using AJIVE
# -------------------------
#
# We can also apply joint decomposition tools to find features across views
# that are jointly related. Using :class:`mvlearn.decomposition.AJIVE`, we can
# find genes and lipids that are jointly related.

from mvlearn.decomposition import AJIVE  # noqa: E402

ajive = AJIVE()
Xs_joint = ajive.fit_transform(Xs)

f, axes = plt.subplots(2, 1, figsize=(5, 5))
sort_idx = np.hstack((np.argsort(y[:20, 1]), np.argsort(y[20:, 1]) + 20))
y_ticks = [
    diet_names[j] + f' ({genotype_names[i]})' if idx % 4 == 0 else ''
    for idx, (i, j) in enumerate(y[sort_idx])
]

gene_ticks = [
    n if i in [31, 36, 76, 94] else ''
    for i, n in enumerate(dataset['gene_feature_names'])
]
g = sns.heatmap(Xs_joint[0][sort_idx],
                yticklabels=y_ticks,
Esempio n. 18
0
vars1 = S**2 / np.sum(S**2)
U, S, V = np.linalg.svd(X2)
vars2 = S**2 / np.sum(S**2)
axes[0].plot(np.arange(10) + 1, vars1[:10], 'ro-', linewidth=2)
axes[1].plot(np.arange(10) + 1, vars2[:10], 'ro-', linewidth=2)
axes[0].set_title('Scree Plot View 1')
axes[1].set_title('Scree Plot View 2')
axes[0].set_xlabel('Number of top singular values')
axes[1].set_xlabel('Number of top singular values')
axes[0].set_ylabel('Percent variance explained')
plt.show()

# Based on the scree plots, we fit AJIVE with both initial signal ranks set to
# 2.

ajive1 = AJIVE(init_signal_ranks=[2, 2])
Js_1 = ajive1.fit_transform(Xs_same)

ajive2 = AJIVE(init_signal_ranks=[2, 2])
Js_2 = ajive2.fit_transform(Xs_diff)

###############################################################################
# Heatmap Visualizations
# ----------------------
#
# Here we are using heatmaps to visualize the decomposition of our views. As we
# can see when we use two of the same views there is no Individualized
# Variation displayed. When we create two different views, the algorithm finds
# different decompositions where common and individual structural artifacts
# can be seen in their corresponding heatmaps.
Esempio n. 19
0
def test_same_joint_indiv_length(data):
    dat = data["same_views"]
    ajive = AJIVE(init_signal_ranks=[2, 2])
    Js = ajive.fit_transform(Xs=dat)
    Is = ajive.individual_mats_
    assert_equal(Js[0].shape, Is[0].shape)