def test_invalid_ranks(data, init_signal_ranks, individual_ranks): # invalid number of samples dat = data["diff_views"] with pytest.raises(ValueError): _ = AJIVE(init_signal_ranks=init_signal_ranks).fit_transform(dat) with pytest.raises(ValueError): _ = AJIVE(individual_ranks=individual_ranks).fit_transform(dat)
def test_indiv_rank_0(data): dat = data["diff_views"] ajive = AJIVE(init_signal_ranks=[2, 2], individual_ranks=[0, 0]) _ = ajive.fit_transform(dat) Is = ajive.individual_mats_ assert_allclose(Is[0], 0) assert_allclose(Is[1], 0)
def test_signal_rank_warning(): # Throws warning signal rank is larger than possible rank, sets to max-1 ajive = AJIVE(init_signal_ranks=[2, 2], joint_rank=2) a = np.vstack([[1, 1], [1, 1]]) with pytest.warns(RuntimeWarning): ajive.fit([a, a]) assert ajive.init_signal_ranks_ == [1, 1]
def test_wrong_signal_ranks(data): # rank < 0 dat = data["diff_views"] with pytest.raises(ValueError): _ = AJIVE(init_signal_ranks=[-1, -4]).fit_transform(dat) with pytest.raises(ValueError): _ = AJIVE(init_signal_ranks=[max(dat[0].shape) + 1, -4]).fit_transform(dat)
def test_same_indiv(data, init_signal_ranks, individual_ranks, joint_rank): # Test same indiv input result across varying inputs dat = data["same_views"] ajive = AJIVE( init_signal_ranks=init_signal_ranks, joint_rank=joint_rank, individual_ranks=individual_ranks, ) ajive = ajive.fit(Xs=dat) Is = ajive.individual_mats_ assert_allclose(Is[0], Is[1])
def test_same_joint(data, init_signal_ranks, individual_ranks, joint_rank): # Test same joint input result across varying inputs dat = data["same_views"] ajive = AJIVE( init_signal_ranks=init_signal_ranks, joint_rank=joint_rank, individual_ranks=individual_ranks, ) Js = ajive.fit_transform(Xs=dat) for i in np.arange(20): j = np.sum(Js[0][i] == Js[1][i]) assert_equal(j, 100)
def test_fit_elbows(): n = 10 elbows = 3 np.random.seed(1) x = np.random.binomial(1, 0.6, (n**2)).reshape(n, n) xorth = orth(x) d = np.zeros(xorth.shape[0]) for i in range(0, len(d), int(len(d) / (elbows + 1))): d[:i] += 10 X = xorth.T.dot(np.diag(d)).dot(xorth) Xs = [X, X] ajive = AJIVE(n_elbows=2) ajive = ajive.fit(Xs) assert_equal(ajive.init_signal_ranks_[0], 4)
def test_random_state(data): # Tests reproducible simulations dat = data["same_views"] ajive1 = AJIVE(init_signal_ranks=[2, 2], random_state=0) ajive1 = ajive1.fit(Xs=dat) ajive2 = AJIVE(init_signal_ranks=[2, 2], random_state=0) ajive2 = ajive2.fit(Xs=dat) assert_allclose(ajive1.wedin_samples_, ajive2.wedin_samples_) assert_allclose(ajive1.random_sv_samples_, ajive2.random_sv_samples_)
def test_joint_rank_0(data): dat = data["diff_views"] ajive = AJIVE(init_signal_ranks=[2, 2], joint_rank=0) Js = ajive.fit_transform(dat) assert_allclose(Js[0], 0) assert_allclose(Js[1], 0)
def ajive(data): aj = AJIVE(init_signal_ranks=[3, 2], random_state=0) joint_mats = aj.fit_transform(Xs=data['diff_views']) aj.joint_mats = joint_mats return aj
def test_invalid_samples(n_wedin_samples, n_randdir_samples): # invalid number of samples with pytest.raises(ValueError): _ = AJIVE(n_wedin_samples=n_wedin_samples) with pytest.raises(ValueError): _ = AJIVE(n_randdir_samples=n_randdir_samples)
def test_indiv_rank(data): dat = data["same_views"] ajive = AJIVE(init_signal_ranks=[2, 2], individual_ranks=[2, 1]) ajive = ajive.fit(Xs=dat) assert_equal(ajive.individual_ranks_[0], 2)
def test_signal_ranks_None(): # Both init rank inputs are None with pytest.raises(ValueError): _ = AJIVE(init_signal_ranks=None, n_elbows=None)
def test_zero_rank_warn(data): # warn returning rank 0 joint dat = data["diff_views"] ajive = AJIVE(init_signal_ranks=[2, 2], joint_rank=0) with pytest.warns(RuntimeWarning): _ = ajive.fit_transform(dat)
def test_check_joint_rank_large(): # Joint rank < sum(init_signal_ranks) with pytest.raises(ValueError): _ = AJIVE(init_signal_ranks=[2, 2], joint_rank=5)
def test_joint_rank(data): dat = data["same_views"] ajive = AJIVE(init_signal_ranks=[2, 2], joint_rank=2) ajive = ajive.fit(Xs=dat) assert_equal(ajive.joint_rank, 2)
axes[0].set_ylabel('MVMDS component 2') axes[0].set_title('Multiview Kmeans Clusters') plt.tight_layout() plt.show() ############################################################################### # Decomposition using AJIVE # ------------------------- # # We can also apply joint decomposition tools to find features across views # that are jointly related. Using :class:`mvlearn.decomposition.AJIVE`, we can # find genes and lipids that are jointly related. from mvlearn.decomposition import AJIVE # noqa: E402 ajive = AJIVE() Xs_joint = ajive.fit_transform(Xs) f, axes = plt.subplots(2, 1, figsize=(5, 5)) sort_idx = np.hstack((np.argsort(y[:20, 1]), np.argsort(y[20:, 1]) + 20)) y_ticks = [ diet_names[j] + f' ({genotype_names[i]})' if idx % 4 == 0 else '' for idx, (i, j) in enumerate(y[sort_idx]) ] gene_ticks = [ n if i in [31, 36, 76, 94] else '' for i, n in enumerate(dataset['gene_feature_names']) ] g = sns.heatmap(Xs_joint[0][sort_idx], yticklabels=y_ticks,
vars1 = S**2 / np.sum(S**2) U, S, V = np.linalg.svd(X2) vars2 = S**2 / np.sum(S**2) axes[0].plot(np.arange(10) + 1, vars1[:10], 'ro-', linewidth=2) axes[1].plot(np.arange(10) + 1, vars2[:10], 'ro-', linewidth=2) axes[0].set_title('Scree Plot View 1') axes[1].set_title('Scree Plot View 2') axes[0].set_xlabel('Number of top singular values') axes[1].set_xlabel('Number of top singular values') axes[0].set_ylabel('Percent variance explained') plt.show() # Based on the scree plots, we fit AJIVE with both initial signal ranks set to # 2. ajive1 = AJIVE(init_signal_ranks=[2, 2]) Js_1 = ajive1.fit_transform(Xs_same) ajive2 = AJIVE(init_signal_ranks=[2, 2]) Js_2 = ajive2.fit_transform(Xs_diff) ############################################################################### # Heatmap Visualizations # ---------------------- # # Here we are using heatmaps to visualize the decomposition of our views. As we # can see when we use two of the same views there is no Individualized # Variation displayed. When we create two different views, the algorithm finds # different decompositions where common and individual structural artifacts # can be seen in their corresponding heatmaps.
def test_same_joint_indiv_length(data): dat = data["same_views"] ajive = AJIVE(init_signal_ranks=[2, 2]) Js = ajive.fit_transform(Xs=dat) Is = ajive.individual_mats_ assert_equal(Js[0].shape, Is[0].shape)