def test_list_input(self): """ Check AJIVE can take a list input. """ ajive = AJIVE(init_signal_ranks=[2, 3]) ajive.fit(blocks=[self.X, self.Y]) self.assertTrue(set(ajive.block_names) == set([0, 1]))
def test_parallel_runs(self): """ Check wedin/random samples works with parallel processing. """ ajive = AJIVE(init_signal_ranks={'x': 2, 'y': 3}, n_jobs=-1) ajive.fit(blocks={'x': self.X, 'y': self.Y}) self.assertTrue(hasattr(ajive, 'blocks'))
def test_dont_store_full(self): """ Make sure setting store_full = False works """ ajive = AJIVE(init_signal_ranks=[2, 3], store_full=False) ajive.fit(blocks=[self.X, self.Y]) self.assertTrue(ajive.blocks[0].joint.full_ is None) self.assertTrue(ajive.blocks[0].individual.full_ is None) self.assertTrue(ajive.blocks[1].joint.full_ is None) self.assertTrue(ajive.blocks[1].individual.full_ is None)
def setUp(self): X, Y = generate_data_ajive_fig2() obs_names = ['sample_{}'.format(i) for i in range(X.shape[0])] var_names = { 'x': ['x_var_{}'.format(i) for i in range(X.shape[1])], 'y': ['y_var_{}'.format(i) for i in range(Y.shape[1])] } X = pd.DataFrame(X, index=obs_names, columns=var_names['x']) Y = pd.DataFrame(Y, index=obs_names, columns=var_names['y']) ajive = AJIVE(init_signal_ranks={ 'x': 2, 'y': 3 }).fit(blocks={ 'x': X, 'y': Y }) self.ajive = ajive self.X = X self.Y = Y self.obs_names = obs_names self.var_names = var_names
def ajive(data, joint): #Only overlapping cancer cell lines will be used blocks = {j: data.loc[j] for j in data.index.levels[0]} ccls = {k: v.index for k, v in blocks.items()} overlap = reduce(np.intersect1d, ccls.values()) blocks = {k: v.loc[overlap] for k, v in blocks.items()} #Calculates the number of singular values that should be used for the PCA #projections of the blocks init = {k: svals(v, joint) for k, v in blocks.items()} model = AJIVE(init, joint_rank=joint) model.fit(blocks) result = ajive_predict(model, data) return result, model
def test_centering(self): xmean = self.X.mean(axis=0) ymean = self.Y.mean(axis=0) self.assertTrue(np.allclose(self.ajive.centers_['x'], xmean)) self.assertTrue(np.allclose(self.ajive.blocks['x'].joint.m_, xmean)) self.assertTrue( np.allclose(self.ajive.blocks['x'].individual.m_, xmean)) self.assertTrue(np.allclose(self.ajive.centers_['y'], ymean)) self.assertTrue(np.allclose(self.ajive.blocks['y'].joint.m_, ymean)) self.assertTrue( np.allclose(self.ajive.blocks['y'].individual.m_, ymean)) # no centering ajive = AJIVE(init_signal_ranks={'x': 2, 'y': 3}, center=False) ajive = ajive.fit(blocks={'x': self.X, 'y': self.Y}) self.assertTrue(ajive.centers_['x'] is None) self.assertTrue(ajive.centers_['y'] is None) # only center x ajive = AJIVE(init_signal_ranks={ 'x': 2, 'y': 3 }, center={ 'x': True, 'y': False }) ajive = ajive.fit(blocks={'x': self.X, 'y': self.Y}) self.assertTrue(np.allclose(ajive.centers_['x'], xmean)) self.assertTrue(ajive.centers_['y'] is None)
def test_rank0(self): """ Check setting joint/individual rank to zero works """ ajive = AJIVE(init_signal_ranks=[2, 3], joint_rank=0) ajive.fit(blocks=[self.X, self.Y]) self.assertTrue(ajive.common.rank == 0) self.assertTrue(ajive.blocks[0].joint.rank == 0) self.assertTrue(ajive.blocks[0].joint.scores_ is None) ajive = AJIVE(init_signal_ranks=[2, 3], indiv_ranks=[0, 1]) ajive.fit(blocks=[self.X, self.Y]) self.assertTrue(ajive.blocks[0].individual.rank == 0) self.assertTrue(ajive.blocks[0].individual.scores_ is None)
X, Y = generate_data_ajive_fig2() plt.figure(figsize=[6.5, 3]) data_block_heatmaps({'x': X, 'y': Y}) plt.savefig('figures/data_heatmaps.png', bbox_inches='tight') plt.close() # determine initial signal ranks by inspecting scree plots plt.figure(figsize=[8.4, 3]) plt.subplot(1, 2, 1) PCA().fit(X).plot_scree() plt.subplot(1, 2, 2) PCA().fit(Y).plot_scree() plt.savefig('figures/scree_plots.png', bbox_inches='tight') plt.close() ajive = AJIVE(init_signal_ranks={'x': 2, 'y': 3}) ajive.fit(blocks={'x': X, 'y': Y}) plt.figure(figsize=[6.5, 12]) jive_full_estimate_heatmaps(ajive.get_full_block_estimates(), blocks={ 'x': X, 'y': Y }) plt.savefig('figures/jive_estimate_heatmaps.png', bbox_inches='tight') plt.close() plt.figure(figsize=[7, 7]) ajive.plot_joint_diagnostic() plt.savefig('figures/jive_diagnostic.png', bbox_inches='tight') plt.close()