def test_ajive_plot_list(data): x = data["same_views"] ajive = AJIVE(init_signal_ranks=[2, 2]) ajive.fit(Xs=x) blocks = ajive.transform(return_dict=False) ajive_full_estimate_heatmaps(x, blocks, names=["x1", "x2"]) p = 1 assert p == 1
def test_indiv(data): dat = data["same_views"] ajive = AJIVE(init_signal_ranks=[2, 2]) ajive.fit(Xs=dat) blocks = ajive.transform(return_dict=True) for i in np.arange(100): j = np.sum(blocks[0]["individual"][i] == blocks[1]["individual"][i]) assert j == 20
def test_check_sparse(data): dat = data["sparse_views"] spar_mat = dat[0] assert np.sum(spar_mat == 0) > np.sum(spar_mat != 0) ajive = AJIVE(init_signal_ranks=[2, 2]) ajive.fit(Xs=dat) blocks = ajive.transform(return_dict=True) assert np.sum(np.sum(blocks[0]["individual"] == 0)) > np.sum( np.sum(blocks[0]["individual"] != 0))
def test__repr__(data): dat = data["same_views"] ajive = AJIVE(init_signal_ranks=[2, 2]) assert ajive.__repr__() == "No data has been fitted yet" ajive.fit(Xs=dat) blocks = ajive.transform(return_dict=True) r = "joint rank: {}".format(ajive.common_.rank) for bn in ajive.block_names: indiv_rank = ajive.blocks_[bn].individual.rank r += ", block {} indiv rank: {}".format(bn, indiv_rank) assert ajive.__repr__() == r
def test_traditional_output(data): x = data["same_views"] ajive = AJIVE(init_signal_ranks=[2, 2]) ajive.fit(Xs=x, view_names=["x", "y"]) ajive.transform(return_dict=False)
def test_joint_noise_length(data): dat = data["same_views"] ajive = AJIVE(init_signal_ranks=[2, 2]) ajive.fit(Xs=dat) blocks = ajive.transform(return_dict=True) assert blocks[0]["joint"].shape == blocks[0]["noise"].shape