Пример #1
0
def main(args):

    poas = []
    init(args.seed, args.device)

    print("* loading data")
    testdata = ChunkDataSet(
        *load_data(limit=args.chunks, shuffle=args.shuffle))
    dataloader = DataLoader(testdata, batch_size=args.batchsize)

    for w in [int(i) for i in args.weights.split(',')]:

        print("* loading model", w)
        model = load_model(args.model_directory, args.device, weights=w)

        print("* calling")
        predictions = []
        t0 = time.perf_counter()

        for data, *_ in dataloader:
            with torch.no_grad():
                log_probs = model(data.to(args.device))
                predictions.append(log_probs.exp().cpu().numpy())

        duration = time.perf_counter() - t0

        references = [
            decode_ref(target, model.alphabet)
            for target in dataloader.dataset.targets
        ]
        sequences = [
            decode_ctc(post, model.alphabet)
            for post in np.concatenate(predictions)
        ]
        accuracies = list(starmap(accuracy, zip(references, sequences)))

        if args.poa: poas.append(sequences)

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)
        print("* samples/s %.2E" % (args.chunks * data.shape[2] / duration))

    if args.poa:

        print("* doing poa")
        t0 = time.perf_counter()
        # group each sequence prediction per model together
        poas = [list(seq) for seq in zip(*poas)]

        consensuses = poa(poas)
        duration = time.perf_counter() - t0

        accuracies = list(starmap(accuracy, zip(references, consensuses)))

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)
Пример #2
0
def main(args):

    sys.stderr.write("> loading model\n")

    model = load_model(
        args.model_directory,
        args.device,
        weights=int(args.weights),
        half=args.half,
        chunksize=args.chunksize,
        use_rt=args.cudart,
    )

    samples = 0
    num_reads = 0
    max_read_size = 4e6
    dtype = np.float16 if args.half else np.float32
    reader = PreprocessReader(args.reads_directory)
    writer = DecoderWriterPool(model,
                               beamsize=args.beamsize,
                               fastq=args.fastq,
                               reference=args.reference)

    t0 = time.perf_counter()
    sys.stderr.write("> calling\n")

    with writer, reader, torch.no_grad():

        while True:

            read = reader.queue.get()
            if read is None:
                break

            if len(read.signal) > max_read_size:
                sys.stderr.write("> skipping long read %s (%s samples)\n" %
                                 (read.read_id, len(read.signal)))
                continue

            num_reads += 1
            samples += len(read.signal)

            raw_data = torch.tensor(read.signal.astype(dtype))
            chunks = chunk(raw_data, args.chunksize, args.overlap)

            posteriors = model(chunks.to(args.device)).cpu().numpy()
            posteriors = stitch(posteriors, args.overlap // model.stride // 2)

            writer.queue.put((read, posteriors[:raw_data.shape[0]]))

    duration = time.perf_counter() - t0

    sys.stderr.write("> completed reads: %s\n" % num_reads)
    sys.stderr.write("> duration: %s\n" %
                     timedelta(seconds=np.round(duration)))
    sys.stderr.write("> samples per second %.1E\n" % (samples / duration))
    sys.stderr.write("> done\n")
Пример #3
0
def main(args):

    sys.stderr.write("> loading model\n")
    model = load_model(args.model_directory,
                       args.device,
                       weights=int(args.weights))

    num_reads = 0
    num_chunks = 0

    t0 = time.perf_counter()
    sys.stderr.write("> calling\n")

    for fast5 in tqdm(glob("%s/*fast5" % args.reads_directory), ascii=True):

        for read_id, raw_data in get_raw_data(fast5):

            if len(raw_data) <= args.chunksize:
                chunks = np.expand_dims(raw_data, axis=0)
            else:
                chunks = window(raw_data,
                                args.chunksize,
                                stepsize=args.chunksize - args.overlap)

            chunks = np.expand_dims(chunks, axis=1)

            num_reads += 1
            num_chunks += chunks.shape[0]

            with torch.no_grad():

                # copy to gpu
                tchunks = torch.tensor(chunks).to(args.device)

                # run model
                predictions = torch.exp(model(tchunks))

                # copy to cpu
                predictions = predictions.cpu()

                if len(predictions) > 1:
                    predictions = stitch(predictions,
                                         int(args.overlap / model.stride / 2))
                else:
                    predictions = np.squeeze(predictions, axis=0)

                sequence = decode_ctc(predictions, model.alphabet)

                print(">%s" % read_id)
                print('\n'.join(wrap(sequence, 100)))

    t1 = time.perf_counter()
    sys.stderr.write("> completed reads: %s\n" % num_reads)
    sys.stderr.write("> samples per second %.1E\n" %
                     (num_chunks * args.chunksize / (t1 - t0)))
    sys.stderr.write("> done\n")
Пример #4
0
def basecall(rank, total_gpu, args, input_files):
    setup(rank, total_gpu)

    device_id = rank
    sys.stderr.write("INFO: LOADING MODEL ON DEVICE: {}\n".format(device_id))
    model = load_model(args.model_directory, args.device, weights=int(args.weights), half=args.half)
    alphabet = model.alphabet
    torch.cuda.set_device(device_id)
    model.to(device_id)
    model.eval()
    model = DDP(model, device_ids=[device_id])
    sys.stderr.write("INFO: LOADED MODEL ON DEVICE: {}\n".format(device_id))

    samples = 0
    num_reads = 0
    max_read_size = 1e9
    dtype = np.float16 if args.half else np.float32
    sys.stderr.write('No of files:{}, index: {}'.format(len(input_files[rank]), rank))
    hdf5_file = h5py.File('{}/{}_{}.hdf5'.format(args.output_directory, args.prefix, device_id), 'w')
    hdf5_file.create_group('Reads')
    reads = hdf5_file['Reads']
    fasta_file = open('{}/{}_{}.fasta'.format(args.output_directory, args.prefix, device_id), 'w')

    t0 = time.perf_counter()
    sys.stderr.write("STARTING INFERENCE\n")
    st = time.time()
    with torch.no_grad():
        for fast5 in input_files[device_id]:
            for read_id, raw_data in get_raw_data(fast5):
                num_reads += 1
                samples += len(raw_data)
                signal_data = raw_data

                raw_data = raw_data[np.newaxis, np.newaxis, :].astype(dtype)
                gpu_data = torch.tensor(raw_data).to(args.device)   
                posteriors = model(gpu_data).exp().cpu().numpy().squeeze()

                sequence, means = decode_revised(posteriors, alphabet, signal_data, args.kmer_length, args.beamsize)
                if len(means) > 0:
                    # sys.stderr.write("\n> No. of kmers: {}\n".format(len(means)))
                    reads.create_group(read_id)
                    reads[read_id]['means'] = means
                fasta_file.write(">%s\n" % read_id)
                fasta_file.write("%s\n" % os.linesep.join(wrap(sequence, 100)))

            ct = time.time()
            sys.stderr.write("\nINFO: FINISHED PROCESSING: {}/{} FILES. DEVICE: {} ELAPSED TIME: {}".format(num_reads, len(input_files), device_id, ct-st))

    t1 = time.perf_counter()
    sys.stderr.write("INFO: TOTAL READS: %s\n" % num_reads)
    sys.stderr.write("INFO: TOTAL DURATION %.1E\n" % (t1 - t0))
    sys.stderr.write("INFO: SAMPLES PER SECOND %.1E\n" % (num_reads/(t1 - t0)))
    sys.stderr.write("DONE\n")

    cleanup()
Пример #5
0
def main(args):

    if args.save_ctc and not args.reference:
        sys.stderr.write("> a reference is needed to output ctc training data\n")
        exit(1)

    sys.stderr.write("> loading model\n")
    model = load_model(args.model_directory, args.device, weights=int(args.weights))

    if args.reference:
        sys.stderr.write("> loading reference\n")
        aligner = Aligner(args.reference, preset='ont-map')
        if not aligner:
            sys.stderr.write("> failed to load/build index\n")
            exit(1)
    else:
        aligner = None

    reads = get_reads(
        args.reads_directory, n_proc=8, recursive=args.recursive,
        read_ids=column_to_set(args.read_ids), skip=args.skip,
    )

    basecall = load_symbol(args.model_directory, "basecall")

    if args.save_ctc:
        reads = (
            chunk for read in reads if len(read.signal) >= 3600 for chunk in read_chunks(read)
        )
        basecalls = basecall(model, reads, aligner=aligner, qscores=args.fastq, batchsize=64)
        writer = CTCWriter(
            tqdm(basecalls, desc="> calling", unit=" reads", leave=False),
            aligner, args.ctc_min_coverage, args.ctc_min_accuracy
        )
    else:
        basecalls = basecall(model, reads, aligner=aligner, qscores=args.fastq)
        writer = Writer(
            tqdm(basecalls, desc="> calling", unit=" reads", leave=False), aligner, fastq=args.fastq
        )

    t0 = perf_counter()
    writer.start()
    writer.join()
    duration = perf_counter() - t0
    num_samples = sum(num_samples for read_id, num_samples in writer.log)

    sys.stderr.write("> completed reads: %s\n" % len(writer.log))
    sys.stderr.write("> duration: %s\n" % timedelta(seconds=np.round(duration)))
    sys.stderr.write("> samples per second %.1E\n" % (num_samples / duration))
    sys.stderr.write("> done\n")
Пример #6
0
def main(args):

	sys.stderr.write("> loading model\n")
	model = load_model(args.model_directory, args.device, weights=int(args.weights), half=args.half)

	samples = 0
	num_reads = 0
	max_read_size = 1e9
	dtype = np.float16 if args.half else np.float32
	reader = PreprocessReader(args.reads_directory)
	writer = DecoderWriterRevised(model.alphabet, args.beamsize, args.kmer_length, args.hdf5_filename)
	# writer = DecoderWriter(model.alphabet, args.beamsize)

	t0 = time.perf_counter()
	# sys.stderr.write("> calling\n")

	# with reader, torch.no_grad():
	with writer, reader, torch.no_grad():

		while True:

			read = reader.queue.get()
			if read is None:
				break

			read_id, raw_data = read
			if len(raw_data) > max_read_size:
				sys.stderr.write("> skipping %s: %s too long\n" % (len(raw_data), read_id))
				pass
			num_reads += 1
			samples += len(raw_data)
			signal_data = raw_data

			raw_data = raw_data[np.newaxis, np.newaxis, :].astype(dtype)
			gpu_data = torch.tensor(raw_data).to(args.device)	
			posteriors = model(gpu_data).exp().cpu().numpy().squeeze()

			# writer.queue.put((read_id, posteriors))

			# sys.stderr.write("\n> idx: %s\tcurrent read: %s" % (num_reads, read_id))
				
			writer.queue.put((read_id, posteriors, signal_data))
			
	duration = time.perf_counter() - t0

	sys.stderr.write("> completed reads: %s\n" % num_reads)
	sys.stderr.write("> total duration : %ss\n" % duration)
	sys.stderr.write("> samples per second %.1E\n" % (samples  / duration))
	sys.stderr.write("> done\n")
Пример #7
0
def main(args):

    sys.stderr.write("> loading model\n")
    model = load_model(args.model_directory,
                       args.device,
                       weights=int(args.weights),
                       half=args.half)

    samples = 0
    num_reads = 0
    max_read_size = 4e6
    dtype = np.float16 if args.half else np.float32
    reader = PreprocessReader(args.reads_directory)
    writer = DecoderWriter(model, beamsize=args.beamsize, fastq=args.fastq)

    t0 = time.perf_counter()
    sys.stderr.write("> calling\n")

    with writer, reader, torch.no_grad():

        while True:

            read = reader.queue.get()
            if read is None:
                break

            read_id, raw_data = read

            if len(raw_data) > max_read_size:
                sys.stderr.write("> skipping long read %s (%s samples)\n" %
                                 (read_id, len(raw_data)))
                continue

            num_reads += 1
            samples += len(raw_data)

            raw_data = raw_data[np.newaxis, np.newaxis, :].astype(dtype)
            gpu_data = torch.tensor(raw_data).to(args.device)
            posteriors = model(gpu_data).exp().cpu().numpy().squeeze()

            writer.queue.put((read_id, posteriors.astype(np.float32)))

    duration = time.perf_counter() - t0

    sys.stderr.write("> completed reads: %s\n" % num_reads)
    sys.stderr.write("> samples per second %.1E\n" % (samples / duration))
    sys.stderr.write("> done\n")
Пример #8
0
def main(args):

    poas = []
    init(args.seed, args.device)

    print("* loading data")
    testdata = ChunkDataSet(
        *load_data(
            limit=args.chunks, shuffle=args.shuffle,
            directory=args.directory, validation=True
        )
    )
    dataloader = DataLoader(testdata, batch_size=args.batchsize)
    accuracy_with_cov = lambda ref, seq: accuracy(ref, seq, min_coverage=args.min_coverage)

    for w in [int(i) for i in args.weights.split(',')]:

        seqs = []

        print("* loading model", w)
        model = load_model(args.model_directory, args.device, weights=w)

        print("* calling")
        t0 = time.perf_counter()

        with torch.no_grad():
            for data, *_ in dataloader:
                if half_supported():
                    data = data.type(torch.float16).to(args.device)
                else:
                    data = data.to(args.device)

                log_probs = model(data)

                if hasattr(model, 'decode_batch'):
                    seqs.extend(model.decode_batch(log_probs))
                else:
                    seqs.extend([model.decode(p) for p in permute(log_probs, 'TNC', 'NTC')])

        duration = time.perf_counter() - t0

        refs = [decode_ref(target, model.alphabet) for target in dataloader.dataset.targets]
        accuracies = [accuracy_with_cov(ref, seq) if len(seq) else 0. for ref, seq in zip(refs, seqs)]

        if args.poa: poas.append(sequences)

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)
        print("* samples/s %.2E" % (args.chunks * data.shape[2] / duration))

    if args.poa:

        print("* doing poa")
        t0 = time.perf_counter()
        # group each sequence prediction per model together
        poas = [list(seq) for seq in zip(*poas)]
        consensuses = poa(poas)
        duration = time.perf_counter() - t0
        accuracies = list(starmap(accuracy_with_coverage_filter, zip(references, consensuses)))

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)
Пример #9
0
def main(args):

    init(args.seed, args.device)

    if args.model_directory in models and args.model_directory not in os.listdir(
            __models__):
        sys.stderr.write("> downloading model\n")
        File(__models__, models[args.model_directory]).download()

    sys.stderr.write(f"> loading model {args.model_directory}\n")
    try:
        model = load_model(
            args.model_directory,
            args.device,
            weights=int(args.weights),
            chunksize=args.chunksize,
            overlap=args.overlap,
            batchsize=args.batchsize,
            quantize=args.quantize,
            use_koi=True,
        )
    except FileNotFoundError:
        sys.stderr.write(f"> error: failed to load {args.model_directory}\n")
        sys.stderr.write(f"> available models:\n")
        for model in sorted(models):
            sys.stderr.write(f" - {model}\n")
        exit(1)

    if args.verbose:
        sys.stderr.write(
            f"> model basecaller params: {model.config['basecaller']}\n")

    basecall = load_symbol(args.model_directory, "basecall")

    mods_model = None
    if args.modified_base_model is not None or args.modified_bases is not None:
        sys.stderr.write("> loading modified base model\n")
        mods_model = load_mods_model(args.modified_bases, args.model_directory,
                                     args.modified_base_model)
        sys.stderr.write(f"> {mods_model[1]['alphabet_str']}\n")

    if args.reference:
        sys.stderr.write("> loading reference\n")
        aligner = Aligner(args.reference, preset='ont-map', best_n=1)
        if not aligner:
            sys.stderr.write("> failed to load/build index\n")
            exit(1)
    else:
        aligner = None

    fmt = biofmt(aligned=args.reference is not None)

    if args.reference and args.reference.endswith(
            ".mmi") and fmt.name == "cram":
        sys.stderr.write(
            "> error: reference cannot be a .mmi when outputting cram\n")
        exit(1)
    elif args.reference and fmt.name == "fastq":
        sys.stderr.write(
            f"> warning: did you really want {fmt.aligned} {fmt.name}?\n")
    else:
        sys.stderr.write(f"> outputting {fmt.aligned} {fmt.name}\n")

    if args.save_ctc and not args.reference:
        sys.stderr.write(
            "> a reference is needed to output ctc training data\n")
        exit(1)

    if fmt.name != 'fastq':
        groups = get_read_groups(args.reads_directory,
                                 args.model_directory,
                                 n_proc=8,
                                 recursive=args.recursive,
                                 read_ids=column_to_set(args.read_ids),
                                 skip=args.skip,
                                 cancel=process_cancel())
    else:
        groups = []

    reads = get_reads(args.reads_directory,
                      n_proc=8,
                      recursive=args.recursive,
                      read_ids=column_to_set(args.read_ids),
                      skip=args.skip,
                      cancel=process_cancel())

    if args.max_reads:
        reads = take(reads, args.max_reads)

    if args.save_ctc:
        reads = (chunk for read in reads for chunk in read_chunks(
            read,
            chunksize=model.config["basecaller"]["chunksize"],
            overlap=model.config["basecaller"]["overlap"]))
        ResultsWriter = CTCWriter
    else:
        ResultsWriter = Writer

    results = basecall(model,
                       reads,
                       reverse=args.revcomp,
                       batchsize=model.config["basecaller"]["batchsize"],
                       chunksize=model.config["basecaller"]["chunksize"],
                       overlap=model.config["basecaller"]["overlap"])

    if mods_model is not None:
        results = process_itemmap(partial(call_mods, mods_model), results)
    if aligner:
        results = align_map(aligner, results, n_thread=os.cpu_count())

    writer = ResultsWriter(
        fmt.mode,
        tqdm(results, desc="> calling", unit=" reads", leave=False),
        aligner=aligner,
        group_key=args.model_directory,
        ref_fn=args.reference,
        groups=groups,
    )

    t0 = perf_counter()
    writer.start()
    writer.join()
    duration = perf_counter() - t0
    num_samples = sum(num_samples for read_id, num_samples in writer.log)

    sys.stderr.write("> completed reads: %s\n" % len(writer.log))
    sys.stderr.write("> duration: %s\n" %
                     timedelta(seconds=np.round(duration)))
    sys.stderr.write("> samples per second %.1E\n" % (num_samples / duration))
    sys.stderr.write("> done\n")
Пример #10
0
def main(args):

    if args.save_ctc and not args.reference:
        sys.stderr.write("> a reference is needed to output ctc training data\n")
        exit(1)

    if args.save_ctc:
        args.overlap = 900
        args.chunksize = 3600

    sys.stderr.write("> loading model\n")

    model = load_model(
        args.model_directory, args.device, weights=int(args.weights),
        half=args.half, chunksize=args.chunksize, use_rt=args.cudart,
    )

    if args.reference:
        sys.stderr.write("> loading reference\n")
        aligner = Aligner(args.reference, preset='ont-map')
        if not aligner:
            sys.stderr.write("> failed to load/build index\n")
            sys.exit(1)
    else:
        aligner = None

    samples = 0
    num_reads = 0
    max_read_size = 4e6
    dtype = np.float16 if args.half else np.float32
    ctc_writer = CTCWriter(model, aligner)
    reader = PreprocessReader(args.reads_directory)
    writer = DecoderWriterPool(model, beamsize=args.beamsize, fastq=args.fastq, aligner=aligner)

    t0 = time.perf_counter()
    sys.stderr.write("> calling\n")

    with writer, ctc_writer, reader, torch.no_grad():

        while True:

            read = reader.queue.get()
            if read is None:
                break

            if len(read.signal) > max_read_size:
                sys.stderr.write("> skipping long read %s (%s samples)\n" % (read.read_id, len(read.signal)))
                continue

            num_reads += 1
            samples += len(read.signal)

            raw_data = torch.tensor(read.signal.astype(dtype))
            chunks = chunk(raw_data, args.chunksize, args.overlap)

            posteriors_ = model(chunks.to(args.device)).cpu().numpy()
            posteriors = stitch(posteriors_, args.overlap // model.stride // 2)

            writer.queue.put((read, posteriors[:raw_data.shape[0]]))
            if args.save_ctc and len(raw_data) > args.chunksize:
                ctc_writer.queue.put((chunks.numpy(), posteriors_))

    duration = time.perf_counter() - t0

    sys.stderr.write("> completed reads: %s\n" % num_reads)
    sys.stderr.write("> duration: %s\n" % timedelta(seconds=np.round(duration)))
    sys.stderr.write("> samples per second %.1E\n" % (samples / duration))
    sys.stderr.write("> done\n")
Пример #11
0
def main(args):

    samples = 0
    num_pairs = 0
    max_read_size = 4e6
    dtype = np.float16 if half_supported() else np.float32

    if args.index is not None:
        sys.stderr.write("> loading read index\n")
        index = json.load(open(args.index, 'r'))
    else:
        sys.stderr.write("> building read index\n")
        files = list(glob(os.path.join(args.reads_directory, '*.fast5')))
        index = build_index(files)
        if args.save_index:
            with open('bonito-read-id.idx', 'w') as f:
                json.dump(index, f)

    sys.stderr.write("> loading model\n")

    model_temp = load_model(args.temp_model_directory, args.device)
    model_comp = load_model(args.comp_model_directory, args.device)

    decoders = PairDecoderWriterPool(model_temp.alphabet, procs=args.num_procs)

    t0 = time.perf_counter()
    sys.stderr.write("> calling\n")

    with torch.no_grad(), open(args.pairs_file) as pairs, decoders:

        for pair in tqdm(pairs, ascii=True, ncols=100):

            read_id_1, read_id_2 = pair.strip().split(args.sep)

            if read_id_1 not in index or read_id_2 not in index: continue

            read_1 = get_raw_data_for_read(
                os.path.join(args.reads_directory, index[read_id_1]),
                read_id_1)
            raw_data_1 = read_1.signal

            if len(raw_data_1) > max_read_size:
                sys.stderr.write("> skipping long read %s (%s samples)\n" %
                                 (read_id_1, len(raw_data_1)))
                continue

            read_2 = get_raw_data_for_read(
                os.path.join(args.reads_directory, index[read_id_2]),
                read_id_2)
            raw_data_2 = read_2.signal

            if len(raw_data_2) > max_read_size:
                sys.stderr.write("> skipping long read %s (%s samples)\n" %
                                 (read_id_2, len(raw_data_2)))
                continue

            # call the template strand
            raw_data_1 = raw_data_1[np.newaxis, np.newaxis, :].astype(dtype)
            gpu_data_1 = torch.tensor(raw_data_1).to(args.device)
            logits_1 = model_temp(gpu_data_1).cpu().numpy().squeeze().astype(
                np.float32)

            # call the complement strand
            raw_data_2 = raw_data_2[np.newaxis, np.newaxis, :].astype(dtype)
            gpu_data_2 = torch.tensor(raw_data_2).to(args.device)
            logits_2 = model_comp(gpu_data_2).cpu().numpy().squeeze().astype(
                np.float32)

            num_pairs += 1
            samples += raw_data_1.shape[-1] + raw_data_2.shape[-1]

            # pair decode
            decoders.queue.put((read_id_1, logits_1, read_id_2, logits_2))

    duration = time.perf_counter() - t0

    sys.stderr.write("> completed pairs: %s\n" % num_pairs)
    sys.stderr.write("> samples per second %.1E\n" % (samples / duration))
    sys.stderr.write("> done\n")
Пример #12
0
def main(args):

    workdir = os.path.expanduser(args.training_directory)

    if os.path.exists(workdir) and not args.force:
        print("[error] %s exists, use -f to force continue training." %
              workdir)
        exit(1)

    init(args.seed, args.device)
    device = torch.device(args.device)

    print("[loading data]")
    train_data = load_data(limit=args.chunks, directory=args.directory)
    if os.path.exists(os.path.join(args.directory, 'validation')):
        valid_data = load_data(
            directory=os.path.join(args.directory, 'validation'))
    else:
        print("[validation set not found: splitting training set]")
        split = np.floor(len(train_data[0]) * 0.97).astype(np.int32)
        valid_data = [x[split:] for x in train_data]
        train_data = [x[:split] for x in train_data]

    train_loader = DataLoader(ChunkDataSet(*train_data),
                              batch_size=args.batch,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)
    valid_loader = DataLoader(ChunkDataSet(*valid_data),
                              batch_size=args.batch,
                              num_workers=4,
                              pin_memory=True)

    config = toml.load(args.config)
    argsdict = dict(training=vars(args))

    chunk_config = {}
    chunk_config_file = os.path.join(args.directory, 'config.toml')
    if os.path.isfile(chunk_config_file):
        chunk_config = toml.load(os.path.join(chunk_config_file))

    os.makedirs(workdir, exist_ok=True)
    toml.dump({
        **config,
        **argsdict,
        **chunk_config
    }, open(os.path.join(workdir, 'config.toml'), 'w'))

    print("[loading model]")
    if args.pretrained:
        print("[using pretrained model {}]".format(args.pretrained))
        model = load_model(args.pretrained, device, half=False)
    else:
        model = load_symbol(config, 'Model')(config)
    optimizer = AdamW(model.parameters(), amsgrad=False, lr=args.lr)

    last_epoch = load_state(workdir,
                            args.device,
                            model,
                            optimizer,
                            use_amp=args.amp)

    lr_scheduler = func_scheduler(optimizer,
                                  cosine_decay_schedule(1.0, 0.1),
                                  args.epochs * len(train_loader),
                                  warmup_steps=500,
                                  start_step=last_epoch * len(train_loader))

    if args.multi_gpu:
        from torch.nn import DataParallel
        model = DataParallel(model)
        model.decode = model.module.decode
        model.alphabet = model.module.alphabet

    if hasattr(model, 'seqdist'):
        criterion = model.seqdist.ctc_loss
    else:
        criterion = None

    for epoch in range(1 + last_epoch, args.epochs + 1 + last_epoch):

        try:
            with CSVLogger(os.path.join(
                    workdir, 'losses_{}.csv'.format(epoch))) as loss_log:
                train_loss, duration = train(model,
                                             device,
                                             train_loader,
                                             optimizer,
                                             criterion=criterion,
                                             use_amp=args.amp,
                                             lr_scheduler=lr_scheduler,
                                             loss_log=loss_log)

            model_state = model.state_dict(
            ) if not args.multi_gpu else model.module.state_dict()
            torch.save(model_state,
                       os.path.join(workdir, "weights_%s.tar" % epoch))

            val_loss, val_mean, val_median = test(model,
                                                  device,
                                                  valid_loader,
                                                  criterion=criterion)
        except KeyboardInterrupt:
            break

        print(
            "[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%"
            .format(epoch, workdir, val_loss, val_mean, val_median))

        with CSVLogger(os.path.join(workdir, 'training.csv')) as training_log:
            training_log.append(
                OrderedDict([('time', datetime.today()),
                             ('duration', int(duration)), ('epoch', epoch),
                             ('train_loss', train_loss),
                             ('validation_loss', val_loss),
                             ('validation_mean', val_mean),
                             ('validation_median', val_median)]))
Пример #13
0
def main(args):

    workdir = os.path.expanduser(args.training_directory)

    if os.path.exists(workdir) and not args.force:
        print("[error] %s exists, use -f to force continue training." %
              workdir)
        exit(1)

    init(args.seed, args.device, (not args.nondeterministic))
    device = torch.device(args.device)

    print("[loading data]")
    try:
        train_loader_kwargs, valid_loader_kwargs = load_numpy(
            args.chunks, args.directory)
    except FileNotFoundError:
        train_loader_kwargs, valid_loader_kwargs = load_script(
            args.directory,
            seed=args.seed,
            chunks=args.chunks,
            valid_chunks=args.valid_chunks)

    loader_kwargs = {
        "batch_size": args.batch,
        "num_workers": 4,
        "pin_memory": True
    }
    train_loader = DataLoader(**loader_kwargs, **train_loader_kwargs)
    valid_loader = DataLoader(**loader_kwargs, **valid_loader_kwargs)

    if args.pretrained:
        dirname = args.pretrained
        if not os.path.isdir(dirname) and os.path.isdir(
                os.path.join(__models__, dirname)):
            dirname = os.path.join(__models__, dirname)
        config_file = os.path.join(dirname, 'config.toml')
    else:
        config_file = args.config

    config = toml.load(config_file)

    argsdict = dict(training=vars(args))

    os.makedirs(workdir, exist_ok=True)
    toml.dump({
        **config,
        **argsdict
    }, open(os.path.join(workdir, 'config.toml'), 'w'))

    print("[loading model]")
    if args.pretrained:
        print("[using pretrained model {}]".format(args.pretrained))
        model = load_model(args.pretrained, device, half=False)
    else:
        model = load_symbol(config, 'Model')(config)

    if config.get("lr_scheduler"):
        sched_config = config["lr_scheduler"]
        lr_scheduler_fn = getattr(import_module(sched_config["package"]),
                                  sched_config["symbol"])(**sched_config)
    else:
        lr_scheduler_fn = None

    trainer = Trainer(model,
                      device,
                      train_loader,
                      valid_loader,
                      use_amp=half_supported() and not args.no_amp,
                      lr_scheduler_fn=lr_scheduler_fn,
                      restore_optim=args.restore_optim,
                      save_optim_every=args.save_optim_every,
                      grad_accum_split=args.grad_accum_split)

    if (',' in args.lr):
        lr = [float(x) for x in args.lr.split(',')]
    else:
        lr = float(args.lr)
    trainer.fit(workdir, args.epochs, lr)
Пример #14
0
def main(args):

    sys.stderr.write("> loading model\n")
    model = load_model(args.model, args.device)

    if args.reference:
        sys.stderr.write("> loading reference\n")
        aligner = Aligner(args.reference, preset='ont-map')
        if not aligner:
            sys.stderr.write("> failed to load/build index\n")
            exit(1)
    else:
        aligner = None

    if args.summary:
        sys.stderr.write("> finding follow on strands\n")
        pairs = pd.read_csv(args.summary, '\t', low_memory=False)
        pairs = pairs[pairs.sequence_length_template.gt(0)]
        if 'filename' in pairs.columns:
            pairs = pairs.rename(columns={'filename': 'filename_fast5'})
        if 'alignment_strand_coverage' in pairs.columns:
            pairs = pairs.rename(
                columns={'alignment_strand_coverage': 'alignment_coverage'})
        valid_fast5s = [
            f for f in pairs.filename_fast5.unique()
            if ((args.reads_directory / Path(f)).exists())
        ]
        pairs = pairs[pairs.filename_fast5.isin(valid_fast5s)]
        pairs = find_follow_on(pairs)
        sys.stderr.write("> found %s follow strands in summary\n" %
                         (len(pairs) // 2))

        if args.max_reads > 0: pairs = pairs.head(args.max_reads)

        temp_reads = pairs.iloc[0::2]
        comp_reads = pairs.iloc[1::2]
    else:
        if args.index is not None:
            sys.stderr.write("> loading read index\n")
            index = json.load(open(args.index, 'r'))
        else:
            sys.stderr.write("> building read index\n")
            files = list(glob(os.path.join(args.reads_directory, '*.fast5')))
            index = build_index(files, n_proc=8)
            if args.save_index:
                with open('bonito-read-id.idx', 'w') as f:
                    json.dump(index, f)

        pairs = pd.read_csv(args.pairs,
                            sep=args.sep,
                            names=['read_1', 'read_2'])
        if args.max_reads > 0: pairs = pairs.head(args.max_reads)

        pairs['file_1'] = pairs['read_1'].apply(index.get)
        pairs['file_2'] = pairs['read_2'].apply(index.get)
        pairs = pairs.dropna().reset_index()

        temp_reads = pairs[['read_1',
                            'file_1']].rename(columns={
                                'read_1': 'read_id',
                                'file_1': 'filename_fast5'
                            })
        comp_reads = pairs[['read_2',
                            'file_2']].rename(columns={
                                'read_2': 'read_id',
                                'file_2': 'filename_fast5'
                            })

    if len(pairs) == 0:
        print("> no matched pairs found in given directory", file=sys.stderr)
        exit(1)

    # https://github.com/clara-parabricks/GenomeWorks/issues/648
    with devnull():
        CudaPoaBatch(1000, 1000, 3724032)

    basecalls = call(model,
                     args.reads_directory,
                     temp_reads,
                     comp_reads,
                     aligner=aligner)
    writer = Writer(tqdm(basecalls,
                         desc="> calling",
                         unit=" reads",
                         leave=False),
                    aligner,
                    duplex=True)

    t0 = perf_counter()
    writer.start()
    writer.join()
    duration = perf_counter() - t0
    num_samples = sum(num_samples for read_id, num_samples in writer.log)

    print("> duration: %s" % timedelta(seconds=np.round(duration)),
          file=sys.stderr)
    print("> samples per second %.1E" % (num_samples / duration),
          file=sys.stderr)
Пример #15
0
def main(args):
    if args.save_ctc and not args.reference:
        sys.stderr.write(
            "> a reference is needed to output ctc training data\n")
        exit(1)

    if args.save_ctc:
        args.overlap = 900
        args.chunksize = 3600

    sys.stderr.write("> loading model\n")

    model = load_model(
        args.model_directory,
        args.device,
        weights=int(args.weights),
        half=args.half,
        chunksize=args.chunksize,
        use_rt=args.cudart,
    )

    if args.reference:
        sys.stderr.write("> loading reference\n")
        aligner = Aligner(args.reference, preset='ont-map')
        if not aligner:
            sys.stderr.write("> failed to load/build index\n")
            sys.exit(1)
        write_sam_header(aligner)
    else:
        aligner = None


#    with open(summary_file(), 'w') as summary:
#        write_summary_header(summary, alignment=aligner)

    samples = 0
    num_reads = 0
    max_read_size = 4e6
    read_ids = column_to_set(args.read_ids)
    dtype = np.float16 if args.half else np.float32
    reader = ProcessIterator(get_reads(args.reads_directory,
                                       read_ids=read_ids,
                                       skip=args.skip),
                             progress=True)
    writer = ProcessPool(DecoderWriter,
                         model=model,
                         aligner=aligner,
                         beamsize=args.beamsize,
                         fastq=args.fastq)
    ctc_writer = CTCWriter(model,
                           aligner,
                           min_coverage=args.ctc_min_coverage,
                           min_accuracy=args.ctc_min_accuracy)

    t0 = time.perf_counter()
    sys.stderr.write("> calling\n")

    with writer, ctc_writer, reader, torch.no_grad():

        while True:

            read = reader.queue.get()
            if read is None:
                break

            if len(read.signal) > max_read_size:
                sys.stderr.write("> skipping long read %s (%s samples)\n" %
                                 (read.read_id, len(read.signal)))
                continue

            num_reads += 1
            samples += len(read.signal)

            raw_data = torch.tensor(read.signal.astype(dtype))
            print('bonito: raw_data.shape: ', raw_data.shape)
            chunks = chunk(raw_data, args.chunksize, args.overlap)

            posteriors_ = model(chunks.to(args.device)).cpu().numpy()
            posteriors = stitch(posteriors_, args.overlap // model.stride // 2)
            if args.write_basecall:
                writer.queue.put((read, posteriors[:raw_data.shape[0]]))
            if args.save_ctc and len(raw_data) > args.chunksize:
                ctc_writer.queue.put((chunks.numpy(), posteriors_))
            print('bonito: posteriors.shape', posteriors.shape)
            posteriors.tofile(args.post_file)

    duration = time.perf_counter() - t0

    sys.stderr.write("> completed reads: %s\n" % num_reads)
    sys.stderr.write("> duration: %s\n" %
                     timedelta(seconds=np.round(duration)))
    sys.stderr.write("> samples per second %.1E\n" % (samples / duration))
    sys.stderr.write("> done\n")