def test_project(data, dtype): # get b indices b_idx = np.random.choice(data.shape[0], 10) # get remaining indices (for a) a_idx = np.setdiff1d(np.arange(data.shape[0]), b_idx) # split data data_csr = data.tocsr() a_data = data_csr[a_idx].tocoo() b_data = data_csr[b_idx].tocoo() # setup model for a_data a_model = scHPF(5, dtype=dtype) a_model._initialize(a_data) bp = a_model.bp #project b_model b_model = a_model.project(b_data) # check genes frozen assert_equal(b_model.eta, a_model.eta) assert_equal(b_model.beta, a_model.beta) # check cells different assert_equal(a_model.ncells, a_data.shape[0]) assert_equal(b_model.ncells, b_data.shape[0]) # check bp unchanged assert_equal(b_model.bp, bp) # check bp updates when we want c_model = a_model.project(b_data, recalc_bp=True) assert c_model.bp != bp
def test_combine_across_cells(data, dtype): # get b indices b_ixs = np.random.choice(data.shape[0], 10, replace=False) # get a indices (remaining) a_ixs = np.setdiff1d(np.arange(data.shape[0]), b_ixs) # split data data_csr = data.tocsr() a_data = data_csr[a_ixs].tocoo() b_data = data_csr[b_ixs].tocoo() # setup model for a_data a = scHPF(5, dtype=dtype) a._initialize(a_data) # setup model for b_data w/same dp, eta, beta b = scHPF(5, dtype=dtype, dp=a.dp, eta=a.eta, beta=a.beta) b._initialize(b_data, freeze_genes=True) ab = combine_across_cells(a, b, b_ixs) # check bp is None since it is different across the two models assert_equal(ab.bp, None) # check a locals where they should be in xi and eta assert_array_equal(ab.xi.vi_shape[a_ixs], a.xi.vi_shape) assert_array_equal(ab.xi.vi_rate[a_ixs], a.xi.vi_rate) assert_array_equal(ab.theta.vi_shape[a_ixs], a.theta.vi_shape) assert_array_equal(ab.theta.vi_rate[a_ixs], a.theta.vi_rate) # check b locals where they should be in xi and eta assert_array_equal(ab.xi.vi_shape[b_ixs], b.xi.vi_shape) assert_array_equal(ab.xi.vi_rate[b_ixs], b.xi.vi_rate) assert_array_equal(ab.theta.vi_shape[b_ixs], b.theta.vi_shape) assert_array_equal(ab.theta.vi_rate[b_ixs], b.theta.vi_rate) # check globals unchanged assert_equal(ab.eta, a.eta) assert_equal(ab.eta, b.eta) assert_equal(ab.beta, a.beta) assert_equal(ab.beta, b.beta)
def model_uninit(request): model = scHPF(N_FACTORS, dtype=request.param) return model
def projection_loss_function(loss_function, X, nfactors, model_kwargs={}, proj_kwargs={}): """ Project new data onto an existing model and calculate loss from it Parameters ---------- loss_function : function the loss function to use on the projected data X : coo_matrix Data to project onto the existing model. Can have an arbitrary number of rows (cells) > 0, but must have the same number of columns (genes) as the existing model nfactors : int Number of factors in model model_kwargs : dict, optional additional keyword arguments for scHPF() proj_kwargs : dict, optional additional keyword arguments for scHPF.project(). By default, max_iter=5, Returns ------- projection_loss_function : function A function which takes `a`, `ap`, `bp`, `c`, `cp`, `dp`, `eta`, and `beta` for an scHPF model, projects a fixed dataset onto it, and takes the loss (using a fixed function) with respect to both the model and the data's projection. """ # have to do import here to avoid issue with files importing each other from schpf import scHPF # make the model used for projection pmodel = scHPF(nfactors=nfactors, **model_kwargs) # actual loss function for data def _projection_loss_function(*, a, ap, bp, c, cp, dp, eta, beta, **kwargs): assert eta.dims[0] == beta.dims[0] assert beta.dims[1] == nfactors pmodel.a = a pmodel.ap = ap pmodel.bp = bp pmodel.c = c pmodel.cp = cp pmodel.dp = dp pmodel.eta = eta pmodel.beta = beta # defaults if not given if 'reinit' not in proj_kwargs: prj_kwargs['reinit'] = False if 'max_iter' not in proj_kwargs: proj_kwargs['max_iter'] = 10 if 'min_iter' not in proj_kwargs: proj_kwargs['min_iter'] = 10 if 'check_freq' not in proj_kwargs: proj_kwargs['check_freq'] = proj_kwargs['max_iter'] + 1 # do the projection pmodel.project(X, replace=True, **proj_kwargs) # calculate loss return loss_function(X, a=pmodel.a, ap=pmodel.ap, bp=pmodel.bp, c=pmodel.c, cp=pmodel.cp, dp=pmodel.dp, xi=pmodel.xi, eta=pmodel.eta, theta=pmodel.theta, beta=pmodel.beta) return _projection_loss_function