Example #1
0
def main():
    args = get_parser().parse_args()

    worker_kwarg_names = ['back_prob', 'localpen', 'minscore', 'trim']

    model = helpers.load_model(args.model)

    fast5_reads = fast5utils.iterate_fast5_reads(
        args.read_dir, limit=args.limit, strand_list=args.input_strand_list,
        recursive=args.recursive)

    with helpers.open_file_or_stdout(args.output) as fh:
        for res in imap_mp(
                squiggle_match.worker, fast5_reads, threads=args.jobs,
                fix_kwargs=helpers.get_kwargs(args, worker_kwarg_names),
                unordered=True, init=squiggle_match.init_worker,
                initargs=[model, args.references]):
            if res is None:
                continue
            read_id, sig, score, path, squiggle, bases = res
            bases = bases.decode('ascii')
            fh.write('#{} {}\n'.format(read_id, score))
            for i, (s, p) in enumerate(zip(sig, path)):
                fh.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(
                    read_id, i, s, p, bases[p], squiggle[p, 0], squiggle[p, 1],
                    squiggle[p, 2]))
Example #2
0
def main():
    args = get_parser().parse_args()

    sys.stderr.write(
        "* Loading references (this may take a while for large genomes)\n")
    references = fasta_file_to_dict(args.reference, filter_ambig=False)

    if args.input_strand_list is None:
        strand_list = None
    else:
        strand_list = readtsv(args.input_strand_list,
                              fields=['read_id'])['read_id']
        sys.stderr.write('* Strand list contains {} reads\n'.format(
            len(strand_list)))

    sys.stderr.write("* Extracting read references using SAM alignment\n")
    with open_file_or_stdout(args.output) as fh:
        for samfile in args.input:
            for name, read_ref in get_refs(samfile,
                                           references,
                                           args.min_coverage,
                                           args.pad,
                                           strand_list=strand_list):
                if args.reverse:
                    read_ref = read_ref[::-1]
                if args.complement:
                    read_ref = complement(read_ref)
                fasta = ">{}\n{}\n".format(name, read_ref)

                fh.write(fasta)
Example #3
0
def main():
    args = parser.parse_args()
    model = load_model(args.model)

    json_out = model.json(args.params)

    with open_file_or_stdout(args.output) as fh:
        json.dump(json_out, fh, indent=4, cls=JsonEncoder)
Example #4
0
def main():
    args = get_parser().parse_args()
    model_md5 = file_md5(args.model)
    model = load_model(args.model)

    json_out = model.json()
    json_out['md5sum'] = model_md5

    with open_file_or_stdout(args.output) as fh:
        json.dump(json_out, fh, indent=4, cls=JsonEncoder)
def main():
    args = parser.parse_args()

    sys.stderr.write("* Loading references (this may take a while for large genomes)\n")
    references = fasta_file_to_dict(args.reference, filter_ambig=False)

    sys.stderr.write("* Extracting read references using SAM alignment\n")
    with open_file_or_stdout(args.output) as fh:
        for samfile in args.input:
            for name, read_ref in get_refs(samfile, references, args.min_coverage, args.pad):
                if args.reverse:
                    read_ref = read_ref[::-1]
                if args.complement:
                    read_ref = complement(read_ref)
                fasta = ">{}\n{}\n".format(name, read_ref)

                fh.write(fasta)
Example #6
0
def main():
    args = parser.parse_args()

    predict_squiggle = helpers.load_model(args.model)

    with helpers.open_file_or_stdout(args.output) as fh:
        for seq in SeqIO.parse(args.input, 'fasta'):
            seqstr = str(seq.seq)
            embedded_seq_numpy = np.expand_dims(
                squiggle_match.embed_sequence(seqstr), axis=1)
            embedded_seq_torch = torch.tensor(embedded_seq_numpy,
                                              dtype=torch.float32)

            with torch.no_grad():
                squiggle = np.squeeze(
                    predict_squiggle(embedded_seq_torch).cpu().numpy(), axis=1)

            fh.write('base\tcurrent\tsd\tdwell\n')
            for base, (mean, logsd, dwell) in zip(seq.seq, squiggle):
                fh.write('{}\t{}\t{}\t{}\n'.format(base, mean, np.exp(logsd),
                                                   np.exp(-dwell)))
Example #7
0
def main():
    args = parser.parse_args()

    trim_start, trim_end = args.trim

    fast5_reads = fast5utils.iterate_fast5_reads(
        args.input_folder,
        limit=args.limit,
        strand_list=args.input_strand_list,
        recursive=args.recursive)

    with open_file_or_stdout(args.output) as tsvfile:
        writer = csv.writer(tsvfile, delimiter='\t', lineterminator='\n')
        # UUID is 32hexdigits and four dashes eg. '43f6a05c-0856-4edc-8cd2-4866d9d60eaa'
        writer.writerow(['UUID', 'trim_start', 'trim_end', 'shift', 'scale'])

        results = imap_mp(one_read_shift_scale, fast5_reads, threads=args.jobs)

        for result in results:
            if all(result):
                read_id, shift, scale = result
                writer.writerow([read_id, trim_start, trim_end, shift, scale])
Example #8
0
def main():
    args = parser.parse_args()

    device = helpers.set_torch_device(args.device)
    # TODO convert to logging
    sys.stderr.write("* Loading model.\n")
    model = load_model(args.model).to(device)
    is_cat_mod = isinstance(model.sublayers[-1],
                            layers.GlobalNormFlipFlopCatMod)
    do_output_mods = args.modified_base_output is not None
    if do_output_mods and not is_cat_mod:
        sys.stderr.write(
            "Cannot output modified bases from canonical base only model.")
        sys.exit()
    n_can_states = nstate_flipflop(model.sublayers[-1].nbase)
    stride = guess_model_stride(model)
    chunk_size = args.chunk_size * stride
    chunk_overlap = args.overlap * stride

    sys.stderr.write("* Initializing reads file search.\n")
    fast5_reads = list(
        fast5utils.iterate_fast5_reads(args.input_folder,
                                       limit=args.limit,
                                       strand_list=args.input_strand_list,
                                       recursive=args.recursive))
    sys.stderr.write("* Found {} reads.\n".format(len(fast5_reads)))

    if args.scaling is not None:
        sys.stderr.write("* Loading read scaling parameters from {}.\n".format(
            args.scaling))
        all_read_params = get_per_read_params_dict_from_tsv(args.scaling)
        input_read_ids = frozenset(rec[1] for rec in fast5_reads)
        scaling_read_ids = frozenset(all_read_params.keys())
        sys.stderr.write("* {} / {} reads have scaling information.\n".format(
            len(input_read_ids & scaling_read_ids), len(input_read_ids)))
        fast5_reads = [
            rec for rec in fast5_reads if rec[1] in scaling_read_ids
        ]
    else:
        all_read_params = {}

    mods_fp = None
    if do_output_mods:
        mods_fp = h5py.File(args.modified_base_output)
        mods_fp.create_group('Reads')
        mod_long_names = model.sublayers[-1].ordered_mod_long_names
        sys.stderr.write("* Preparing modified base output: {}.\n".format(
            ', '.join(map(str, mod_long_names))))
        mods_fp.create_dataset('mod_long_names',
                               data=np.array(mod_long_names, dtype='S'),
                               dtype=h5py.special_dtype(vlen=str))

    sys.stderr.write("* Calling reads.\n")
    nbase, ncalled, nread, nsample = 0, 0, 0, 0
    t0 = time.time()
    progress = Progress(quiet=args.quiet)
    startcharacter = '@' if args.fastq else '>'
    try:
        with open_file_or_stdout(args.output) as fh:
            for read_filename, read_id in fast5_reads:
                read_params = all_read_params[
                    read_id] if read_id in all_read_params else None
                basecall, qstring, read_nsample = process_read(
                    read_filename, read_id, model, chunk_size, chunk_overlap,
                    read_params, n_can_states, stride, args.alphabet,
                    is_cat_mod, mods_fp, args.max_concurrent_chunks,
                    args.fastq, args.qscore_scale, args.qscore_offset)
                if basecall is not None:
                    fh.write("{}{}\n{}\n".format(
                        startcharacter, read_id,
                        basecall[::-1] if args.reverse else basecall))
                    nbase += len(basecall)
                    ncalled += 1
                    if args.fastq:
                        fh.write("+\n{}\n".format(
                            qstring[::-1] if args.reverse else qstring))
                nread += 1
                nsample += read_nsample
                progress.step()
    finally:
        if mods_fp is not None:
            mods_fp.close()
    total_time = time.time() - t0

    sys.stderr.write("* Called {} reads in {:.2f}s\n".format(
        nread, int(total_time)))
    sys.stderr.write("* {:7.2f} kbase / s\n".format(nbase / total_time /
                                                    1000.0))
    sys.stderr.write("* {:7.2f} ksample / s\n".format(nsample / total_time /
                                                      1000.0))
    sys.stderr.write("* {} reads failed.\n".format(nread - ncalled))
    return
Example #9
0
def main():
    args = get_parser().parse_args()

    # TODO convert to logging

    sys.stderr.write("* Initializing reads file search.\n")
    fast5_reads = fast5utils.iterate_fast5_reads(
        args.input_folder, limit=args.limit,
        strand_list=args.input_strand_list, recursive=args.recursive)

    if args.scaling is not None:
        sys.stderr.write(
            "* Loading read scaling parameters from {}.\n".format(
                args.scaling))
        all_read_params = get_per_read_params_dict_from_tsv(args.scaling)
        input_read_ids = frozenset(rec[1] for rec in fast5_reads)
        scaling_read_ids = frozenset(all_read_params.keys())
        sys.stderr.write("* {} / {} reads have scaling information.\n".format(
            len(input_read_ids & scaling_read_ids), len(input_read_ids)))
        fast5_reads = [rec for rec in fast5_reads if rec[
            1] in scaling_read_ids]
    else:
        all_read_params = {}

    sys.stderr.write("* Calling reads.\n")
    nbase, ncalled, nread, nsample = 0, 0, 0, 0
    t0 = time.time()
    progress = Progress(quiet=args.quiet)
    startcharacter = '@' if args.fastq else '>'
    initargs = [args.device, args.model, args.chunk_size, args.overlap,
                all_read_params, args.alphabet,
                args.max_concurrent_chunks, args.fastq, args.qscore_scale,
                args.qscore_offset, args.beam, args.posterior,
                args.temperature]
    pool = Pool(args.jobs, initializer=worker_init, initargs=initargs)
    with open_file_or_stdout(args.output) as fh:
        for read_id, basecall, qstring, read_nsample in \
                pool.imap_unordered(worker, fast5_reads):
            if basecall is not None and len(basecall) > 0:
                fh.write("{}{}\n{}\n".format(
                    startcharacter, read_id,
                    basecall[::-1] if args.reverse else basecall))
                nbase += len(basecall)
                ncalled += 1
                if args.fastq:
                    fh.write("+\n{}\n".format(
                        qstring[::-1] if args.reverse else qstring))

            nread += 1
            nsample += read_nsample
            progress.step()
    total_time = time.time() - t0

    sys.stderr.write(
        "* Called {} reads in {:.2f}s\n".format(nread, int(total_time)))
    sys.stderr.write(
        "* {:7.2f} kbase / s\n".format(nbase / total_time / 1000.0))
    sys.stderr.write(
        "* {:7.2f} ksample / s\n".format(nsample / total_time / 1000.0))
    sys.stderr.write("* {} reads failed.\n".format(nread - ncalled))
    return
Example #10
0
def main():
    args = parser.parse_args()

    assert args.device != 'cpu', "Flipflop basecalling in taiyaki requires a GPU and for cupy to be installed"
    device = torch.device(args.device)
    # TODO convert to logging
    sys.stderr.write("* Loading model.\n")
    model = load_model(args.model).to(device)
    is_cat_mod = isinstance(model.sublayers[-1], layers.GlobalNormFlipFlopCatMod)
    do_output_mods = args.modified_base_output is not None
    if do_output_mods and not is_cat_mod:
        sys.stderr.write(
            "Cannot output modified bases from canonical base only model.")
        sys.exit()
    n_can_states = nstate_flipflop(model.sublayers[-1].nbase)
    stride = guess_model_stride(model, device=device)
    chunk_size, chunk_overlap = basecall_helpers.round_chunk_values(
        args.chunk_size, args.overlap, stride)

    sys.stderr.write("* Initializing reads file search.\n")
    fast5_reads = fast5utils.iterate_fast5_reads(
        args.input_folder, limit=args.limit, strand_list=args.input_strand_list,
        recursive=args.recursive)

    mods_fp = None
    if do_output_mods:
        mods_fp = h5py.File(args.modified_base_output)
        mods_fp.create_group('Reads')
        mod_long_names = model.sublayers[-1].ordered_mod_long_names
        sys.stderr.write("* Preparing modified base output: {}.\n".format(
            ', '.join(map(str, mod_long_names))))
        mods_fp.create_dataset(
            'mod_long_names', data=np.array(mod_long_names, dtype='S'),
            dtype=h5py.special_dtype(vlen=str))

    sys.stderr.write("* Calling reads.\n")
    nbase, ncalled, nread, nsample = 0, 0, 0, 0
    t0 = time.time()
    progress = Progress(quiet=args.quiet)
    try:
        with open_file_or_stdout(args.output) as fh:
            for read_filename, read_id in fast5_reads:
                basecall, read_nsample = process_read(
                    read_filename, read_id, model, chunk_size,
                    chunk_overlap, device, n_can_states, stride, args.alphabet,
                    is_cat_mod, mods_fp)
                if basecall is not None:
                    fh.write(">{}\n{}\n".format(read_id, basecall))
                    nbase += len(basecall)
                    ncalled += 1
                nread += 1
                nsample += read_nsample
                progress.step()
    finally:
        if mods_fp is not None:
            mods_fp.close()
    total_time = time.time() - t0

    sys.stderr.write("* Called {} reads in {}s\n".format(nread, int(total_time)))
    sys.stderr.write("* {:7.2f} kbase / s\n".format(nbase / total_time / 1000.0))
    sys.stderr.write("* {:7.2f} ksample / s\n".format(nsample / total_time / 1000.0))
    sys.stderr.write("* {} reads failed.\n".format(nread - ncalled))
    return