예제 #1
0
def make_sc_dataset(cnt_pth : str,
                    lbl_pth : str,
                    topn_genes : int = None,
                    gene_list_pth : str = None,
                    filter_genes : bool = False,
                    lbl_colname : str = 'bio_celltype',
                    min_counts : int = 300,
                    min_cells : int = 0,
                    transpose : bool = False,
                    upper_bound : int = None,
                    lower_bound : int = None,
                    ):

    """
    Generate CountData object for SC-data

    Parameter:
    ---------

    cnt_pth : str
        path to SC count data

    lbl_pth : str
        path to SC label data

    topn_genes : bool
        number of top expressed genes to
        include

    gene_list_pth : str
        gene list

    lbl_colname : str
        name of column containing labels

    min_counts : int
        minimal number of observed
        counts assigned to a specific
        spot/cell for it to be included

    min_cells : int
        minimal number of occurances
        of a gene among all cells
        for it to be included

    transpose : bool
        transpose data

    lower_bound : int
          lower bound for the number of cells to
          include from each type
    upper_bound : int
         upper bound for the number of cells to
         include from each type


    Returns:
    -------

    CountData object for the SC data

    """

    sc_ext = utils.get_extenstion(cnt_pth)

    if sc_ext == 'h5ad' :
        cnt,lbl = utils.read_h5ad_sc(cnt_pth,
                                     lbl_colname,
                                     lbl_pth,
                                     )
    else:
        cnt = utils.read_file(cnt_pth,sc_ext)
        if transpose:
            cnt = cnt.T
        lbl = utils.read_file(lbl_pth)

        # get labels
        if lbl_colname is None:
            lbl = lbl.iloc[:,0]
        else:
            lbl = lbl.loc[:,lbl_colname]

    # match count and label data
    inter = cnt.index.intersection(lbl.index)
    if inter.shape[0] < 1:
        print("[ERROR] : single cell count and annotation"\
              " data did not match. Exiting.",
              file = sys.stderr,
              )
    cnt = cnt.loc[inter,:]
    lbl = lbl.loc[inter]


    if upper_bound is not None or\
       lower_bound is not None:
        cnt,lbl = utils.subsample_data(cnt,
                                       lbl,
                                       lower_bound,
                                       upper_bound,
                                       )

    # select top N expressed genes
    if topn_genes is not None:
        genesize = cnt.values.sum(axis = 0)
        topn_genes = np.min((topn_genes,genesize.shape[0]))
        sel = np.argsort(genesize)[::-1]
        sel = sel[0:topn_genes]
        cnt = cnt.iloc[:,sel]

    # only use genes in specific genes list
    # if specified
    if gene_list_pth is not None:
        with open(gene_list_pth,'r+') as fopen:
            gene_list = fopen.readlines()

        gene_list = pd.Index([ x.replace('\n','') for x in gene_list ])
        sel = cnt.columns.intersection(gene_list)
        cnt = cnt.loc[:,sel]

    # create sc data set
    dataset = CountData(cnt = cnt,
                        lbl = lbl)

    # filter genes based on names
    if filter_genes:
        dataset.filter_genes()

    # filter data based on quality
    if any([min_counts > 0,min_cells > 0]):
        dataset.filter_bad(min_counts = min_counts,
                           min_occurance = min_cells,
                          )

    return dataset
예제 #2
0
def run(prs : arp.ArgumentParser,
        args : arp.Namespace,
       )-> None:

    """Run analysis

    Depending on specified arguments performs
    either single cell parameter estimation,
    ST-data proportion estimates or both.

    Parameter:
    ---------
    prs : argparse.ArgumentParser
    args : argparse.Namespace

    """

    # generate unique identifier for analysis
    timestamp = utils.generate_identifier()

    # ensure arguments are provided
    if len(sys.argv[1::]) < 2:
        prs.print_help()
        sys.exit(-1)

    # set output directory to cwd if none specified
    if args.out_dir is None:
        args.out_dir = getcwd()
    # create output directory if non-existant
    elif not osp.exists(args.out_dir):
        mkdir(args.out_dir)

    # instatiate logger
    log = utils.Logger(osp.join(args.out_dir,
                                '.'.join(['stsc',
                                          timestamp,
                                          'log'])
                               )
                      )

    # convert args to list if not
    args.st_cnt = (args.st_cnt if \
                   isinstance(args.st_cnt,list) else \
                   [args.st_cnt])

    # set device
    if args.gpu:
        device = t.device('cuda')
    else:
        device = t.device('cpu')

    device = (device if is_available() else t.device('cpu'))
    log.info("Using device {}".format(str(device)))

    # If parameters should be fitted from sc data
    if not all(args.sc_fit):
        log.info(' | '.join(["fitting sc data",
                             "count file : {}".format(args.sc_cnt),
                             "labels file : {}".format(args.sc_labels),
                            ])
                )

        # control that paths to sc data exists
        if not all([osp.exists(args.sc_cnt)]):

            log.error(' '.join(["One or more of the specified paths to",
                                "the sc data does not exist"]))
            sys.exit(-1)

        # load pre-fitted model if provided
        if args.sc_model is not None:
            log.info("loading state from provided sc_model")

        # Create data set for single cell data
        sc_data = D.make_sc_dataset(args.sc_cnt,
                                    args.sc_labels,
                                    topn_genes = args.topn_genes,
                                    gene_list_pth = args.gene_list,
                                    lbl_colname = args.label_colname,
                                    filter_genes = args.filter_genes,
                                    min_counts = args.min_sc_counts,
                                    min_cells = args.min_cells,
                                    transpose = args.sc_transpose,
                                    )

        log.info(' '.join(["SC data GENES : {} ".format(sc_data.G),
                           "SC data CELLS : {} ".format(sc_data.M),
                           "SC data TYPES : {} ".format(sc_data.Z),
                           ])
                )

        # generate LossTracker object
        oname_loss_track = osp.join(args.out_dir,
                                    '.'.join(["sc_loss",timestamp,"txt"])
                                   )

        sc_loss_tracker = utils.LossTracker(oname_loss_track,
                                            interval = 100,
                                            )
        # estimate parameters from single cell data
        sc_res  = fit.fit_sc_data(sc_data,
                                  loss_tracker = sc_loss_tracker,
                                  sc_epochs = args.sc_epochs,
                                  sc_batch_size = args.sc_batch_size,
                                  learning_rate = args.learning_rate,
                                  sc_from_model = args.sc_model,
                                  device = device,
                                 )

        R,logits,sc_model = sc_res['rates'],sc_res['logits'],sc_res['model']
        # save sc model
        oname_sc_model = osp.join(args.out_dir,
                                  '.'.join(['sc_model',timestamp,'pt']))

        t.save(sc_model.state_dict(),oname_sc_model)

        # save estimated parameters
        oname_R = osp.join(args.out_dir,
                           '.'.join(['R',timestamp,'tsv']))

        oname_logits = osp.join(args.out_dir,
                                '.'.join(['logits',timestamp,'tsv']))

        utils.write_file(R,oname_R)
        utils.write_file(logits,oname_logits)

    # Load already estimated single cell parameters
    elif args.st_cnt is not None:
        log.info(' | '.join(["load sc parameter",
                             "rates (R) : {}".format(args.sc_fit[0]),
                             "logodds (logits) : {}".format(args.sc_fit[1]),
                            ])
                )
        R = utils.read_file(args.sc_fit[0])
        logits = utils.read_file(args.sc_fit[1])

    # If ST data is provided estiamte proportions
    if args.st_cnt[0] is not None:
        # generate identifiying tag for each section
        sectiontag = list(map(lambda x: '.'.join(osp.basename(x).split('.')[0:-1]),args.st_cnt))
        log.info("fit st data section(s) : {}".format(args.st_cnt))

        # check that provided files exist
        if not all([osp.exists(x) for x in args.st_cnt]):
            log.error("Some of the provided ST-data paths does not exist")
            sys.exit(-1)

        if args.st_model is not None:
            log.info("loading state from provided st_model")

        # create data set for st data
        st_data =  D.make_st_dataset(args.st_cnt,
                                     topn_genes = args.topn_genes,
                                     min_counts = args.min_st_counts,
                                     min_spots = args.min_spots,
                                     filter_genes = args.filter_genes,
                                     transpose = args.st_transpose,
                                    )

        log.info(' '.join(["ST data GENES : {} ".format(st_data.G),
                           "ST data SPOTS : {} ".format(st_data.M),
                           ])
                )

        # generate LossTracker object
        oname_loss_track = osp.join(args.out_dir,
                                    '.'.join(["st_loss",timestamp,"txt"])
                                   )

        st_loss_tracker = utils.LossTracker(oname_loss_track,
                                            interval = 100,
                                           )

        # estimate proportions of cell types within st data
        st_res = fit.fit_st_data(st_data,
                                 R = R,
                                 logits = logits,
                                 loss_tracker = st_loss_tracker,
                                 st_epochs = args.st_epochs,
                                 st_batch_size = args.st_batch_size,
                                 learning_rate = args.learning_rate,
                                 silent_mode = args.silent_mode,
                                 st_from_model = args.st_model,
                                 device = device,
                                 keep_noise = args.keep_noise,
                                 freeze_beta = args.freeze_beta,
                                 )

        W,st_model = st_res['proportions'],st_res['model']

        # split joint matrix into multiple
        wlist = utils.split_joint_matrix(W)

        # save st model
        oname_st_model = osp.join(args.out_dir,
                                  '.'.join(['st_model',timestamp,'pt']))

        t.save(st_model.state_dict(),oname_st_model)

        # save st data proportion estimates results
        for s in range(len(wlist)):
            section_dir = osp.join(args.out_dir,sectiontag[s])
            if not osp.exists(section_dir):
                mkdir(section_dir)

            oname_W = osp.join(section_dir,'.'.join(['W',timestamp,'tsv']))
            log.info("saving proportions for section {} to {}".format(sectiontag[s],
                                                                      oname_W))
            utils.write_file(wlist[s],oname_W)