def test_indiv_rank_0(data): dat = data["diff_views"] ajive = AJIVE(init_signal_ranks=[2, 2], individual_ranks=[0, 0]) _ = ajive.fit_transform(dat) Is = ajive.individual_mats_ assert_allclose(Is[0], 0) assert_allclose(Is[1], 0)
def test_same_joint(data, init_signal_ranks, individual_ranks, joint_rank): # Test same joint input result across varying inputs dat = data["same_views"] ajive = AJIVE( init_signal_ranks=init_signal_ranks, joint_rank=joint_rank, individual_ranks=individual_ranks, ) Js = ajive.fit_transform(Xs=dat) for i in np.arange(20): j = np.sum(Js[0][i] == Js[1][i]) assert_equal(j, 100)
U, S, V = np.linalg.svd(X2) vars2 = S**2 / np.sum(S**2) axes[0].plot(np.arange(10) + 1, vars1[:10], 'ro-', linewidth=2) axes[1].plot(np.arange(10) + 1, vars2[:10], 'ro-', linewidth=2) axes[0].set_title('Scree Plot View 1') axes[1].set_title('Scree Plot View 2') axes[0].set_xlabel('Number of top singular values') axes[1].set_xlabel('Number of top singular values') axes[0].set_ylabel('Percent variance explained') plt.show() # Based on the scree plots, we fit AJIVE with both initial signal ranks set to # 2. ajive1 = AJIVE(init_signal_ranks=[2, 2]) Js_1 = ajive1.fit_transform(Xs_same) ajive2 = AJIVE(init_signal_ranks=[2, 2]) Js_2 = ajive2.fit_transform(Xs_diff) ############################################################################### # Heatmap Visualizations # ---------------------- # # Here we are using heatmaps to visualize the decomposition of our views. As we # can see when we use two of the same views there is no Individualized # Variation displayed. When we create two different views, the algorithm finds # different decompositions where common and individual structural artifacts # can be seen in their corresponding heatmaps.
axes[0].set_title('Multiview Kmeans Clusters') plt.tight_layout() plt.show() ############################################################################### # Decomposition using AJIVE # ------------------------- # # We can also apply joint decomposition tools to find features across views # that are jointly related. Using :class:`mvlearn.decomposition.AJIVE`, we can # find genes and lipids that are jointly related. from mvlearn.decomposition import AJIVE # noqa: E402 ajive = AJIVE() Xs_joint = ajive.fit_transform(Xs) f, axes = plt.subplots(2, 1, figsize=(5, 5)) sort_idx = np.hstack((np.argsort(y[:20, 1]), np.argsort(y[20:, 1]) + 20)) y_ticks = [ diet_names[j] + f' ({genotype_names[i]})' if idx % 4 == 0 else '' for idx, (i, j) in enumerate(y[sort_idx]) ] gene_ticks = [ n if i in [31, 36, 76, 94] else '' for i, n in enumerate(dataset['gene_feature_names']) ] g = sns.heatmap(Xs_joint[0][sort_idx], yticklabels=y_ticks, cmap="RdBu_r",
def ajive(data): aj = AJIVE(init_signal_ranks=[3, 2], random_state=0) joint_mats = aj.fit_transform(Xs=data['diff_views']) aj.joint_mats = joint_mats return aj
def test_zero_rank_warn(data): # warn returning rank 0 joint dat = data["diff_views"] ajive = AJIVE(init_signal_ranks=[2, 2], joint_rank=0) with pytest.warns(RuntimeWarning): _ = ajive.fit_transform(dat)
def test_joint_rank_0(data): dat = data["diff_views"] ajive = AJIVE(init_signal_ranks=[2, 2], joint_rank=0) Js = ajive.fit_transform(dat) assert_allclose(Js[0], 0) assert_allclose(Js[1], 0)
def test_same_joint_indiv_length(data): dat = data["same_views"] ajive = AJIVE(init_signal_ranks=[2, 2]) Js = ajive.fit_transform(Xs=dat) Is = ajive.individual_mats_ assert_equal(Js[0].shape, Is[0].shape)