Ejemplo n.º 1
0
    def test_centering(self):
        xmean = self.X.mean(axis=0)
        ymean = self.Y.mean(axis=0)

        self.assertTrue(np.allclose(self.ajive.centers_["x"], xmean))
        self.assertTrue(np.allclose(self.ajive.blocks_["x"].joint.m_, xmean))
        self.assertTrue(
            np.allclose(self.ajive.blocks_["x"].individual.m_, xmean))

        self.assertTrue(np.allclose(self.ajive.centers_["y"], ymean))
        self.assertTrue(np.allclose(self.ajive.blocks_["y"].joint.m_, ymean))
        self.assertTrue(
            np.allclose(self.ajive.blocks_["y"].individual.m_, ymean))

        # no centering
        ajive = AJIVE(init_signal_ranks=[2, 3], center=False)
        ajive = ajive.fit(Xs=[self.X, self.Y], view_names=["x", "y"])
        self.assertTrue(ajive.centers_["x"] is None)
        self.assertTrue(ajive.centers_["y"] is None)

        # only center x
        ajive = AJIVE(init_signal_ranks=[2, 3], center=[True, False])
        ajive = ajive.fit(Xs=[self.X, self.Y], view_names=["x", "y"])
        self.assertTrue(np.allclose(ajive.centers_["x"], xmean))
        self.assertTrue(ajive.centers_["y"] is None)
Ejemplo n.º 2
0
 def test_list_input(self):
     """
     Check AJIVE can take a list input.
     """
     ajive = AJIVE(init_signal_ranks=[2, 3])
     ajive.fit(Xs=[self.X, self.Y])
     self.assertTrue(set(ajive.block_names) == set([0, 1]))
Ejemplo n.º 3
0
def test_ajive_plot(data):
    x = data["same_views"]
    ajive = AJIVE(init_signal_ranks=[2, 2])
    ajive.fit(Xs=x)
    blocks = ajive.predict(return_dict=True)
    ajive_full_estimate_heatmaps(x, blocks)
    p = 1
    assert p == 1
Ejemplo n.º 4
0
def test_indiv(data):
    dat = data["same_views"]
    ajive = AJIVE(init_signal_ranks=[2, 2])
    ajive.fit(Xs=dat)
    blocks = ajive.predict(return_dict=True)
    for i in np.arange(100):
        j = np.sum(blocks[0]["individual"][i] == blocks[1]["individual"][i])
        assert j == 20
Ejemplo n.º 5
0
def test_ajive_plot_list(data):
    x = data["same_views"]
    ajive = AJIVE(init_signal_ranks=[2, 2])
    ajive.fit(Xs=x)
    blocks = ajive.predict(return_dict=False)
    ajive_full_estimate_heatmaps(x, blocks, names=["x1", "x2"])
    p = 1
    assert p == 1
Ejemplo n.º 6
0
def test_check_sparse(data):
    dat = data["sparse_views"]
    spar_mat = dat[0]
    assert np.sum(spar_mat == 0) > np.sum(spar_mat != 0)
    ajive = AJIVE(init_signal_ranks=[2, 2])
    ajive.fit(Xs=dat)
    blocks = ajive.predict(return_dict=True)
    assert np.sum(np.sum(blocks[0]["individual"] == 0)) > np.sum(
        np.sum(blocks[0]["individual"] != 0))
Ejemplo n.º 7
0
def test_wrong_sig(data):
    dat = data["diff_views"]
    ajive = AJIVE(init_signal_ranks=[-1, -4])
    try:
        ajive.fit(Xs=dat)
        j = 0
    except:
        j = 1
    assert j == 1
Ejemplo n.º 8
0
def test_precomp_init_svd(data):
    dat = data["same_views"]
    precomp = []
    for i in dat:
        precomp.append(svd_wrapper(i))
    ajive = AJIVE(init_signal_ranks=[2, 2], joint_rank=1)
    ajive.fit(dat, precomp_init_svd=precomp)
    p = 3
    assert p == 3
Ejemplo n.º 9
0
    def test_dont_store_full(self):
        """
        Make sure setting store_full = False works
        """
        ajive = AJIVE(init_signal_ranks=[2, 3], store_full=False)
        ajive.fit(Xs=[self.X, self.Y])

        self.assertTrue(ajive.blocks_[0].joint.full_ is None)
        self.assertTrue(ajive.blocks_[0].individual.full_ is None)
        self.assertTrue(ajive.blocks_[1].joint.full_ is None)
        self.assertTrue(ajive.blocks_[1].individual.full_ is None)
Ejemplo n.º 10
0
    def test_rank0(self):
        """
        Check setting joint/individual rank to zero works
        """
        ajive = AJIVE(init_signal_ranks=[2, 3], joint_rank=0)
        ajive.fit(Xs=[self.X, self.Y])
        self.assertTrue(ajive.common_.rank == 0)
        self.assertTrue(ajive.blocks_[0].joint.rank == 0)
        self.assertTrue(ajive.blocks_[0].joint.scores_ is None)

        ajive = AJIVE(init_signal_ranks=[2, 3], indiv_ranks=[0, 1])
        ajive.fit(Xs=[self.X, self.Y])
        self.assertTrue(ajive.blocks_[0].individual.rank == 0)
        self.assertTrue(ajive.blocks_[0].individual.scores_ is None)
Ejemplo n.º 11
0
def test_fit_elbows():
    n = 10
    elbows = 3
    np.random.seed(1)
    x = np.random.binomial(1, 0.6, (n**2)).reshape(n, n)
    xorth = orth(x)
    d = np.zeros(xorth.shape[0])
    for i in range(0, len(d), int(len(d) / (elbows + 1))):
        d[:i] += 10
    X = xorth.T.dot(np.diag(d)).dot(xorth)

    Xs = [X, X]

    ajive = AJIVE(n_elbows=2)
    ajive = ajive.fit(Xs)

    np.testing.assert_equal(list(ajive.init_signal_ranks_.values())[0], 4)
Ejemplo n.º 12
0
def test_check_joint_rank_large(data):
    with pytest.raises(ValueError):
        dat = data["same_views"]
        ajive = AJIVE(init_signal_ranks=[2, 2], joint_rank=5)
        ajive.fit(Xs=dat)
Ejemplo n.º 13
0
def test_name_values_type(data):
    with pytest.raises(ValueError):
        x = data["same_views"]
        ajive = AJIVE(init_signal_ranks=[2, 2])
        ajive.fit(Xs=x, view_names={"jon": "first", "rich": "second"})
Ejemplo n.º 14
0
def test_traditional_output(data):
    x = data["same_views"]
    ajive = AJIVE(init_signal_ranks=[2, 2])
    ajive.fit(Xs=x, view_names=["x", "y"])
    ajive.predict(return_dict=False)
Ejemplo n.º 15
0
def test_name_values(data):
    with pytest.raises(ValueError):
        x = data["same_views"]
        ajive = AJIVE(init_signal_ranks=[2, 2])
        ajive.fit(Xs=x, view_names=["1", "2", "3"])
Ejemplo n.º 16
0
def test_joint_noise_length(data):
    dat = data["same_views"]
    ajive = AJIVE(init_signal_ranks=[2, 2])
    ajive.fit(Xs=dat)
    blocks = ajive.predict(return_dict=True)
    assert blocks[0]["joint"].shape == blocks[0]["noise"].shape
Ejemplo n.º 17
0
def test_check_gen_lin_op_scipy(data):
    with pytest.raises(TypeError):
        dat = data["bad_views"]
        ajive = AJIVE(init_signal_ranks=[2, 2])
        ajive.fit(Xs=dat)
Ejemplo n.º 18
0
def test_indiv_rank(data):
    dat = data["same_views"]
    ajive = AJIVE(init_signal_ranks=[2, 2], indiv_ranks=[2, 1])
    ajive.fit(Xs=dat)
    assert ajive.indiv_ranks[0] == 2
Ejemplo n.º 19
0
def test_joint_rank(data):
    dat = data["same_views"]
    ajive = AJIVE(init_signal_ranks=[2, 2], joint_rank=2)
    ajive.fit(Xs=dat)
    assert ajive.joint_rank == 2