예제 #1
def make_sc_dataset(cnt_pth : str,
                    lbl_pth : str,
                    topn_genes : int = None,
                    gene_list_pth : str = None,
                    filter_genes : bool = False,
                    lbl_colname : str = 'bio_celltype',
                    min_counts : int = 300,
                    min_cells : int = 0,
                    transpose : bool = False,
                    upper_bound : int = None,
                    lower_bound : int = None,

    Generate CountData object for SC-data


    cnt_pth : str
        path to SC count data

    lbl_pth : str
        path to SC label data

    topn_genes : bool
        number of top expressed genes to

    gene_list_pth : str
        gene list

    lbl_colname : str
        name of column containing labels

    min_counts : int
        minimal number of observed
        counts assigned to a specific
        spot/cell for it to be included

    min_cells : int
        minimal number of occurances
        of a gene among all cells
        for it to be included

    transpose : bool
        transpose data

    lower_bound : int
          lower bound for the number of cells to
          include from each type
    upper_bound : int
         upper bound for the number of cells to
         include from each type


    CountData object for the SC data


    sc_ext = utils.get_extenstion(cnt_pth)

    if sc_ext == 'h5ad' :
        cnt,lbl = utils.read_h5ad_sc(cnt_pth,
        cnt = utils.read_file(cnt_pth,sc_ext)
        if transpose:
            cnt = cnt.T
        lbl = utils.read_file(lbl_pth)

        # get labels
        if lbl_colname is None:
            lbl = lbl.iloc[:,0]
            lbl = lbl.loc[:,lbl_colname]

    # match count and label data
    inter = cnt.index.intersection(lbl.index)
    if inter.shape[0] < 1:
        print("[ERROR] : single cell count and annotation"\
              " data did not match. Exiting.",
              file = sys.stderr,
    cnt = cnt.loc[inter,:]
    lbl = lbl.loc[inter]

    if upper_bound is not None or\
       lower_bound is not None:
        cnt,lbl = utils.subsample_data(cnt,

    # select top N expressed genes
    if topn_genes is not None:
        genesize = cnt.values.sum(axis = 0)
        topn_genes = np.min((topn_genes,genesize.shape[0]))
        sel = np.argsort(genesize)[::-1]
        sel = sel[0:topn_genes]
        cnt = cnt.iloc[:,sel]

    # only use genes in specific genes list
    # if specified
    if gene_list_pth is not None:
        with open(gene_list_pth,'r+') as fopen:
            gene_list = fopen.readlines()

        gene_list = pd.Index([ x.replace('\n','') for x in gene_list ])
        sel = cnt.columns.intersection(gene_list)
        cnt = cnt.loc[:,sel]

    # create sc data set
    dataset = CountData(cnt = cnt,
                        lbl = lbl)

    # filter genes based on names
    if filter_genes:

    # filter data based on quality
    if any([min_counts > 0,min_cells > 0]):
        dataset.filter_bad(min_counts = min_counts,
                           min_occurance = min_cells,

    return dataset
예제 #2
def run(prs : arp.ArgumentParser,
        args : arp.Namespace,
       )-> None:

    """Run analysis

    Depending on specified arguments performs
    either single cell parameter estimation,
    ST-data proportion estimates or both.

    prs : argparse.ArgumentParser
    args : argparse.Namespace


    # generate unique identifier for analysis
    timestamp = utils.generate_identifier()

    # ensure arguments are provided
    if len(sys.argv[1::]) < 2:

    # set output directory to cwd if none specified
    if args.out_dir is None:
        args.out_dir = getcwd()
    # create output directory if non-existant
    elif not osp.exists(args.out_dir):

    # instatiate logger
    log = utils.Logger(osp.join(args.out_dir,

    # convert args to list if not
    args.st_cnt = (args.st_cnt if \
                   isinstance(args.st_cnt,list) else \

    # set device
    if args.gpu:
        device = t.device('cuda')
        device = t.device('cpu')

    device = (device if is_available() else t.device('cpu'))
    log.info("Using device {}".format(str(device)))

    # If parameters should be fitted from sc data
    if not all(args.sc_fit):
        log.info(' | '.join(["fitting sc data",
                             "count file : {}".format(args.sc_cnt),
                             "labels file : {}".format(args.sc_labels),

        # control that paths to sc data exists
        if not all([osp.exists(args.sc_cnt)]):

            log.error(' '.join(["One or more of the specified paths to",
                                "the sc data does not exist"]))

        # load pre-fitted model if provided
        if args.sc_model is not None:
            log.info("loading state from provided sc_model")

        # Create data set for single cell data
        sc_data = D.make_sc_dataset(args.sc_cnt,
                                    topn_genes = args.topn_genes,
                                    gene_list_pth = args.gene_list,
                                    lbl_colname = args.label_colname,
                                    filter_genes = args.filter_genes,
                                    min_counts = args.min_sc_counts,
                                    min_cells = args.min_cells,
                                    transpose = args.sc_transpose,

        log.info(' '.join(["SC data GENES : {} ".format(sc_data.G),
                           "SC data CELLS : {} ".format(sc_data.M),
                           "SC data TYPES : {} ".format(sc_data.Z),

        # generate LossTracker object
        oname_loss_track = osp.join(args.out_dir,

        sc_loss_tracker = utils.LossTracker(oname_loss_track,
                                            interval = 100,
        # estimate parameters from single cell data
        sc_res  = fit.fit_sc_data(sc_data,
                                  loss_tracker = sc_loss_tracker,
                                  sc_epochs = args.sc_epochs,
                                  sc_batch_size = args.sc_batch_size,
                                  learning_rate = args.learning_rate,
                                  sc_from_model = args.sc_model,
                                  device = device,

        R,logits,sc_model = sc_res['rates'],sc_res['logits'],sc_res['model']
        # save sc model
        oname_sc_model = osp.join(args.out_dir,


        # save estimated parameters
        oname_R = osp.join(args.out_dir,

        oname_logits = osp.join(args.out_dir,


    # Load already estimated single cell parameters
    elif args.st_cnt is not None:
        log.info(' | '.join(["load sc parameter",
                             "rates (R) : {}".format(args.sc_fit[0]),
                             "logodds (logits) : {}".format(args.sc_fit[1]),
        R = utils.read_file(args.sc_fit[0])
        logits = utils.read_file(args.sc_fit[1])

    # If ST data is provided estiamte proportions
    if args.st_cnt[0] is not None:
        # generate identifiying tag for each section
        sectiontag = list(map(lambda x: '.'.join(osp.basename(x).split('.')[0:-1]),args.st_cnt))
        log.info("fit st data section(s) : {}".format(args.st_cnt))

        # check that provided files exist
        if not all([osp.exists(x) for x in args.st_cnt]):
            log.error("Some of the provided ST-data paths does not exist")

        if args.st_model is not None:
            log.info("loading state from provided st_model")

        # create data set for st data
        st_data =  D.make_st_dataset(args.st_cnt,
                                     topn_genes = args.topn_genes,
                                     min_counts = args.min_st_counts,
                                     min_spots = args.min_spots,
                                     filter_genes = args.filter_genes,
                                     transpose = args.st_transpose,

        log.info(' '.join(["ST data GENES : {} ".format(st_data.G),
                           "ST data SPOTS : {} ".format(st_data.M),

        # generate LossTracker object
        oname_loss_track = osp.join(args.out_dir,

        st_loss_tracker = utils.LossTracker(oname_loss_track,
                                            interval = 100,

        # estimate proportions of cell types within st data
        st_res = fit.fit_st_data(st_data,
                                 R = R,
                                 logits = logits,
                                 loss_tracker = st_loss_tracker,
                                 st_epochs = args.st_epochs,
                                 st_batch_size = args.st_batch_size,
                                 learning_rate = args.learning_rate,
                                 silent_mode = args.silent_mode,
                                 st_from_model = args.st_model,
                                 device = device,
                                 keep_noise = args.keep_noise,
                                 freeze_beta = args.freeze_beta,

        W,st_model = st_res['proportions'],st_res['model']

        # split joint matrix into multiple
        wlist = utils.split_joint_matrix(W)

        # save st model
        oname_st_model = osp.join(args.out_dir,


        # save st data proportion estimates results
        for s in range(len(wlist)):
            section_dir = osp.join(args.out_dir,sectiontag[s])
            if not osp.exists(section_dir):

            oname_W = osp.join(section_dir,'.'.join(['W',timestamp,'tsv']))
            log.info("saving proportions for section {} to {}".format(sectiontag[s],