def test_ajive_plot_list(data): x = data["same_views"] ajive = AJIVE(init_signal_ranks=[2, 2]) ajive.fit(Xs=x) blocks = ajive.predict(return_dict=False) ajive_full_estimate_heatmaps(x, blocks, names=["x1", "x2"]) p = 1 assert p == 1
def test_ajive_plot(data): x = data["same_views"] ajive = AJIVE(init_signal_ranks=[2, 2]) ajive.fit(Xs=x) blocks = ajive.predict(return_dict=True) ajive_full_estimate_heatmaps(x, blocks) p = 1 assert p == 1
def test_indiv(data): dat = data["same_views"] ajive = AJIVE(init_signal_ranks=[2, 2]) ajive.fit(Xs=dat) blocks = ajive.predict(return_dict=True) for i in np.arange(100): j = np.sum(blocks[0]["individual"][i] == blocks[1]["individual"][i]) assert j == 20
def test_check_sparse(data): dat = data["sparse_views"] spar_mat = dat[0] assert np.sum(spar_mat == 0) > np.sum(spar_mat != 0) ajive = AJIVE(init_signal_ranks=[2, 2]) ajive.fit(Xs=dat) blocks = ajive.predict(return_dict=True) assert np.sum(np.sum(blocks[0]["individual"] == 0)) > np.sum( np.sum(blocks[0]["individual"] != 0))
def test_traditional_output(data): x = data["same_views"] ajive = AJIVE(init_signal_ranks=[2, 2]) ajive.fit(Xs=x, view_names=["x", "y"]) ajive.predict(return_dict=False)
def test_joint_noise_length(data): dat = data["same_views"] ajive = AJIVE(init_signal_ranks=[2, 2]) ajive.fit(Xs=dat) blocks = ajive.predict(return_dict=True) assert blocks[0]["joint"].shape == blocks[0]["noise"].shape