Beispiel #1
0
def plot_state_annotation_relationship(model,
                                       storage,
                                       labels,
                                       title,
                                       plottype='boxplot',
                                       threshold=0.0,
                                       groupby=None):

    make_folders(os.path.join(storage, 'annotation'))

    fig, axes = plt.subplots(len(labels))

    if len(labels) == 1:
        axes = [axes]

    segdf = model._segments[model._segments.Prob_max >= threshold].copy()

    for ax, label in zip(axes, labels):
        if plottype == 'countplot':
            sns.countplot(y='name', hue=label, data=segdf, ax=ax)
        elif plottype == 'boxplot':
            #segdf['log_'+label] = np.log10(segdf[label]+1)
            sns.boxplot(x='log_' + label,
                        y='name',
                        data=segdf,
                        hue=groupby,
                        orient='h',
                        ax=ax)

    logging.debug('writing {}'.format(
        os.path.join(storage, 'annotation', '{}.png'.format(title))))
    fig.tight_layout()
    fig.savefig(os.path.join(storage, 'annotation', '{}.png'.format(title)))
Beispiel #2
0
def make_state_summary(model, output, labels):
    """ Make and save summary statistics."""
    make_folders(os.path.join(output, 'summary'))
    model.get_state_frequency().to_csv(
        os.path.join(output, 'summary', 'statesummary.csv'))
    fig, ax = plt.subplots()
    model.plot_state_frequency(ax=ax)
    fig.savefig(os.path.join(output, 'summary', 'state_abundance.svg'))
    plt.close(fig)
    fig, ax = plt.subplots()
    model.plot_readdepth(ax)
    fig.savefig(os.path.join(output, 'summary', 'state_readdepth.svg'))
    plt.close(fig)
Beispiel #3
0
def plot_fragmentsize(scmodel, output, labels, cmats):
    resultspath = os.path.join(output, 'summary')
    make_folders(resultspath)

    bed = BedTool([Interval(row.chrom, row.start, row.end) \
                   for _, row in scmodel._segments.iterrows()])

    aggfmat = None
    for label, cmat in zip(labels, cmats):
        if not has_fragmentlength(cmat.adata):
            continue
        fig, ax = plt.subplots(figsize=(7, 7))
        scmodel.plot_fragmentsize(cmat.adata, ax, cmap='Blues')
        fig.savefig(
            os.path.join(resultspath,
                         'fragmentsize_per_state_{}.svg'.format(label)))
Beispiel #4
0
def plot_state_annotation_relationship_heatmap(model,
                                               storage,
                                               labels,
                                               title,
                                               threshold=0.0,
                                               groupby=None):

    make_folders(os.path.join(storage, 'annotation'))

    fig, ax = plt.subplots(figsize=(10, 20))

    segdf = model._segments[model._segments.Prob_max >= threshold].copy()

    segdf_ = segdf[labels].apply(zscore)
    segdf_['name'] = segdf.name

    segdf = segdf_.groupby("name").agg('mean')

    sns.heatmap(segdf, cmap="RdBu_r", robust=True, center=0.0, ax=ax)
    logging.debug('writing {}'.format(
        os.path.join(storage, 'annotation', '{}.png'.format(title))))
    fig.tight_layout()
    fig.savefig(os.path.join(storage, 'annotation', '{}.png'.format(title)))
Beispiel #5
0
def local_main(args):
    if hasattr(args, 'storage'):
        logfile = os.path.join(args.storage, 'log', 'logfile.log')
        make_folders(os.path.dirname(logfile))
        logging.basicConfig(filename=logfile,
                            level=logging.DEBUG,
                            format='%(asctime)s;%(levelname)s;%(message)s',
                            datefmt='%Y-%m-%d %H:%M:%S')

    logging.debug(args)

    if args.program == 'bam_to_counts':

        logging.debug('Make countmatrix ...')
        cm = CountMatrix.from_bam(args.bamfile,
                                  args.regions,
                                  barcodetag=args.barcodetag,
                                  mode=args.mode,
                                  with_fraglen=args.with_fraglen)
        cm.adata.var.loc[:,
                         "sample"] = args.samplename if args.samplename is not None else args.bamfile
        if args.cellgroup is not None:
            cells, groups = get_cell_grouping(args.cellgroup, cm)
            cm = cm.pseudobulk(cells, groups)

        cm.export_counts(args.counts)

    if args.program == 'fragments_to_counts':

        logging.debug('Make countmatrix ...')
        cm = CountMatrix.from_fragments(args.fragmentfile,
                                        args.regions,
                                        with_fraglen=args.with_fraglen)

        cm.adata.var.loc[:,
                         "sample"] = args.samplename if args.samplename is not None else args.fragmentfile
        if args.cellgroup is not None:
            cells, groups = get_cell_grouping(args.cellgroup, cm)
            cm = cm.pseudobulk(cells, groups)

        cm.export_counts(args.counts)

    elif args.program == 'pseudobulk_tracks':

        logging.debug('Make pseudobulk bam-files')

        cells, groups = get_cell_grouping(args.cellgroup)

        make_pseudobulk_bam(args.bamfile,
                            args.outdir,
                            cells,
                            groups,
                            tag=args.barcodetag)

    elif args.program == "make_tile":
        make_counting_bins(args.bamfile, args.binsize, args.regions,
                           args.remove_chroms)

    elif args.program == 'filter':
        logging.debug('Filter counts ...')
        cm = CountMatrix.load(args.incounts, args.regions)
        cm = cm.filter(args.mincounts,
                       args.maxcounts,
                       args.minregioncounts,
                       binarize=False,
                       trimcount=args.trimcounts)
        cm.export_counts(args.outcounts)

    elif args.program == 'collapse':
        logging.debug('Collapse cells (pseudobulk)...')
        cm = CountMatrix.load(args.incounts, args.regions)

        cells, groups = get_cell_grouping(args.cellgroup, cm)
        pscm = cm.pseudobulk(cells, groups)
        pscm.export_counts(args.outcounts)

    elif args.program == 'subset':

        logging.debug('Subset cells ...')
        cm = CountMatrix.load(args.incounts, args.regions)

        cells = get_cells(args.subset, args.barcodecolumn)
        pscm = cm.subset(cells)
        pscm.export_counts(args.outcounts)

    elif args.program == 'merge':
        logging.debug('Merge count matrices ...')
        cms = []
        for incount in args.incounts:
            cm = CountMatrix.load(incount, args.regions)
            cms.append(cm)

        merged_cm = CountMatrix.merge(cms)
        merged_cm.export_counts(args.outcounts)

    elif args.program == 'fit_segment':
        if args.labels is None:
            args.labels = ["sample"] * len(args.counts)
        assert len(args.labels) == len(args.counts)

        outputpath = os.path.join(args.storage, modelname)
        logging.debug('Segmentation ...')
        # fit on subset of the data
        data = load_count_matrices(args.counts, args.regions, args.mincounts,
                                   args.maxcounts, args.trimcounts,
                                   args.minregioncounts)

        scmodel, models = run_segmentation(data, args.nstates, args.niter,
                                           args.randomseed, args.n_jobs)

        # predict on the entire genome
        data = load_count_matrices(args.counts, args.regions, args.mincounts,
                                   args.maxcounts, args.trimcounts, None)

        logging.debug('segmentation data:')
        for d in data:
            logging.debug(d)

        scmodel.segment(data, args.regions)
        scmodel.save(outputpath)
        for s, m in zip(args.randomseed, models):
            scmodel.save(outputpath + f'_rseed{s}')

        logging.debug('summarize results ...')
        make_state_summary(scmodel, outputpath, args.labels)
        plot_normalized_emissions(scmodel, outputpath, args.labels)
        save_score(scmodel, data, outputpath)
        plot_fragmentsize(scmodel, outputpath, args.labels, data)

    elif args.program == 'segment':
        assert len(args.labels) == len(args.counts)

        outputpath = os.path.join(args.storage, modelname)
        data = load_count_matrices(args.counts, args.regions, args.mincounts,
                                   args.maxcounts, args.trimcounts, 0)
        scmodel = Scregseg.load(outputpath)
        logging.debug('State calling ...')
        scmodel.segment(data, args.regions)
        scmodel.save(outputpath)
        make_state_summary(scmodel, outputpath, args.labels)
        plot_normalized_emissions(scmodel, outputpath, args.labels)
        save_score(scmodel, data, outputpath)

    elif args.program == 'seg_to_bed':
        outputpath = os.path.join(args.storage, modelname)

        scmodel = Scregseg.load(outputpath)

        sdf = scmodel._segments.copy()
        if args.method == "manualselect":
            if args.statenames is None:
                raise ValueError(
                    "--method manuelselect also requires --statenames <list state names>"
                )
            query_states = args.statenames
        elif args.method == "rarest":
            if args.nstates <= 0:
                raise ValueError(
                    "--method rarest also requires --nstates <int>")
            query_states = pd.Series(
                scmodel.model.get_stationary_distribution(),
                index=[
                    'state_{}'.format(i) for i in range(scmodel.n_components)
                ])
            query_states = query_states.nsmallest(args.nstates).index.tolist()
        elif args.method == "abundancethreshold":
            query_states = ['state_{}'.format(i) for i, p in enumerate(scmodel.model.get_stationary_distribution()) \
                            if p<=args.max_state_abundance]

        logging.debug("method={}: {}".format(args.method, query_states))

        if args.exclude_states is not None:
            query_states = list(
                set(query_states).difference(set(args.exclude_states)))

        # subset and merge the state calls
        subset, perm_matrix = get_statecalls(
            sdf,
            query_states,
            ntop=args.nregsperstate,
            collapse_neighbors=not args.no_bookended_merging,
            state_prob_threshold=args.threshold)

        logging.debug("Exporting {} states with {} regions".format(
            len(query_states), subset.shape[0]))
        if args.output == '':
            output = outputpath = os.path.join(args.storage, modelname,
                                               'summary', 'segments.bed')
        else:
            output = args.output

        # export the state calls as a bed file
        export_bed(subset, output, individual_beds=args.individualbeds)

        if len(args.counts) > 0:
            labels = _get_labels(args.counts, args.labels)
            data = load_count_matrices(args.counts, args.regions,
                                       args.mincounts, args.maxcounts,
                                       args.trimcounts, 0)
            for mat, datum, fname in zip(labels, data, args.counts):
                x = perm_matrix.dot(datum.adata.X).tocsr()
                dat = CountMatrix(x, subset, datum.cannot)
                if fname.endswith('.h5ad'):
                    dat.export_counts(output[:-4] + f'_{mat}.h5ad')
                else:
                    dat.export_counts(output[:-4] + f'_{mat}.mtx')

    elif args.program == 'annotate':
        outputpath = os.path.join(args.storage, modelname)
        scmodel = Scregseg.load(outputpath)

        assert len(args.labels) == len(
            args.files), "Number of files and labels mismatching"
        logging.debug('annotate states ...')
        files = {
            key: filename
            for key, filename in zip(args.labels, args.files)
        }
        scmodel.annotate(files)

        scmodel.save(outputpath)

    elif args.program == 'plot_annot':
        outputpath = os.path.join(args.storage, modelname)
        logging.debug('Plot annotation ...')
        scmodel = Scregseg.load(outputpath)

        if args.plottype == 'heatmap':
            plot_state_annotation_relationship_heatmap(scmodel, outputpath,
                                                       args.labels, args.title,
                                                       args.threshold,
                                                       args.groupby)
        else:
            plot_state_annotation_relationship(scmodel, outputpath,
                                               args.labels, args.title,
                                               args.plottype, args.threshold,
                                               args.groupby)

    elif args.program == 'enrichment':
        outputpath = os.path.join(args.storage, modelname)

        logging.debug('enrichment analysis')
        scmodel = Scregseg.load(outputpath)

        if args.output is None:
            outputenr = os.path.join(outputpath, 'annotation')
        else:
            outputenr = args.output

        make_folders(outputenr)

        if os.path.isdir(args.features):
            featuresets = glob.glob(os.path.join(args.features, '*.bed'))
            featurenames = [
                os.path.basename(name)[:-4] for name in featuresets
            ]
            obs, lens, _ = scmodel.geneset_observed_state_counts(
                featuresets, flanking=args.flanking)
        else:
            obs, lens, featurenames = scmodel.observed_state_counts(
                args.features,
                flanking=args.flanking,
                using_tss=not args.using_genebody)
            obs.to_csv(os.path.join(outputenr,
                                    'state_counts_{}.tsv'.format(args.title)),
                       sep='\t')

        enr = scmodel.broadregion_enrichment(obs, mode=args.method)
        cats = []
        for cluster in enr.columns:
            cats += list(enr[cluster].nlargest(args.ntop).index)
        cats = list(set(cats))

        enr = enr.loc[cats, :]

        def _getfigsize(s):
            return tuple([int(x) for x in s.split(',')])

        if not args.noplot:
            if args.method == 'logfold':
                g = sns.clustermap(enr,
                                   cmap="RdBu_r",
                                   figsize=_getfigsize(args.figsize),
                                   robust=True,
                                   **{
                                       'center': 0.0,
                                       'vmin': -1.5,
                                       'vmax': 1.5
                                   })

            elif args.method == 'chisqstat':
                g = sns.clustermap(enr,
                                   cmap="Reds",
                                   figsize=_getfigsize(args.figsize),
                                   robust=True)

            elif args.method == 'pvalue':
                g = sns.clustermap(enr,
                                   cmap="Reds",
                                   figsize=_getfigsize(args.figsize),
                                   robust=True)
            g.savefig(
                os.path.join(
                    outputenr, "state_enrichment_{}_{}.png".format(
                        args.method, args.title)))

        enr.to_csv(os.path.join(
            outputenr,
            'state_enrichment_{}_{}.tsv'.format(args.method, args.title)),
                   sep='\t')

    elif args.program == 'extract_motifs':
        outputpath = os.path.join(args.storage, modelname)

        if args.output is None:
            motifoutput = os.path.join(outputpath, 'motifs')
        else:
            motifoutput = args.output
        make_folders(motifoutput)

        scmodel = Scregseg.load(outputpath)
        if args.method == "regression":
            motifextractor = MotifExtractor(scmodel,
                                            args.refgenome,
                                            ntop=args.ntop,
                                            nbottom=args.nbottom,
                                            ngap=args.ngap,
                                            nmotifs=args.nmotifs,
                                            flank=args.flank)
        elif args.method == "betweenstates":
            motifextractor = MotifExtractor2(scmodel,
                                             args.refgenome,
                                             ntop=args.ntop,
                                             nmotifs=args.nmotifs,
                                             flank=args.flank)
        else:
            raise ValueError(
                "--method {} unknown. regression or classification supported.".
                format(args.method))

        os.environ['JANGGU_OUTPUT'] = motifoutput
        motifextractor.extract_motifs()
        motifextractor.save_motifs(
            os.path.join(motifoutput, 'scregseg_motifs.meme'))
Beispiel #6
0
def plot_normalized_emissions(model, output, labels):
    """ Save normalized emission probabilities"""
    make_folders(os.path.join(output, 'summary'))
    model.plot_emissions().savefig(
        os.path.join(output, 'summary', 'emission.png'))
Beispiel #7
0
 def save_motifs(self, output):
     make_folders(os.path.dirname(output))
     self.meme.save(output)