Beispiel #1
0
def basecall(model, reads, aligner=None, beamsize=5, chunksize=0, overlap=0, batchsize=1, qscores=False):
    """
    Basecalls a set of reads.
    """
    chunks = (
        (read, chunk(torch.tensor(read.signal), chunksize, overlap)) for read in reads
    )
    scores = unbatchify(
        (k, compute_scores(model, v)) for k, v in batchify(chunks, batchsize)
    )
    scores = (
        (read, {'scores': stitch(v, chunksize, overlap, len(read.signal), model.stride)}) for read, v in scores
    )
    decoder = partial(decode, decode=model.decode, beamsize=beamsize, qscores=qscores)
    basecalls = process_map(decoder, scores, n_proc=4)
    if aligner: return align_map(aligner, basecalls)
    return basecalls
Beispiel #2
0
def basecall(model,
             reads,
             aligner=None,
             beamsize=40,
             chunksize=4000,
             overlap=500,
             batchsize=32,
             qscores=False,
             reverse=False):
    """
    Basecalls a set of reads.
    """
    _decode = partial(decode_int8, seqdist=model.seqdist, beamsize=beamsize)
    reads = (read_chunk for read in reads
             for read_chunk in split_read(read)[::-1 if reverse else 1])
    chunks = (((read, start, end),
               chunk(torch.from_numpy(read.signal[start:end]), chunksize,
                     overlap)) for (read, start, end) in reads)
    batches = (
        (k, quantise_int8(compute_scores(model, batch, reverse=reverse)))
        for k, batch in thread_iter(batchify(chunks, batchsize=batchsize)))
    stitched = ((read,
                 stitch(x,
                        chunksize,
                        overlap,
                        end - start,
                        model.stride,
                        reverse=reverse))
                for ((read, start, end), x) in unbatchify(batches))

    transferred = thread_map(transfer, stitched, n_thread=1)
    basecalls = thread_map(_decode, transferred, n_thread=8)

    basecalls = ((read, ''.join(seq for k, seq in parts))
                 for read, parts in groupby(
                     basecalls, lambda x:
                     (x[0].parent if hasattr(x[0], 'parent') else x[0])))
    basecalls = ((read, {
        'sequence': seq,
        'qstring': '?' * len(seq) if qscores else '*',
        'mean_qscore': 0.0
    }) for read, seq in basecalls)

    if aligner: return align_map(aligner, basecalls)
    return basecalls
Beispiel #3
0
def call(model,
         reads_directory,
         templates,
         complements,
         aligner=None,
         cudapoa=True):

    temp_reads = read_gen(reads_directory,
                          templates,
                          n_proc=8,
                          cancel=process_cancel())
    comp_reads = read_gen(reads_directory,
                          complements,
                          n_proc=8,
                          cancel=process_cancel())

    temp_scores = basecall(model, temp_reads, reverse=False)
    comp_scores = basecall(model, comp_reads, reverse=True)

    scores = (((r1, r2), (s1, s2))
              for (r1, s1), (r2, s2) in zip(temp_scores, comp_scores))
    calls = thread_map(decode, scores, n_thread=12)

    if cudapoa:
        sequences = ((reads, [
            seqs,
        ]) for reads, seqs in calls if len(seqs) > 2)
        consensus = (zip(reads, poagen(calls))
                     for reads, calls in batchify(sequences, 100))
        res = ((reads[0], {
            'sequence': seq
        }) for seqs in consensus for reads, seq in seqs)
    else:
        sequences = ((reads, seqs) for reads, seqs in calls if len(seqs) > 2)
        consensus = process_map(poa, sequences, n_proc=4)
        res = ((reads, {
            'sequence': seq
        }) for reads, seqs in consensus for seq in seqs)

    if aligner is None: return res
    return align_map(aligner, res)
Beispiel #4
0
def basecall(model, reads, aligner=None, beamsize=40, chunksize=4000, overlap=500, batchsize=32, qscores=False):
    """
    Basecalls a set of reads.
    """
    split_read_length=400000
    _stitch = partial(
        stitch,
        start=overlap // 2 // model.stride,
        end=(chunksize - overlap // 2) // model.stride,
    )
    _decode = partial(decode_int8, seqdist=model.seqdist, beamsize=beamsize)
    reads = (
        ((read, i), x) for read in reads
        for (i, x) in enumerate(torch.split(torch.from_numpy(read.signal), split_read_length))
    )
    chunks = (
        ((read, chunk(signal, chunksize, overlap, pad_start=True)) for (read, signal) in reads)
    )
    batches = (
        (read, quantise_int8(compute_scores(model, batch)))
        for read, batch in thread_iter(batchify(chunks, batchsize=batchsize))
    )
    stitched = ((read, _stitch(x)) for (read, x) in unbatchify(batches))
    transferred = thread_map(transfer, stitched, n_thread=1)
    basecalls = thread_map(_decode, transferred, n_thread=8)

    basecalls = (
        (read, ''.join(seq for k, seq in parts)) for read, parts in groupby(basecalls, lambda x: x[0][0])
    )
    basecalls = (
        (read, {'sequence': seq, 'qstring': '?' * len(seq) if qscores else '*', 'mean_qscore': 0.0})
        for read, seq in basecalls
    )

    if aligner: return align_map(aligner, basecalls)
    return basecalls
Beispiel #5
0
def main(args):

    init(args.seed, args.device)

    if args.model_directory in models and args.model_directory not in os.listdir(
            __models__):
        sys.stderr.write("> downloading model\n")
        File(__models__, models[args.model_directory]).download()

    sys.stderr.write(f"> loading model {args.model_directory}\n")
    try:
        model = load_model(
            args.model_directory,
            args.device,
            weights=int(args.weights),
            chunksize=args.chunksize,
            overlap=args.overlap,
            batchsize=args.batchsize,
            quantize=args.quantize,
            use_koi=True,
        )
    except FileNotFoundError:
        sys.stderr.write(f"> error: failed to load {args.model_directory}\n")
        sys.stderr.write(f"> available models:\n")
        for model in sorted(models):
            sys.stderr.write(f" - {model}\n")
        exit(1)

    if args.verbose:
        sys.stderr.write(
            f"> model basecaller params: {model.config['basecaller']}\n")

    basecall = load_symbol(args.model_directory, "basecall")

    mods_model = None
    if args.modified_base_model is not None or args.modified_bases is not None:
        sys.stderr.write("> loading modified base model\n")
        mods_model = load_mods_model(args.modified_bases, args.model_directory,
                                     args.modified_base_model)
        sys.stderr.write(f"> {mods_model[1]['alphabet_str']}\n")

    if args.reference:
        sys.stderr.write("> loading reference\n")
        aligner = Aligner(args.reference, preset='ont-map', best_n=1)
        if not aligner:
            sys.stderr.write("> failed to load/build index\n")
            exit(1)
    else:
        aligner = None

    fmt = biofmt(aligned=args.reference is not None)

    if args.reference and args.reference.endswith(
            ".mmi") and fmt.name == "cram":
        sys.stderr.write(
            "> error: reference cannot be a .mmi when outputting cram\n")
        exit(1)
    elif args.reference and fmt.name == "fastq":
        sys.stderr.write(
            f"> warning: did you really want {fmt.aligned} {fmt.name}?\n")
    else:
        sys.stderr.write(f"> outputting {fmt.aligned} {fmt.name}\n")

    if args.save_ctc and not args.reference:
        sys.stderr.write(
            "> a reference is needed to output ctc training data\n")
        exit(1)

    if fmt.name != 'fastq':
        groups = get_read_groups(args.reads_directory,
                                 args.model_directory,
                                 n_proc=8,
                                 recursive=args.recursive,
                                 read_ids=column_to_set(args.read_ids),
                                 skip=args.skip,
                                 cancel=process_cancel())
    else:
        groups = []

    reads = get_reads(args.reads_directory,
                      n_proc=8,
                      recursive=args.recursive,
                      read_ids=column_to_set(args.read_ids),
                      skip=args.skip,
                      cancel=process_cancel())

    if args.max_reads:
        reads = take(reads, args.max_reads)

    if args.save_ctc:
        reads = (chunk for read in reads for chunk in read_chunks(
            read,
            chunksize=model.config["basecaller"]["chunksize"],
            overlap=model.config["basecaller"]["overlap"]))
        ResultsWriter = CTCWriter
    else:
        ResultsWriter = Writer

    results = basecall(model,
                       reads,
                       reverse=args.revcomp,
                       batchsize=model.config["basecaller"]["batchsize"],
                       chunksize=model.config["basecaller"]["chunksize"],
                       overlap=model.config["basecaller"]["overlap"])

    if mods_model is not None:
        results = process_itemmap(partial(call_mods, mods_model), results)
    if aligner:
        results = align_map(aligner, results, n_thread=os.cpu_count())

    writer = ResultsWriter(
        fmt.mode,
        tqdm(results, desc="> calling", unit=" reads", leave=False),
        aligner=aligner,
        group_key=args.model_directory,
        ref_fn=args.reference,
        groups=groups,
    )

    t0 = perf_counter()
    writer.start()
    writer.join()
    duration = perf_counter() - t0
    num_samples = sum(num_samples for read_id, num_samples in writer.log)

    sys.stderr.write("> completed reads: %s\n" % len(writer.log))
    sys.stderr.write("> duration: %s\n" %
                     timedelta(seconds=np.round(duration)))
    sys.stderr.write("> samples per second %.1E\n" % (num_samples / duration))
    sys.stderr.write("> done\n")