示例#1
0
    def __init__(self, args, display_title='Filtering'):
        super().__init__(args, display_title)

        # data
        with open(args.raw_read_count_file) as f:
            self.count_dict = json.load(f)

        self.raw_umi = 0
        self.total_corrected_umi = 0
        self.del_umi = 0
        self.read_threshold_dict = {}
        self.umi_threshold_dict = {}  # if not set explicitly, use 1 as default

        self.barcode_ref_umi_dict = utils.genDict(dim=2)
        self.ref_barcode_umi_dict = utils.genDict(dim=2)

        match_dir_dict = utils.parse_match_dir(args.match_dir)
        self.df_tsne = match_dir_dict['df_tsne']

        self.df_filter_tsne = self.df_tsne.copy()

        # out
        self.corrected_read_count_file = f'{self.out_prefix}_corrected_read_count.json'
        self.filter_read_count_file = f'{self.out_prefix}_filtered_read_count.json'
        self.filter_tsne_file = f'{self.out_prefix}_filtered_UMI_tsne.csv'
示例#2
0
    def __init__(self, args, display_title=None):
        Step.__init__(self, args, display_title=display_title)

        # read args
        self.fq = args.fq
        self.fq_pattern = args.fq_pattern
        self.linker_fasta = args.linker_fasta
        self.barcode_fasta = args.barcode_fasta

        # process
        self.barcode_dict, self.barcode_length = utils.read_fasta(self.barcode_fasta, equal=True)
        if self.linker_fasta and self.linker_fasta != 'None':
            self.linker_dict, self.linker_length = utils.read_fasta(self.linker_fasta, equal=True)
        else:
            self.linker_dict, self.linker_length = {}, 0
        self.pattern_dict = parse_pattern(self.fq_pattern)

        # check barcode length
        barcode1 = self.pattern_dict["C"][0]
        # end - start
        pattern_barcode_length = barcode1[1] - barcode1[0]
        if pattern_barcode_length != self.barcode_length:
            raise Exception(
                f'''barcode fasta length {self.barcode_length} 
                != pattern barcode length {pattern_barcode_length}'''
            )

        self.res_dic = utils.genDict()
        self.res_sum_dic = utils.genDict(dim=2)
        self.match_barcode = []

        # out files
        self.read_count_file = f'{self.outdir}/{self.sample}_read_count.tsv'
        self.UMI_count_file = f'{self.outdir}/{self.sample}_UMI_count.tsv'
        self.stat_file = f'{self.outdir}/stat.txt'
示例#3
0
    def __init__(self, args, display_title=None):
        Step.__init__(self, args, display_title=display_title)

        # set
        self.match_barcode_list, self.n_cell = utils.read_barcode_file(args.match_dir)
        self.match_barcode = set(self.match_barcode_list)

        if args.panel:
            self.gene_list = utils.get_gene_region_from_bed(args.panel)[0]
            self.n_gene = len(self.gene_list)
        else:
            self.gene_list, self.n_gene = utils.read_one_col(args.gene_list)

        if not self.gene_list:
            sys.exit("You must provide either --panel or --gene_list!")

        self.count_dict = utils.genDict(dim=3, valType=int)

        self.add_metric(
            name="Number of Target Genes",
            value=self.n_gene,
        )
        self.add_metric(
            name="Number of Cells",
            value=self.n_cell,
        )

        # out file
        self.out_bam_file = f'{self.out_prefix}_filtered.bam'
        self.out_bam_file_sorted = f'{self.out_prefix}_filtered_sorted.bam'
示例#4
0
    def get_read_threshold(self):
        self.add_metric('Read Threshold Method',
                        self.args.read_threshold_method)

        read_dict = utils.genDict(dim=1, valType=list)

        for barcode in self.count_dict:
            for ref in self.count_dict[barcode]:
                read_dict[ref] += list(self.count_dict[barcode][ref].values())

        if self.debug:
            print(read_dict)

        for ref in read_dict:
            otsu_plot_path = f'{self.out_prefix}_{ref}_read_otsu.png'
            runner = Threshold(
                array=read_dict[ref],
                threshold_method=self.args.read_threshold_method,
                otsu_plot_path=otsu_plot_path,
                hard_threshold=self.args.read_hard_threshold)
            read_threshold = runner.run()

            self.read_threshold_dict[ref] = read_threshold
            self.add_metric(
                f'{ref} Read Threshold',
                read_threshold,
                help_info='filter UMI with less than this number of reads',
            )
示例#5
0
def get_nCell_barcodes(fq, nCell):
    '''
    get top nCell's barcodes(rank by UMI counts)
    '''
    count_dict = genDict(dim=2)
    barcode_dict = {}
    with pysam.FastxFile(fq) as fh:
        for entry in fh:
            attr = entry.name.split('_')
            barcode = attr[0]
            umi = attr[1]
            count_dict[barcode][umi] += 1
    for barcode in count_dict:
        barcode_dict[barcode] = len(count_dict[barcode])
    barcodes = pd.DataFrame.from_dict(
        barcode_dict,
        orient='index').sort_values(0, ascending=False).iloc[0:nCell, ].index
    return barcodes
示例#6
0
    def __init__(self, args, display_title='Count'):
        super().__init__(args, display_title)

        # set
        self.min_query_length = int(args.min_query_length)
        self.capture_bam = args.capture_bam

        # read barcodes
        match_dir_dict = utils.parse_match_dir(args.match_dir)
        self.match_barcode = match_dir_dict['match_barcode']
        self.n_match_barcode = match_dir_dict['n_match_barcode']
        self.add_metric(
            name=HELP_INFO_DICT['matched_barcode_number']['display'],
            value=self.n_match_barcode,
            help_info=HELP_INFO_DICT['matched_barcode_number']['info'])

        # data
        self.total_corrected_umi = 0
        self.count_dict = utils.genDict(dim=3)

        # out
        self.raw_read_count_file = f'{self.out_prefix}_raw_read_count.json'
示例#7
0
def count_fusion(args):

    outdir = args.outdir
    sample = args.sample
    bam = args.bam
    flanking_base = int(args.flanking_base)
    fusion_pos_file = args.fusion_pos
    match_dir = args.match_dir
    UMI_min = int(args.UMI_min)

    # check dir
    if not os.path.exists(outdir):
        os.system('mkdir -p %s' % (outdir))

    fusion_pos = read_pos(fusion_pos_file)
    out_prefix = outdir + "/" + sample
    # barcode
    match_barcode, _n_barcode = read_barcode_file(match_dir)
    # tsne
    match_tsne_file = parse_match_dir(match_dir)['tsne_coord']
    df_tsne = pd.read_csv(match_tsne_file, sep="\t", index_col=0)
    # out
    out_read_count_file = out_prefix + "_fusion_read_count.tsv"
    out_umi_count_file = out_prefix + "_fusion_UMI_count.tsv"
    out_barcode_count_file = out_prefix + "_fusion_barcode_count.tsv"
    out_tsne_file = out_prefix + "_fusion_tsne.tsv"

    # process bam
    samfile = pysam.AlignmentFile(bam, "rb")
    header = samfile.header
    new_bam = pysam.AlignmentFile(out_prefix + "_fusion.bam",
                                  "wb",
                                  header=header)
    count_dic = genDict(dim=3)
    for read in samfile:
        tag = read.reference_name
        read_start = int(read.reference_start)
        read_length = len(read.query_sequence)
        attr = read.query_name.split('_')
        barcode = attr[0]
        umi = attr[1]
        if tag in fusion_pos.keys():
            if barcode in match_barcode:
                if is_fusion(pos=fusion_pos[tag],
                             read_start=read_start,
                             read_length=read_length,
                             flanking_base=flanking_base):
                    new_bam.write(read)
                    count_dic[barcode][tag][umi] += 1
    new_bam.close()

    # write dic to pandas df
    rows = []
    for barcode in count_dic:
        for tag in count_dic[barcode]:
            for umi in count_dic[barcode][tag]:
                rows.append([barcode, tag, umi, count_dic[barcode][tag][umi]])
    df_read = pd.DataFrame(rows)
    df_read.rename(columns={
        0: "barcode",
        1: "tag",
        2: "UMI",
        3: "read_count"
    },
                   inplace=True)
    df_read.to_csv(out_read_count_file, sep="\t", index=False)

    if not rows:
        count_fusion.logger.error('***** NO FUSION FOUND! *****')
    else:
        df_umi = df_read.groupby(["barcode", "tag"]).agg({"UMI": "count"})
        df_umi = df_umi[df_umi["UMI"] >= UMI_min]
        df_umi.to_csv(out_umi_count_file, sep="\t")

        df_umi.reset_index(inplace=True)
        df_barcode = df_umi.groupby(["tag"]).agg({"barcode": "count"})
        n_match_barcode = len(match_barcode)
        # add zero count tag
        for tag in fusion_pos.keys():
            if not tag in df_barcode.barcode:
                new_row = pd.Series(data={'barcode': 0}, name=tag)
                df_barcode = df_barcode.append(new_row, ignore_index=False)
        df_barcode["percent"] = df_barcode["barcode"] / n_match_barcode
        df_barcode.to_csv(out_barcode_count_file, sep="\t")

        df_pivot = df_umi.pivot(index="barcode", columns="tag", values="UMI")
        df_pivot.fillna(0, inplace=True)
        df_tsne_fusion = pd.merge(df_tsne,
                                  df_pivot,
                                  right_index=True,
                                  left_index=True,
                                  how="left")
        df_tsne_fusion.fillna(0, inplace=True)
        df_tsne_fusion.to_csv(out_tsne_file, sep="\t")

        # plot
        count_fusion.logger.info("plot fusion...!")
        app = fusionDir + "/plot_fusion.R"
        cmd = f"Rscript {app} --tsne_fusion {out_tsne_file} --outdir {outdir}"
        os.system(cmd)
        count_fusion.logger.info("plot done.")
    def bam2table(self):
        """
        read probe file
        """
        probe_gene_count_dict = utils.genDict(dim=4, valType=int)

        samfile = pysam.AlignmentFile(self.bam, "rb")
        with open(self.count_detail_file, 'wt') as fh1:
            fh1.write('\t'.join(['Barcode', 'geneID', 'UMI', 'count']) + '\n')

            def keyfunc(x):
                return x.query_name.split('_', 1)[0]

            for _, g in groupby(samfile, keyfunc):
                gene_umi_dict = defaultdict(lambda: defaultdict(int))
                for seg in g:
                    (barcode, umi, probe) = seg.query_name.split('_')[:3]
                    if probe != 'None':
                        probe_gene_count_dict[probe]['total'][barcode][
                            umi] += 1
                        if seg.has_tag('XT'):
                            geneID = seg.get_tag('XT')
                            geneName = self.gtf_dict[geneID]
                            probe_gene_count_dict[probe][geneName][barcode][
                                umi] += 1
                        else:
                            probe_gene_count_dict[probe]['None'][barcode][
                                umi] += 1
                    if not seg.has_tag('XT'):
                        continue
                    geneID = seg.get_tag('XT')
                    gene_umi_dict[geneID][umi] += 1
                for gene_id in gene_umi_dict:
                    Count.correct_umi(gene_umi_dict[gene_id])

                # output
                for gene_id in gene_umi_dict:
                    for umi in gene_umi_dict[gene_id]:
                        fh1.write('%s\t%s\t%s\t%s\n' %
                                  (barcode, gene_id, umi,
                                   gene_umi_dict[gene_id][umi]))

        # out probe
        row_list = []
        for probe in probe_gene_count_dict:
            for geneName in probe_gene_count_dict[probe]:
                barcode_count = len(probe_gene_count_dict[probe][geneName])
                umi_count = 0
                read_count = 0
                for barcode in probe_gene_count_dict[probe][geneName]:
                    for umi in probe_gene_count_dict[probe][geneName][barcode]:
                        umi_count += len(
                            probe_gene_count_dict[probe][geneName][barcode])
                        read_count += probe_gene_count_dict[probe][geneName][
                            barcode][umi]
                row_list.append({
                    'probe': probe,
                    'gene': geneName,
                    'barcode_count': barcode_count,
                    'read_count': read_count,
                    'UMI_count': umi_count
                })

        df_probe = pd.DataFrame(row_list,
                                columns=[
                                    'probe', 'gene', 'barcode_count',
                                    'read_count', 'UMI_count'
                                ])
        df_probe = df_probe.groupby([
            'probe'
        ]).apply(lambda x: x.sort_values('UMI_count', ascending=False))
        return df_probe
示例#9
0
def barcode(args):

    # check dir
    if not os.path.exists(args.outdir):
        os.system('mkdir -p %s' % args.outdir)

    bc_pattern = __PATTERN_DICT__[args.chemistry]
    if (bc_pattern):
        (linker, whitelist) = get_scope_bc(args.chemistry)
    else:
        bc_pattern = args.pattern
        linker = args.linker
        whitelist = args.whitelist
    if (not linker) or (not whitelist) or (not bc_pattern):
        barcode.logger.error("invalid chemistry or [pattern,linker,whitelist]")
        sys.exit()

    # parse pattern to dict, C8L10C8L10C8U8
    # defaultdict(<type 'list'>, {'C': [[0, 8], [18, 26], [36, 44]], 'U':
    # [[44, 52]], 'L': [[8, 18], [26, 36]]})
    pattern_dict = parse_pattern(bc_pattern)

    # check linker
    check_seq(linker, pattern_dict, "L")

    bool_T = True if 'T' in pattern_dict else False
    bool_L = True if 'L' in pattern_dict else False

    C_len = sum([item[1] - item[0] for item in pattern_dict['C']])

    barcode_qual_Counter = Counter()
    umi_qual_Counter = Counter()
    C_U_base_Counter = Counter()
    args.lowQual = ord2chr(args.lowQual)

    # generate list with mismatch 1, substitute one base in raw sequence with
    # A,T,C,G
    barcode_dict = generate_seq_dict(whitelist, n=1)
    linker_dict = generate_seq_dict(linker, n=2)

    fq1_list = args.fq1.split(",")
    fq2_list = args.fq2.split(",")
    # merge multiple fastq files
    if len(fq1_list) > 1:
        barcode.logger.info("merge fastq with same sample name...")
        fastq_dir = args.outdir + "/../merge_fastq"
        if not os.path.exists(fastq_dir):
            os.system('mkdir -p %s' % fastq_dir)
        fastq1_file = f"{fastq_dir}/{args.sample}_1.fq.gz"
        fastq2_file = f"{fastq_dir}/{args.sample}_2.fq.gz"
        fq1_files = " ".join(fq1_list)
        fq2_files = " ".join(fq2_list)
        fq1_cmd = f"cat {fq1_files} > {fastq1_file}"
        fq2_cmd = f"cat {fq2_files} > {fastq2_file}"
        barcode.logger.info(fq1_cmd)
        os.system(fq1_cmd)
        barcode.logger.info(fq2_cmd)
        os.system(fq2_cmd)
        barcode.logger.info("merge fastq done.")
    else:
        fastq1_file = args.fq1
        fastq2_file = args.fq2

    fh1 = xopen(fastq1_file)
    fh2 = xopen(fastq2_file)
    out_fq2 = args.outdir + '/' + args.sample + '_2.fq.gz'
    fh3 = xopen(out_fq2, 'w')

    (total_num, clean_num, no_polyT_num, lowQual_num,
     no_linker_num, no_barcode_num) = (0, 0, 0, 0, 0, 0)
    Barcode_dict = defaultdict(int)

    if args.nopolyT:
        fh1_without_polyT = xopen(args.outdir + '/noPolyT_1.fq', 'w')
        fh2_without_polyT = xopen(args.outdir + '/noPolyT_2.fq', 'w')

    if args.noLinker:
        fh1_without_linker = xopen(args.outdir + '/noLinker_1.fq', 'w')
        fh2_without_linker = xopen(args.outdir + '/noLinker_2.fq', 'w')

    bool_probe = False
    if args.probe_file and args.probe_file != 'None':
        bool_probe = True
        count_dic = genDict(dim=3)
        valid_count_dic = genDict(dim=2)
        probe_dic = read_fasta(args.probe_file)
        reads_without_probe = 0

    g1 = read_fastq(fh1)
    g2 = read_fastq(fh2)

    while True:
        try:
            (header1, seq1, qual1) = next(g1)
            (header2, seq2, qual2) = next(g2)
        except BaseException:
            break
        if total_num > 0 and total_num % 1000000 == 0:
            barcode.logger.info(
                f'processed reads: {format_number(total_num)}.'
                f'valid reads: {format_number(clean_num)}.'
            )

        total_num += 1

        # polyT filter
        if bool_T:
            polyT = seq_ranges(seq1, pattern_dict['T'])
            if no_polyT(polyT):
                no_polyT_num += 1
                if args.nopolyT:
                    fh1_without_polyT.write(
                        '@%s\n%s\n+\n%s\n' % (header1, seq1, qual1))
                    fh2_without_polyT.write(
                        '@%s\n%s\n+\n%s\n' % (header2, seq2, qual2))
                continue

        # lowQual filter
        C_U_quals_ascii = seq_ranges(
            qual1, pattern_dict['C'] + pattern_dict['U'])
        # C_U_quals_ord = [ord(q) - 33 for q in C_U_quals_ascii]
        if low_qual(C_U_quals_ascii, args.lowQual, args.lowNum):
            lowQual_num += 1
            continue

        # linker filter
        barcode_arr = [seq_ranges(seq1, [i]) for i in pattern_dict['C']]
        raw_cb = ''.join(barcode_arr)
        if bool_L:
            linker = seq_ranges(seq1, pattern_dict['L'])
            if (no_linker(linker, linker_dict)):
                no_linker_num += 1

                if args.noLinker:
                    fh1_without_linker.write(
                        '@%s\n%s\n+\n%s\n' % (header1, seq1, qual1))
                    fh2_without_linker.write(
                        '@%s\n%s\n+\n%s\n' % (header2, seq2, qual2))
                continue

        # barcode filter
        # barcode_arr = [seq_ranges(seq1, [i]) for i in pattern_dict['C']]
        # raw_cb = ''.join(barcode_arr)
        res = no_barcode(barcode_arr, barcode_dict)
        if res is True:
            no_barcode_num += 1
            continue
        elif res == "correct":
            cb = raw_cb
        else:
            cb = res

        umi = seq_ranges(seq1, pattern_dict['U'])
        Barcode_dict[cb] += 1
        clean_num += 1
        read_name_probe = 'None'

        if bool_probe:
            # valid count
            valid_count_dic[cb][umi] += 1

            # output probe UMi and read count
            find_probe = False
            for probe_name in probe_dic:
                probe_seq = probe_dic[probe_name]
                probe_seq = probe_seq.upper()
                if seq1.find(probe_seq) != -1:
                    count_dic[probe_name][cb][umi] += 1
                    read_name_probe = probe_name
                    find_probe = True
                    break

            if not find_probe:
                reads_without_probe += 1

        barcode_qual_Counter.update(C_U_quals_ascii[:C_len])
        umi_qual_Counter.update(C_U_quals_ascii[C_len:])
        C_U_base_Counter.update(raw_cb + umi)

        # new readID: @barcode_umi_old readID
        fh3.write(f'@{cb}_{umi}_{read_name_probe}_{total_num}\n{seq2}\n+\n{qual2}\n')

    fh3.close()

    # logging
    if total_num % 1000000 != 0:
        barcode.logger.info(
            f'processed reads: {format_number(total_num)}. '
            f'valid reads: {format_number(clean_num)}. '
        )

    if clean_num == 0:
        raise Exception(
            'no valid reads found! please check the --chemistry parameter.')

    if bool_probe:
        # total probe summary
        total_umi = 0
        total_valid_read = 0
        for cb in valid_count_dic:
            total_umi += len(valid_count_dic[cb])
            total_valid_read += sum(valid_count_dic[cb].values())
        barcode.logger.info("total umi:"+str(total_umi))
        barcode.logger.info("total valid read:"+str(total_valid_read))
        barcode.logger.info("reads without probe:"+str(reads_without_probe))

        # probe summary
        count_list = []
        for probe_name in probe_dic:
            UMI_count = 0
            read_count = 0
            if probe_name in count_dic:
                for cb in count_dic[probe_name]:
                    UMI_count += len(count_dic[probe_name][cb])
                    read_count += sum(count_dic[probe_name][cb].values())
            count_list.append(
                {"probe_name": probe_name, "UMI_count": UMI_count, "read_count": read_count})

        df_count = pd.DataFrame(count_list, columns=[
                                "probe_name", "read_count", "UMI_count"])

        def format_percent(x):
            x = str(round(x*100, 2))+"%"
            return x
        df_count["read_fraction"] = (
            df_count["read_count"]/total_valid_read).apply(format_percent)
        df_count["UMI_fraction"] = (
            df_count["UMI_count"]/total_umi).apply(format_percent)
        df_count.sort_values(by="UMI_count", inplace=True, ascending=False)
        df_count_file = args.outdir + '/' + args.sample + '_probe_count.tsv'
        df_count.to_csv(df_count_file, sep="\t", index=False)

    # stat
    BarcodesQ30 = sum([barcode_qual_Counter[k] for k in barcode_qual_Counter if k >= ord2chr(
        30)]) / float(sum(barcode_qual_Counter.values())) * 100
    UMIsQ30 = sum([umi_qual_Counter[k] for k in umi_qual_Counter if k >= ord2chr(
        30)]) / float(sum(umi_qual_Counter.values())) * 100

    global stat_info
    def cal_percent(x): return "{:.2%}".format((x + 0.0) / total_num)
    with open(args.outdir + '/stat.txt', 'w') as fh:
        """
        Raw Reads: %s
        Valid Reads: %s(%s)
        Q30 of Barcodes: %.2f%%
        Q30 of UMIs: %.2f%%
        """
        stat_info = stat_info % (format_number(total_num), format_number(clean_num),
                                 cal_percent(clean_num), BarcodesQ30,
                                 UMIsQ30)
        stat_info = re.sub(r'^\s+', r'', stat_info, flags=re.M)
        fh.write(stat_info)

    barcode.logger.info('fastqc ...!')
    cmd = ['fastqc', '-t', str(args.thread), '-o', args.outdir, out_fq2]
    barcode.logger.info('%s' % (' '.join(cmd)))
    subprocess.check_call(cmd)
    barcode.logger.info('fastqc done!')

    t = reporter(name='barcode', assay=args.assay, sample=args.sample,
                 stat_file=args.outdir + '/stat.txt', outdir=args.outdir + '/..')
    t.get_report()
示例#10
0
def bam2table(bam, detail_file, id_name):
    # 提取bam中相同barcode的reads,统计比对到基因的reads信息
    # probe
    probe_gene_count_dict = genDict(dim=4, valType=int)

    samfile = pysam.AlignmentFile(bam, "rb")
    with open(detail_file, 'w') as fh1:
        fh1.write('\t'.join(['Barcode', 'geneID', 'UMI', 'count']) + '\n')

        # pysam.libcalignedsegment.AlignedSegment
        # AAACAGGCCAGCGTTAACACGACC_CCTAACGT_A00129:340:HHH72DSXX:2:1353:23276:30843
        # 获取read的barcode
        def keyfunc(x):
            return x.query_name.split('_', 1)[0]

        for _, g in groupby(samfile, keyfunc):
            gene_umi_dict = defaultdict(lambda: defaultdict(int))
            for seg in g:
                (barcode, umi, probe) = seg.query_name.split('_')[:3]
                if probe != 'None':
                    probe_gene_count_dict[probe]['total'][barcode][umi] += 1
                    if seg.has_tag('XT'):
                        geneID = seg.get_tag('XT')
                        geneName = id_name[geneID]
                        probe_gene_count_dict[probe][geneName][barcode][
                            umi] += 1
                    else:
                        probe_gene_count_dict[probe]['None'][barcode][umi] += 1
                if not seg.has_tag('XT'):
                    continue
                geneID = seg.get_tag('XT')
                gene_umi_dict[geneID][umi] += 1
            res_dict = correct_umi(fh1, barcode, gene_umi_dict)

            # output
            for geneID in res_dict:
                for umi in res_dict[geneID]:
                    fh1.write('%s\t%s\t%s\t%s\n' %
                              (barcode, geneID, umi, res_dict[geneID][umi]))

    # out probe
    row_list = []
    for probe in probe_gene_count_dict:
        for geneName in probe_gene_count_dict[probe]:
            barcode_count = len(probe_gene_count_dict[probe][geneName])
            umi_count = 0
            read_count = 0
            for barcode in probe_gene_count_dict[probe][geneName]:
                for umi in probe_gene_count_dict[probe][geneName][barcode]:
                    umi_count += len(
                        probe_gene_count_dict[probe][geneName][barcode])
                    read_count += probe_gene_count_dict[probe][geneName][
                        barcode][umi]
            row_list.append({
                'probe': probe,
                'gene': geneName,
                'barcode_count': barcode_count,
                'read_count': read_count,
                'UMI_count': umi_count
            })

    df_probe = pd.DataFrame(
        row_list,
        columns=['probe', 'gene', 'barcode_count', 'read_count', 'UMI_count'])
    df_probe = df_probe.groupby(
        ['probe']).apply(lambda x: x.sort_values('UMI_count', ascending=False))
    return df_probe