def main(args): poas = [] init(args.seed, args.device) print("* loading data") testdata = ChunkDataSet( *load_data(limit=args.chunks, shuffle=args.shuffle)) dataloader = DataLoader(testdata, batch_size=args.batchsize) for w in [int(i) for i in args.weights.split(',')]: print("* loading model", w) model = load_model(args.model_directory, args.device, weights=w) print("* calling") predictions = [] t0 = time.perf_counter() for data, *_ in dataloader: with torch.no_grad(): log_probs = model(data.to(args.device)) predictions.append(log_probs.exp().cpu().numpy()) duration = time.perf_counter() - t0 references = [ decode_ref(target, model.alphabet) for target in dataloader.dataset.targets ] sequences = [ decode_ctc(post, model.alphabet) for post in np.concatenate(predictions) ] accuracies = list(starmap(accuracy, zip(references, sequences))) if args.poa: poas.append(sequences) print("* mean %.2f%%" % np.mean(accuracies)) print("* median %.2f%%" % np.median(accuracies)) print("* time %.2f" % duration) print("* samples/s %.2E" % (args.chunks * data.shape[2] / duration)) if args.poa: print("* doing poa") t0 = time.perf_counter() # group each sequence prediction per model together poas = [list(seq) for seq in zip(*poas)] consensuses = poa(poas) duration = time.perf_counter() - t0 accuracies = list(starmap(accuracy, zip(references, consensuses))) print("* mean %.2f%%" % np.mean(accuracies)) print("* median %.2f%%" % np.median(accuracies)) print("* time %.2f" % duration)
def main(args): sys.stderr.write("> loading model\n") model = load_model( args.model_directory, args.device, weights=int(args.weights), half=args.half, chunksize=args.chunksize, use_rt=args.cudart, ) samples = 0 num_reads = 0 max_read_size = 4e6 dtype = np.float16 if args.half else np.float32 reader = PreprocessReader(args.reads_directory) writer = DecoderWriterPool(model, beamsize=args.beamsize, fastq=args.fastq, reference=args.reference) t0 = time.perf_counter() sys.stderr.write("> calling\n") with writer, reader, torch.no_grad(): while True: read = reader.queue.get() if read is None: break if len(read.signal) > max_read_size: sys.stderr.write("> skipping long read %s (%s samples)\n" % (read.read_id, len(read.signal))) continue num_reads += 1 samples += len(read.signal) raw_data = torch.tensor(read.signal.astype(dtype)) chunks = chunk(raw_data, args.chunksize, args.overlap) posteriors = model(chunks.to(args.device)).cpu().numpy() posteriors = stitch(posteriors, args.overlap // model.stride // 2) writer.queue.put((read, posteriors[:raw_data.shape[0]])) duration = time.perf_counter() - t0 sys.stderr.write("> completed reads: %s\n" % num_reads) sys.stderr.write("> duration: %s\n" % timedelta(seconds=np.round(duration))) sys.stderr.write("> samples per second %.1E\n" % (samples / duration)) sys.stderr.write("> done\n")
def main(args): sys.stderr.write("> loading model\n") model = load_model(args.model_directory, args.device, weights=int(args.weights)) num_reads = 0 num_chunks = 0 t0 = time.perf_counter() sys.stderr.write("> calling\n") for fast5 in tqdm(glob("%s/*fast5" % args.reads_directory), ascii=True): for read_id, raw_data in get_raw_data(fast5): if len(raw_data) <= args.chunksize: chunks = np.expand_dims(raw_data, axis=0) else: chunks = window(raw_data, args.chunksize, stepsize=args.chunksize - args.overlap) chunks = np.expand_dims(chunks, axis=1) num_reads += 1 num_chunks += chunks.shape[0] with torch.no_grad(): # copy to gpu tchunks = torch.tensor(chunks).to(args.device) # run model predictions = torch.exp(model(tchunks)) # copy to cpu predictions = predictions.cpu() if len(predictions) > 1: predictions = stitch(predictions, int(args.overlap / model.stride / 2)) else: predictions = np.squeeze(predictions, axis=0) sequence = decode_ctc(predictions, model.alphabet) print(">%s" % read_id) print('\n'.join(wrap(sequence, 100))) t1 = time.perf_counter() sys.stderr.write("> completed reads: %s\n" % num_reads) sys.stderr.write("> samples per second %.1E\n" % (num_chunks * args.chunksize / (t1 - t0))) sys.stderr.write("> done\n")
def basecall(rank, total_gpu, args, input_files): setup(rank, total_gpu) device_id = rank sys.stderr.write("INFO: LOADING MODEL ON DEVICE: {}\n".format(device_id)) model = load_model(args.model_directory, args.device, weights=int(args.weights), half=args.half) alphabet = model.alphabet torch.cuda.set_device(device_id) model.to(device_id) model.eval() model = DDP(model, device_ids=[device_id]) sys.stderr.write("INFO: LOADED MODEL ON DEVICE: {}\n".format(device_id)) samples = 0 num_reads = 0 max_read_size = 1e9 dtype = np.float16 if args.half else np.float32 sys.stderr.write('No of files:{}, index: {}'.format(len(input_files[rank]), rank)) hdf5_file = h5py.File('{}/{}_{}.hdf5'.format(args.output_directory, args.prefix, device_id), 'w') hdf5_file.create_group('Reads') reads = hdf5_file['Reads'] fasta_file = open('{}/{}_{}.fasta'.format(args.output_directory, args.prefix, device_id), 'w') t0 = time.perf_counter() sys.stderr.write("STARTING INFERENCE\n") st = time.time() with torch.no_grad(): for fast5 in input_files[device_id]: for read_id, raw_data in get_raw_data(fast5): num_reads += 1 samples += len(raw_data) signal_data = raw_data raw_data = raw_data[np.newaxis, np.newaxis, :].astype(dtype) gpu_data = torch.tensor(raw_data).to(args.device) posteriors = model(gpu_data).exp().cpu().numpy().squeeze() sequence, means = decode_revised(posteriors, alphabet, signal_data, args.kmer_length, args.beamsize) if len(means) > 0: # sys.stderr.write("\n> No. of kmers: {}\n".format(len(means))) reads.create_group(read_id) reads[read_id]['means'] = means fasta_file.write(">%s\n" % read_id) fasta_file.write("%s\n" % os.linesep.join(wrap(sequence, 100))) ct = time.time() sys.stderr.write("\nINFO: FINISHED PROCESSING: {}/{} FILES. DEVICE: {} ELAPSED TIME: {}".format(num_reads, len(input_files), device_id, ct-st)) t1 = time.perf_counter() sys.stderr.write("INFO: TOTAL READS: %s\n" % num_reads) sys.stderr.write("INFO: TOTAL DURATION %.1E\n" % (t1 - t0)) sys.stderr.write("INFO: SAMPLES PER SECOND %.1E\n" % (num_reads/(t1 - t0))) sys.stderr.write("DONE\n") cleanup()
def main(args): if args.save_ctc and not args.reference: sys.stderr.write("> a reference is needed to output ctc training data\n") exit(1) sys.stderr.write("> loading model\n") model = load_model(args.model_directory, args.device, weights=int(args.weights)) if args.reference: sys.stderr.write("> loading reference\n") aligner = Aligner(args.reference, preset='ont-map') if not aligner: sys.stderr.write("> failed to load/build index\n") exit(1) else: aligner = None reads = get_reads( args.reads_directory, n_proc=8, recursive=args.recursive, read_ids=column_to_set(args.read_ids), skip=args.skip, ) basecall = load_symbol(args.model_directory, "basecall") if args.save_ctc: reads = ( chunk for read in reads if len(read.signal) >= 3600 for chunk in read_chunks(read) ) basecalls = basecall(model, reads, aligner=aligner, qscores=args.fastq, batchsize=64) writer = CTCWriter( tqdm(basecalls, desc="> calling", unit=" reads", leave=False), aligner, args.ctc_min_coverage, args.ctc_min_accuracy ) else: basecalls = basecall(model, reads, aligner=aligner, qscores=args.fastq) writer = Writer( tqdm(basecalls, desc="> calling", unit=" reads", leave=False), aligner, fastq=args.fastq ) t0 = perf_counter() writer.start() writer.join() duration = perf_counter() - t0 num_samples = sum(num_samples for read_id, num_samples in writer.log) sys.stderr.write("> completed reads: %s\n" % len(writer.log)) sys.stderr.write("> duration: %s\n" % timedelta(seconds=np.round(duration))) sys.stderr.write("> samples per second %.1E\n" % (num_samples / duration)) sys.stderr.write("> done\n")
def main(args): sys.stderr.write("> loading model\n") model = load_model(args.model_directory, args.device, weights=int(args.weights), half=args.half) samples = 0 num_reads = 0 max_read_size = 1e9 dtype = np.float16 if args.half else np.float32 reader = PreprocessReader(args.reads_directory) writer = DecoderWriterRevised(model.alphabet, args.beamsize, args.kmer_length, args.hdf5_filename) # writer = DecoderWriter(model.alphabet, args.beamsize) t0 = time.perf_counter() # sys.stderr.write("> calling\n") # with reader, torch.no_grad(): with writer, reader, torch.no_grad(): while True: read = reader.queue.get() if read is None: break read_id, raw_data = read if len(raw_data) > max_read_size: sys.stderr.write("> skipping %s: %s too long\n" % (len(raw_data), read_id)) pass num_reads += 1 samples += len(raw_data) signal_data = raw_data raw_data = raw_data[np.newaxis, np.newaxis, :].astype(dtype) gpu_data = torch.tensor(raw_data).to(args.device) posteriors = model(gpu_data).exp().cpu().numpy().squeeze() # writer.queue.put((read_id, posteriors)) # sys.stderr.write("\n> idx: %s\tcurrent read: %s" % (num_reads, read_id)) writer.queue.put((read_id, posteriors, signal_data)) duration = time.perf_counter() - t0 sys.stderr.write("> completed reads: %s\n" % num_reads) sys.stderr.write("> total duration : %ss\n" % duration) sys.stderr.write("> samples per second %.1E\n" % (samples / duration)) sys.stderr.write("> done\n")
def main(args): sys.stderr.write("> loading model\n") model = load_model(args.model_directory, args.device, weights=int(args.weights), half=args.half) samples = 0 num_reads = 0 max_read_size = 4e6 dtype = np.float16 if args.half else np.float32 reader = PreprocessReader(args.reads_directory) writer = DecoderWriter(model, beamsize=args.beamsize, fastq=args.fastq) t0 = time.perf_counter() sys.stderr.write("> calling\n") with writer, reader, torch.no_grad(): while True: read = reader.queue.get() if read is None: break read_id, raw_data = read if len(raw_data) > max_read_size: sys.stderr.write("> skipping long read %s (%s samples)\n" % (read_id, len(raw_data))) continue num_reads += 1 samples += len(raw_data) raw_data = raw_data[np.newaxis, np.newaxis, :].astype(dtype) gpu_data = torch.tensor(raw_data).to(args.device) posteriors = model(gpu_data).exp().cpu().numpy().squeeze() writer.queue.put((read_id, posteriors.astype(np.float32))) duration = time.perf_counter() - t0 sys.stderr.write("> completed reads: %s\n" % num_reads) sys.stderr.write("> samples per second %.1E\n" % (samples / duration)) sys.stderr.write("> done\n")
def main(args): poas = [] init(args.seed, args.device) print("* loading data") testdata = ChunkDataSet( *load_data( limit=args.chunks, shuffle=args.shuffle, directory=args.directory, validation=True ) ) dataloader = DataLoader(testdata, batch_size=args.batchsize) accuracy_with_cov = lambda ref, seq: accuracy(ref, seq, min_coverage=args.min_coverage) for w in [int(i) for i in args.weights.split(',')]: seqs = [] print("* loading model", w) model = load_model(args.model_directory, args.device, weights=w) print("* calling") t0 = time.perf_counter() with torch.no_grad(): for data, *_ in dataloader: if half_supported(): data = data.type(torch.float16).to(args.device) else: data = data.to(args.device) log_probs = model(data) if hasattr(model, 'decode_batch'): seqs.extend(model.decode_batch(log_probs)) else: seqs.extend([model.decode(p) for p in permute(log_probs, 'TNC', 'NTC')]) duration = time.perf_counter() - t0 refs = [decode_ref(target, model.alphabet) for target in dataloader.dataset.targets] accuracies = [accuracy_with_cov(ref, seq) if len(seq) else 0. for ref, seq in zip(refs, seqs)] if args.poa: poas.append(sequences) print("* mean %.2f%%" % np.mean(accuracies)) print("* median %.2f%%" % np.median(accuracies)) print("* time %.2f" % duration) print("* samples/s %.2E" % (args.chunks * data.shape[2] / duration)) if args.poa: print("* doing poa") t0 = time.perf_counter() # group each sequence prediction per model together poas = [list(seq) for seq in zip(*poas)] consensuses = poa(poas) duration = time.perf_counter() - t0 accuracies = list(starmap(accuracy_with_coverage_filter, zip(references, consensuses))) print("* mean %.2f%%" % np.mean(accuracies)) print("* median %.2f%%" % np.median(accuracies)) print("* time %.2f" % duration)
def main(args): init(args.seed, args.device) if args.model_directory in models and args.model_directory not in os.listdir( __models__): sys.stderr.write("> downloading model\n") File(__models__, models[args.model_directory]).download() sys.stderr.write(f"> loading model {args.model_directory}\n") try: model = load_model( args.model_directory, args.device, weights=int(args.weights), chunksize=args.chunksize, overlap=args.overlap, batchsize=args.batchsize, quantize=args.quantize, use_koi=True, ) except FileNotFoundError: sys.stderr.write(f"> error: failed to load {args.model_directory}\n") sys.stderr.write(f"> available models:\n") for model in sorted(models): sys.stderr.write(f" - {model}\n") exit(1) if args.verbose: sys.stderr.write( f"> model basecaller params: {model.config['basecaller']}\n") basecall = load_symbol(args.model_directory, "basecall") mods_model = None if args.modified_base_model is not None or args.modified_bases is not None: sys.stderr.write("> loading modified base model\n") mods_model = load_mods_model(args.modified_bases, args.model_directory, args.modified_base_model) sys.stderr.write(f"> {mods_model[1]['alphabet_str']}\n") if args.reference: sys.stderr.write("> loading reference\n") aligner = Aligner(args.reference, preset='ont-map', best_n=1) if not aligner: sys.stderr.write("> failed to load/build index\n") exit(1) else: aligner = None fmt = biofmt(aligned=args.reference is not None) if args.reference and args.reference.endswith( ".mmi") and fmt.name == "cram": sys.stderr.write( "> error: reference cannot be a .mmi when outputting cram\n") exit(1) elif args.reference and fmt.name == "fastq": sys.stderr.write( f"> warning: did you really want {fmt.aligned} {fmt.name}?\n") else: sys.stderr.write(f"> outputting {fmt.aligned} {fmt.name}\n") if args.save_ctc and not args.reference: sys.stderr.write( "> a reference is needed to output ctc training data\n") exit(1) if fmt.name != 'fastq': groups = get_read_groups(args.reads_directory, args.model_directory, n_proc=8, recursive=args.recursive, read_ids=column_to_set(args.read_ids), skip=args.skip, cancel=process_cancel()) else: groups = [] reads = get_reads(args.reads_directory, n_proc=8, recursive=args.recursive, read_ids=column_to_set(args.read_ids), skip=args.skip, cancel=process_cancel()) if args.max_reads: reads = take(reads, args.max_reads) if args.save_ctc: reads = (chunk for read in reads for chunk in read_chunks( read, chunksize=model.config["basecaller"]["chunksize"], overlap=model.config["basecaller"]["overlap"])) ResultsWriter = CTCWriter else: ResultsWriter = Writer results = basecall(model, reads, reverse=args.revcomp, batchsize=model.config["basecaller"]["batchsize"], chunksize=model.config["basecaller"]["chunksize"], overlap=model.config["basecaller"]["overlap"]) if mods_model is not None: results = process_itemmap(partial(call_mods, mods_model), results) if aligner: results = align_map(aligner, results, n_thread=os.cpu_count()) writer = ResultsWriter( fmt.mode, tqdm(results, desc="> calling", unit=" reads", leave=False), aligner=aligner, group_key=args.model_directory, ref_fn=args.reference, groups=groups, ) t0 = perf_counter() writer.start() writer.join() duration = perf_counter() - t0 num_samples = sum(num_samples for read_id, num_samples in writer.log) sys.stderr.write("> completed reads: %s\n" % len(writer.log)) sys.stderr.write("> duration: %s\n" % timedelta(seconds=np.round(duration))) sys.stderr.write("> samples per second %.1E\n" % (num_samples / duration)) sys.stderr.write("> done\n")
def main(args): if args.save_ctc and not args.reference: sys.stderr.write("> a reference is needed to output ctc training data\n") exit(1) if args.save_ctc: args.overlap = 900 args.chunksize = 3600 sys.stderr.write("> loading model\n") model = load_model( args.model_directory, args.device, weights=int(args.weights), half=args.half, chunksize=args.chunksize, use_rt=args.cudart, ) if args.reference: sys.stderr.write("> loading reference\n") aligner = Aligner(args.reference, preset='ont-map') if not aligner: sys.stderr.write("> failed to load/build index\n") sys.exit(1) else: aligner = None samples = 0 num_reads = 0 max_read_size = 4e6 dtype = np.float16 if args.half else np.float32 ctc_writer = CTCWriter(model, aligner) reader = PreprocessReader(args.reads_directory) writer = DecoderWriterPool(model, beamsize=args.beamsize, fastq=args.fastq, aligner=aligner) t0 = time.perf_counter() sys.stderr.write("> calling\n") with writer, ctc_writer, reader, torch.no_grad(): while True: read = reader.queue.get() if read is None: break if len(read.signal) > max_read_size: sys.stderr.write("> skipping long read %s (%s samples)\n" % (read.read_id, len(read.signal))) continue num_reads += 1 samples += len(read.signal) raw_data = torch.tensor(read.signal.astype(dtype)) chunks = chunk(raw_data, args.chunksize, args.overlap) posteriors_ = model(chunks.to(args.device)).cpu().numpy() posteriors = stitch(posteriors_, args.overlap // model.stride // 2) writer.queue.put((read, posteriors[:raw_data.shape[0]])) if args.save_ctc and len(raw_data) > args.chunksize: ctc_writer.queue.put((chunks.numpy(), posteriors_)) duration = time.perf_counter() - t0 sys.stderr.write("> completed reads: %s\n" % num_reads) sys.stderr.write("> duration: %s\n" % timedelta(seconds=np.round(duration))) sys.stderr.write("> samples per second %.1E\n" % (samples / duration)) sys.stderr.write("> done\n")
def main(args): samples = 0 num_pairs = 0 max_read_size = 4e6 dtype = np.float16 if half_supported() else np.float32 if args.index is not None: sys.stderr.write("> loading read index\n") index = json.load(open(args.index, 'r')) else: sys.stderr.write("> building read index\n") files = list(glob(os.path.join(args.reads_directory, '*.fast5'))) index = build_index(files) if args.save_index: with open('bonito-read-id.idx', 'w') as f: json.dump(index, f) sys.stderr.write("> loading model\n") model_temp = load_model(args.temp_model_directory, args.device) model_comp = load_model(args.comp_model_directory, args.device) decoders = PairDecoderWriterPool(model_temp.alphabet, procs=args.num_procs) t0 = time.perf_counter() sys.stderr.write("> calling\n") with torch.no_grad(), open(args.pairs_file) as pairs, decoders: for pair in tqdm(pairs, ascii=True, ncols=100): read_id_1, read_id_2 = pair.strip().split(args.sep) if read_id_1 not in index or read_id_2 not in index: continue read_1 = get_raw_data_for_read( os.path.join(args.reads_directory, index[read_id_1]), read_id_1) raw_data_1 = read_1.signal if len(raw_data_1) > max_read_size: sys.stderr.write("> skipping long read %s (%s samples)\n" % (read_id_1, len(raw_data_1))) continue read_2 = get_raw_data_for_read( os.path.join(args.reads_directory, index[read_id_2]), read_id_2) raw_data_2 = read_2.signal if len(raw_data_2) > max_read_size: sys.stderr.write("> skipping long read %s (%s samples)\n" % (read_id_2, len(raw_data_2))) continue # call the template strand raw_data_1 = raw_data_1[np.newaxis, np.newaxis, :].astype(dtype) gpu_data_1 = torch.tensor(raw_data_1).to(args.device) logits_1 = model_temp(gpu_data_1).cpu().numpy().squeeze().astype( np.float32) # call the complement strand raw_data_2 = raw_data_2[np.newaxis, np.newaxis, :].astype(dtype) gpu_data_2 = torch.tensor(raw_data_2).to(args.device) logits_2 = model_comp(gpu_data_2).cpu().numpy().squeeze().astype( np.float32) num_pairs += 1 samples += raw_data_1.shape[-1] + raw_data_2.shape[-1] # pair decode decoders.queue.put((read_id_1, logits_1, read_id_2, logits_2)) duration = time.perf_counter() - t0 sys.stderr.write("> completed pairs: %s\n" % num_pairs) sys.stderr.write("> samples per second %.1E\n" % (samples / duration)) sys.stderr.write("> done\n")
def main(args): workdir = os.path.expanduser(args.training_directory) if os.path.exists(workdir) and not args.force: print("[error] %s exists, use -f to force continue training." % workdir) exit(1) init(args.seed, args.device) device = torch.device(args.device) print("[loading data]") train_data = load_data(limit=args.chunks, directory=args.directory) if os.path.exists(os.path.join(args.directory, 'validation')): valid_data = load_data( directory=os.path.join(args.directory, 'validation')) else: print("[validation set not found: splitting training set]") split = np.floor(len(train_data[0]) * 0.97).astype(np.int32) valid_data = [x[split:] for x in train_data] train_data = [x[:split] for x in train_data] train_loader = DataLoader(ChunkDataSet(*train_data), batch_size=args.batch, shuffle=True, num_workers=4, pin_memory=True) valid_loader = DataLoader(ChunkDataSet(*valid_data), batch_size=args.batch, num_workers=4, pin_memory=True) config = toml.load(args.config) argsdict = dict(training=vars(args)) chunk_config = {} chunk_config_file = os.path.join(args.directory, 'config.toml') if os.path.isfile(chunk_config_file): chunk_config = toml.load(os.path.join(chunk_config_file)) os.makedirs(workdir, exist_ok=True) toml.dump({ **config, **argsdict, **chunk_config }, open(os.path.join(workdir, 'config.toml'), 'w')) print("[loading model]") if args.pretrained: print("[using pretrained model {}]".format(args.pretrained)) model = load_model(args.pretrained, device, half=False) else: model = load_symbol(config, 'Model')(config) optimizer = AdamW(model.parameters(), amsgrad=False, lr=args.lr) last_epoch = load_state(workdir, args.device, model, optimizer, use_amp=args.amp) lr_scheduler = func_scheduler(optimizer, cosine_decay_schedule(1.0, 0.1), args.epochs * len(train_loader), warmup_steps=500, start_step=last_epoch * len(train_loader)) if args.multi_gpu: from torch.nn import DataParallel model = DataParallel(model) model.decode = model.module.decode model.alphabet = model.module.alphabet if hasattr(model, 'seqdist'): criterion = model.seqdist.ctc_loss else: criterion = None for epoch in range(1 + last_epoch, args.epochs + 1 + last_epoch): try: with CSVLogger(os.path.join( workdir, 'losses_{}.csv'.format(epoch))) as loss_log: train_loss, duration = train(model, device, train_loader, optimizer, criterion=criterion, use_amp=args.amp, lr_scheduler=lr_scheduler, loss_log=loss_log) model_state = model.state_dict( ) if not args.multi_gpu else model.module.state_dict() torch.save(model_state, os.path.join(workdir, "weights_%s.tar" % epoch)) val_loss, val_mean, val_median = test(model, device, valid_loader, criterion=criterion) except KeyboardInterrupt: break print( "[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%" .format(epoch, workdir, val_loss, val_mean, val_median)) with CSVLogger(os.path.join(workdir, 'training.csv')) as training_log: training_log.append( OrderedDict([('time', datetime.today()), ('duration', int(duration)), ('epoch', epoch), ('train_loss', train_loss), ('validation_loss', val_loss), ('validation_mean', val_mean), ('validation_median', val_median)]))
def main(args): workdir = os.path.expanduser(args.training_directory) if os.path.exists(workdir) and not args.force: print("[error] %s exists, use -f to force continue training." % workdir) exit(1) init(args.seed, args.device, (not args.nondeterministic)) device = torch.device(args.device) print("[loading data]") try: train_loader_kwargs, valid_loader_kwargs = load_numpy( args.chunks, args.directory) except FileNotFoundError: train_loader_kwargs, valid_loader_kwargs = load_script( args.directory, seed=args.seed, chunks=args.chunks, valid_chunks=args.valid_chunks) loader_kwargs = { "batch_size": args.batch, "num_workers": 4, "pin_memory": True } train_loader = DataLoader(**loader_kwargs, **train_loader_kwargs) valid_loader = DataLoader(**loader_kwargs, **valid_loader_kwargs) if args.pretrained: dirname = args.pretrained if not os.path.isdir(dirname) and os.path.isdir( os.path.join(__models__, dirname)): dirname = os.path.join(__models__, dirname) config_file = os.path.join(dirname, 'config.toml') else: config_file = args.config config = toml.load(config_file) argsdict = dict(training=vars(args)) os.makedirs(workdir, exist_ok=True) toml.dump({ **config, **argsdict }, open(os.path.join(workdir, 'config.toml'), 'w')) print("[loading model]") if args.pretrained: print("[using pretrained model {}]".format(args.pretrained)) model = load_model(args.pretrained, device, half=False) else: model = load_symbol(config, 'Model')(config) if config.get("lr_scheduler"): sched_config = config["lr_scheduler"] lr_scheduler_fn = getattr(import_module(sched_config["package"]), sched_config["symbol"])(**sched_config) else: lr_scheduler_fn = None trainer = Trainer(model, device, train_loader, valid_loader, use_amp=half_supported() and not args.no_amp, lr_scheduler_fn=lr_scheduler_fn, restore_optim=args.restore_optim, save_optim_every=args.save_optim_every, grad_accum_split=args.grad_accum_split) if (',' in args.lr): lr = [float(x) for x in args.lr.split(',')] else: lr = float(args.lr) trainer.fit(workdir, args.epochs, lr)
def main(args): sys.stderr.write("> loading model\n") model = load_model(args.model, args.device) if args.reference: sys.stderr.write("> loading reference\n") aligner = Aligner(args.reference, preset='ont-map') if not aligner: sys.stderr.write("> failed to load/build index\n") exit(1) else: aligner = None if args.summary: sys.stderr.write("> finding follow on strands\n") pairs = pd.read_csv(args.summary, '\t', low_memory=False) pairs = pairs[pairs.sequence_length_template.gt(0)] if 'filename' in pairs.columns: pairs = pairs.rename(columns={'filename': 'filename_fast5'}) if 'alignment_strand_coverage' in pairs.columns: pairs = pairs.rename( columns={'alignment_strand_coverage': 'alignment_coverage'}) valid_fast5s = [ f for f in pairs.filename_fast5.unique() if ((args.reads_directory / Path(f)).exists()) ] pairs = pairs[pairs.filename_fast5.isin(valid_fast5s)] pairs = find_follow_on(pairs) sys.stderr.write("> found %s follow strands in summary\n" % (len(pairs) // 2)) if args.max_reads > 0: pairs = pairs.head(args.max_reads) temp_reads = pairs.iloc[0::2] comp_reads = pairs.iloc[1::2] else: if args.index is not None: sys.stderr.write("> loading read index\n") index = json.load(open(args.index, 'r')) else: sys.stderr.write("> building read index\n") files = list(glob(os.path.join(args.reads_directory, '*.fast5'))) index = build_index(files, n_proc=8) if args.save_index: with open('bonito-read-id.idx', 'w') as f: json.dump(index, f) pairs = pd.read_csv(args.pairs, sep=args.sep, names=['read_1', 'read_2']) if args.max_reads > 0: pairs = pairs.head(args.max_reads) pairs['file_1'] = pairs['read_1'].apply(index.get) pairs['file_2'] = pairs['read_2'].apply(index.get) pairs = pairs.dropna().reset_index() temp_reads = pairs[['read_1', 'file_1']].rename(columns={ 'read_1': 'read_id', 'file_1': 'filename_fast5' }) comp_reads = pairs[['read_2', 'file_2']].rename(columns={ 'read_2': 'read_id', 'file_2': 'filename_fast5' }) if len(pairs) == 0: print("> no matched pairs found in given directory", file=sys.stderr) exit(1) # https://github.com/clara-parabricks/GenomeWorks/issues/648 with devnull(): CudaPoaBatch(1000, 1000, 3724032) basecalls = call(model, args.reads_directory, temp_reads, comp_reads, aligner=aligner) writer = Writer(tqdm(basecalls, desc="> calling", unit=" reads", leave=False), aligner, duplex=True) t0 = perf_counter() writer.start() writer.join() duration = perf_counter() - t0 num_samples = sum(num_samples for read_id, num_samples in writer.log) print("> duration: %s" % timedelta(seconds=np.round(duration)), file=sys.stderr) print("> samples per second %.1E" % (num_samples / duration), file=sys.stderr)
def main(args): if args.save_ctc and not args.reference: sys.stderr.write( "> a reference is needed to output ctc training data\n") exit(1) if args.save_ctc: args.overlap = 900 args.chunksize = 3600 sys.stderr.write("> loading model\n") model = load_model( args.model_directory, args.device, weights=int(args.weights), half=args.half, chunksize=args.chunksize, use_rt=args.cudart, ) if args.reference: sys.stderr.write("> loading reference\n") aligner = Aligner(args.reference, preset='ont-map') if not aligner: sys.stderr.write("> failed to load/build index\n") sys.exit(1) write_sam_header(aligner) else: aligner = None # with open(summary_file(), 'w') as summary: # write_summary_header(summary, alignment=aligner) samples = 0 num_reads = 0 max_read_size = 4e6 read_ids = column_to_set(args.read_ids) dtype = np.float16 if args.half else np.float32 reader = ProcessIterator(get_reads(args.reads_directory, read_ids=read_ids, skip=args.skip), progress=True) writer = ProcessPool(DecoderWriter, model=model, aligner=aligner, beamsize=args.beamsize, fastq=args.fastq) ctc_writer = CTCWriter(model, aligner, min_coverage=args.ctc_min_coverage, min_accuracy=args.ctc_min_accuracy) t0 = time.perf_counter() sys.stderr.write("> calling\n") with writer, ctc_writer, reader, torch.no_grad(): while True: read = reader.queue.get() if read is None: break if len(read.signal) > max_read_size: sys.stderr.write("> skipping long read %s (%s samples)\n" % (read.read_id, len(read.signal))) continue num_reads += 1 samples += len(read.signal) raw_data = torch.tensor(read.signal.astype(dtype)) print('bonito: raw_data.shape: ', raw_data.shape) chunks = chunk(raw_data, args.chunksize, args.overlap) posteriors_ = model(chunks.to(args.device)).cpu().numpy() posteriors = stitch(posteriors_, args.overlap // model.stride // 2) if args.write_basecall: writer.queue.put((read, posteriors[:raw_data.shape[0]])) if args.save_ctc and len(raw_data) > args.chunksize: ctc_writer.queue.put((chunks.numpy(), posteriors_)) print('bonito: posteriors.shape', posteriors.shape) posteriors.tofile(args.post_file) duration = time.perf_counter() - t0 sys.stderr.write("> completed reads: %s\n" % num_reads) sys.stderr.write("> duration: %s\n" % timedelta(seconds=np.round(duration))) sys.stderr.write("> samples per second %.1E\n" % (samples / duration)) sys.stderr.write("> done\n")