def make_sc_dataset(cnt_pth : str, lbl_pth : str, topn_genes : int = None, gene_list_pth : str = None, filter_genes : bool = False, lbl_colname : str = 'bio_celltype', min_counts : int = 300, min_cells : int = 0, transpose : bool = False, upper_bound : int = None, lower_bound : int = None, ): """ Generate CountData object for SC-data Parameter: --------- cnt_pth : str path to SC count data lbl_pth : str path to SC label data topn_genes : bool number of top expressed genes to include gene_list_pth : str gene list lbl_colname : str name of column containing labels min_counts : int minimal number of observed counts assigned to a specific spot/cell for it to be included min_cells : int minimal number of occurances of a gene among all cells for it to be included transpose : bool transpose data lower_bound : int lower bound for the number of cells to include from each type upper_bound : int upper bound for the number of cells to include from each type Returns: ------- CountData object for the SC data """ sc_ext = utils.get_extenstion(cnt_pth) if sc_ext == 'h5ad' : cnt,lbl = utils.read_h5ad_sc(cnt_pth, lbl_colname, lbl_pth, ) else: cnt = utils.read_file(cnt_pth,sc_ext) if transpose: cnt = cnt.T lbl = utils.read_file(lbl_pth) # get labels if lbl_colname is None: lbl = lbl.iloc[:,0] else: lbl = lbl.loc[:,lbl_colname] # match count and label data inter = cnt.index.intersection(lbl.index) if inter.shape[0] < 1: print("[ERROR] : single cell count and annotation"\ " data did not match. Exiting.", file = sys.stderr, ) cnt = cnt.loc[inter,:] lbl = lbl.loc[inter] if upper_bound is not None or\ lower_bound is not None: cnt,lbl = utils.subsample_data(cnt, lbl, lower_bound, upper_bound, ) # select top N expressed genes if topn_genes is not None: genesize = cnt.values.sum(axis = 0) topn_genes = np.min((topn_genes,genesize.shape[0])) sel = np.argsort(genesize)[::-1] sel = sel[0:topn_genes] cnt = cnt.iloc[:,sel] # only use genes in specific genes list # if specified if gene_list_pth is not None: with open(gene_list_pth,'r+') as fopen: gene_list = fopen.readlines() gene_list = pd.Index([ x.replace('\n','') for x in gene_list ]) sel = cnt.columns.intersection(gene_list) cnt = cnt.loc[:,sel] # create sc data set dataset = CountData(cnt = cnt, lbl = lbl) # filter genes based on names if filter_genes: dataset.filter_genes() # filter data based on quality if any([min_counts > 0,min_cells > 0]): dataset.filter_bad(min_counts = min_counts, min_occurance = min_cells, ) return dataset
def run(prs : arp.ArgumentParser, args : arp.Namespace, )-> None: """Run analysis Depending on specified arguments performs either single cell parameter estimation, ST-data proportion estimates or both. Parameter: --------- prs : argparse.ArgumentParser args : argparse.Namespace """ # generate unique identifier for analysis timestamp = utils.generate_identifier() # ensure arguments are provided if len(sys.argv[1::]) < 2: prs.print_help() sys.exit(-1) # set output directory to cwd if none specified if args.out_dir is None: args.out_dir = getcwd() # create output directory if non-existant elif not osp.exists(args.out_dir): mkdir(args.out_dir) # instatiate logger log = utils.Logger(osp.join(args.out_dir, '.'.join(['stsc', timestamp, 'log']) ) ) # convert args to list if not args.st_cnt = (args.st_cnt if \ isinstance(args.st_cnt,list) else \ [args.st_cnt]) # set device if args.gpu: device = t.device('cuda') else: device = t.device('cpu') device = (device if is_available() else t.device('cpu')) log.info("Using device {}".format(str(device))) # If parameters should be fitted from sc data if not all(args.sc_fit): log.info(' | '.join(["fitting sc data", "count file : {}".format(args.sc_cnt), "labels file : {}".format(args.sc_labels), ]) ) # control that paths to sc data exists if not all([osp.exists(args.sc_cnt)]): log.error(' '.join(["One or more of the specified paths to", "the sc data does not exist"])) sys.exit(-1) # load pre-fitted model if provided if args.sc_model is not None: log.info("loading state from provided sc_model") # Create data set for single cell data sc_data = D.make_sc_dataset(args.sc_cnt, args.sc_labels, topn_genes = args.topn_genes, gene_list_pth = args.gene_list, lbl_colname = args.label_colname, filter_genes = args.filter_genes, min_counts = args.min_sc_counts, min_cells = args.min_cells, transpose = args.sc_transpose, ) log.info(' '.join(["SC data GENES : {} ".format(sc_data.G), "SC data CELLS : {} ".format(sc_data.M), "SC data TYPES : {} ".format(sc_data.Z), ]) ) # generate LossTracker object oname_loss_track = osp.join(args.out_dir, '.'.join(["sc_loss",timestamp,"txt"]) ) sc_loss_tracker = utils.LossTracker(oname_loss_track, interval = 100, ) # estimate parameters from single cell data sc_res = fit.fit_sc_data(sc_data, loss_tracker = sc_loss_tracker, sc_epochs = args.sc_epochs, sc_batch_size = args.sc_batch_size, learning_rate = args.learning_rate, sc_from_model = args.sc_model, device = device, ) R,logits,sc_model = sc_res['rates'],sc_res['logits'],sc_res['model'] # save sc model oname_sc_model = osp.join(args.out_dir, '.'.join(['sc_model',timestamp,'pt'])) t.save(sc_model.state_dict(),oname_sc_model) # save estimated parameters oname_R = osp.join(args.out_dir, '.'.join(['R',timestamp,'tsv'])) oname_logits = osp.join(args.out_dir, '.'.join(['logits',timestamp,'tsv'])) utils.write_file(R,oname_R) utils.write_file(logits,oname_logits) # Load already estimated single cell parameters elif args.st_cnt is not None: log.info(' | '.join(["load sc parameter", "rates (R) : {}".format(args.sc_fit[0]), "logodds (logits) : {}".format(args.sc_fit[1]), ]) ) R = utils.read_file(args.sc_fit[0]) logits = utils.read_file(args.sc_fit[1]) # If ST data is provided estiamte proportions if args.st_cnt[0] is not None: # generate identifiying tag for each section sectiontag = list(map(lambda x: '.'.join(osp.basename(x).split('.')[0:-1]),args.st_cnt)) log.info("fit st data section(s) : {}".format(args.st_cnt)) # check that provided files exist if not all([osp.exists(x) for x in args.st_cnt]): log.error("Some of the provided ST-data paths does not exist") sys.exit(-1) if args.st_model is not None: log.info("loading state from provided st_model") # create data set for st data st_data = D.make_st_dataset(args.st_cnt, topn_genes = args.topn_genes, min_counts = args.min_st_counts, min_spots = args.min_spots, filter_genes = args.filter_genes, transpose = args.st_transpose, ) log.info(' '.join(["ST data GENES : {} ".format(st_data.G), "ST data SPOTS : {} ".format(st_data.M), ]) ) # generate LossTracker object oname_loss_track = osp.join(args.out_dir, '.'.join(["st_loss",timestamp,"txt"]) ) st_loss_tracker = utils.LossTracker(oname_loss_track, interval = 100, ) # estimate proportions of cell types within st data st_res = fit.fit_st_data(st_data, R = R, logits = logits, loss_tracker = st_loss_tracker, st_epochs = args.st_epochs, st_batch_size = args.st_batch_size, learning_rate = args.learning_rate, silent_mode = args.silent_mode, st_from_model = args.st_model, device = device, keep_noise = args.keep_noise, freeze_beta = args.freeze_beta, ) W,st_model = st_res['proportions'],st_res['model'] # split joint matrix into multiple wlist = utils.split_joint_matrix(W) # save st model oname_st_model = osp.join(args.out_dir, '.'.join(['st_model',timestamp,'pt'])) t.save(st_model.state_dict(),oname_st_model) # save st data proportion estimates results for s in range(len(wlist)): section_dir = osp.join(args.out_dir,sectiontag[s]) if not osp.exists(section_dir): mkdir(section_dir) oname_W = osp.join(section_dir,'.'.join(['W',timestamp,'tsv'])) log.info("saving proportions for section {} to {}".format(sectiontag[s], oname_W)) utils.write_file(wlist[s],oname_W)