def compute_scores(model, batch, beam_width=32, beam_cut=100.0, scale=1.0, offset=0.0, blank_score=2.0, reverse=False): """ Compute scores for model. """ with torch.inference_mode(): device = next(model.parameters()).device dtype = torch.float16 if half_supported() else torch.float32 scores = model(batch.to(dtype).to(device)) if reverse: scores = model.seqdist.reverse_complement(scores) sequence, qstring, moves = beam_search(scores, beam_width=beam_width, beam_cut=beam_cut, scale=scale, offset=offset, blank_score=blank_score) return { 'qstring': qstring, 'sequence': sequence, 'moves': np.array(moves, dtype=bool), }
def compute_scores(model, batch, reverse=False): with torch.no_grad(): device = next(model.parameters()).device dtype = torch.float16 if half_supported() else torch.float32 scores = model.encoder(batch.to(dtype).to(device)) if reverse: scores = model.seqdist.reverse_complement(scores) betas = model.seqdist.backward_scores(scores.to(torch.float32)) trans, init = model.seqdist.compute_transition_probs(scores, betas) return { 'trans': trans.to(dtype).transpose(0, 1), 'init': init.to(dtype).unsqueeze(1), }
def compute_scores(model, batch): """ Compute scores for model. """ with torch.no_grad(): device = next(model.parameters()).device dtype = torch.float16 if half_supported() else torch.float32 scores = model.encoder(batch.to(dtype).to(device)) betas = model.seqdist.backward_scores(scores.to(torch.float32)) betas -= (betas.max(2, keepdim=True)[0] - 5.0) return { 'scores': scores.transpose(0, 1), 'betas': betas.transpose(0, 1), }
def argparser(): parser = ArgumentParser( formatter_class=ArgumentDefaultsHelpFormatter, add_help=False ) parser.add_argument("model_directory") parser.add_argument("--directory", default=None) parser.add_argument("--device", default="cuda") parser.add_argument("--half", action="store_true", default=half_supported()) parser.add_argument("--seed", default=9, type=int) parser.add_argument("--weights", default="0", type=str) parser.add_argument("--chunks", default=500, type=int) parser.add_argument("--batchsize", default=100, type=int) parser.add_argument("--beamsize", default=5, type=int) parser.add_argument("--poa", action="store_true", default=False) parser.add_argument("--shuffle", action="store_true", default=True) return parser
def argparser(): parser = ArgumentParser( formatter_class=ArgumentDefaultsHelpFormatter, add_help=False ) parser.add_argument("model_directory") parser.add_argument("reads_directory") parser.add_argument("--reference") parser.add_argument("--device", default="cuda") parser.add_argument("--weights", default="0", type=str) parser.add_argument("--beamsize", default=5, type=int) parser.add_argument("--chunksize", default=0, type=int) parser.add_argument("--overlap", default=0, type=int) parser.add_argument("--half", action="store_true", default=half_supported()) parser.add_argument("--fastq", action="store_true", default=False) parser.add_argument("--cudart", action="store_true", default=False) parser.add_argument("--save-ctc", action="store_true", default=False) return parser
def argparser(): parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter, add_help=False) parser.add_argument("model_directory") parser.add_argument("reads_directory") parser.add_argument("--reference") parser.add_argument("--read-ids") parser.add_argument("--device", default="cuda") parser.add_argument("--weights", default="0", type=str) parser.add_argument("--beamsize", default=5, type=int) parser.add_argument("--post_file", type=str, required=True) parser.add_argument("--write_basecall", default=False, action='store_true') parser.add_argument("--chunksize", default=0, type=int) parser.add_argument("--overlap", default=0, type=int) parser.add_argument("--half", action="store_true", default=half_supported()) parser.add_argument("--skip", action="store_true", default=False) parser.add_argument("--fastq", action="store_true", default=False) parser.add_argument("--cudart", action="store_true", default=False) parser.add_argument("--save-ctc", action="store_true", default=False) parser.add_argument("--ctc-min-coverage", default=0.9, type=float) parser.add_argument("--ctc-min-accuracy", default=0.9, type=float) return parser
def main(args): poas = [] init(args.seed, args.device) print("* loading data") testdata = ChunkDataSet( *load_data( limit=args.chunks, shuffle=args.shuffle, directory=args.directory, validation=True ) ) dataloader = DataLoader(testdata, batch_size=args.batchsize) accuracy_with_cov = lambda ref, seq: accuracy(ref, seq, min_coverage=args.min_coverage) for w in [int(i) for i in args.weights.split(',')]: seqs = [] print("* loading model", w) model = load_model(args.model_directory, args.device, weights=w) print("* calling") t0 = time.perf_counter() with torch.no_grad(): for data, *_ in dataloader: if half_supported(): data = data.type(torch.float16).to(args.device) else: data = data.to(args.device) log_probs = model(data) if hasattr(model, 'decode_batch'): seqs.extend(model.decode_batch(log_probs)) else: seqs.extend([model.decode(p) for p in permute(log_probs, 'TNC', 'NTC')]) duration = time.perf_counter() - t0 refs = [decode_ref(target, model.alphabet) for target in dataloader.dataset.targets] accuracies = [accuracy_with_cov(ref, seq) if len(seq) else 0. for ref, seq in zip(refs, seqs)] if args.poa: poas.append(sequences) print("* mean %.2f%%" % np.mean(accuracies)) print("* median %.2f%%" % np.median(accuracies)) print("* time %.2f" % duration) print("* samples/s %.2E" % (args.chunks * data.shape[2] / duration)) if args.poa: print("* doing poa") t0 = time.perf_counter() # group each sequence prediction per model together poas = [list(seq) for seq in zip(*poas)] consensuses = poa(poas) duration = time.perf_counter() - t0 accuracies = list(starmap(accuracy_with_coverage_filter, zip(references, consensuses))) print("* mean %.2f%%" % np.mean(accuracies)) print("* median %.2f%%" % np.median(accuracies)) print("* time %.2f" % duration)
def main(args): samples = 0 num_pairs = 0 max_read_size = 4e6 dtype = np.float16 if half_supported() else np.float32 if args.index is not None: sys.stderr.write("> loading read index\n") index = json.load(open(args.index, 'r')) else: sys.stderr.write("> building read index\n") files = list(glob(os.path.join(args.reads_directory, '*.fast5'))) index = build_index(files) if args.save_index: with open('bonito-read-id.idx', 'w') as f: json.dump(index, f) sys.stderr.write("> loading model\n") model_temp = load_model(args.temp_model_directory, args.device) model_comp = load_model(args.comp_model_directory, args.device) decoders = PairDecoderWriterPool(model_temp.alphabet, procs=args.num_procs) t0 = time.perf_counter() sys.stderr.write("> calling\n") with torch.no_grad(), open(args.pairs_file) as pairs, decoders: for pair in tqdm(pairs, ascii=True, ncols=100): read_id_1, read_id_2 = pair.strip().split(args.sep) if read_id_1 not in index or read_id_2 not in index: continue read_1 = get_raw_data_for_read( os.path.join(args.reads_directory, index[read_id_1]), read_id_1) raw_data_1 = read_1.signal if len(raw_data_1) > max_read_size: sys.stderr.write("> skipping long read %s (%s samples)\n" % (read_id_1, len(raw_data_1))) continue read_2 = get_raw_data_for_read( os.path.join(args.reads_directory, index[read_id_2]), read_id_2) raw_data_2 = read_2.signal if len(raw_data_2) > max_read_size: sys.stderr.write("> skipping long read %s (%s samples)\n" % (read_id_2, len(raw_data_2))) continue # call the template strand raw_data_1 = raw_data_1[np.newaxis, np.newaxis, :].astype(dtype) gpu_data_1 = torch.tensor(raw_data_1).to(args.device) logits_1 = model_temp(gpu_data_1).cpu().numpy().squeeze().astype( np.float32) # call the complement strand raw_data_2 = raw_data_2[np.newaxis, np.newaxis, :].astype(dtype) gpu_data_2 = torch.tensor(raw_data_2).to(args.device) logits_2 = model_comp(gpu_data_2).cpu().numpy().squeeze().astype( np.float32) num_pairs += 1 samples += raw_data_1.shape[-1] + raw_data_2.shape[-1] # pair decode decoders.queue.put((read_id_1, logits_1, read_id_2, logits_2)) duration = time.perf_counter() - t0 sys.stderr.write("> completed pairs: %s\n" % num_pairs) sys.stderr.write("> samples per second %.1E\n" % (samples / duration)) sys.stderr.write("> done\n")
def main(args): workdir = os.path.expanduser(args.training_directory) if os.path.exists(workdir) and not args.force: print("[error] %s exists, use -f to force continue training." % workdir) exit(1) init(args.seed, args.device) device = torch.device(args.device) print("[loading data]") train_data = load_data(limit=args.chunks, directory=args.directory) if os.path.exists(os.path.join(args.directory, 'validation')): valid_data = load_data( directory=os.path.join(args.directory, 'validation')) else: print("[validation set not found: splitting training set]") split = np.floor(len(train_data[0]) * 0.97).astype(np.int32) valid_data = [x[split:] for x in train_data] train_data = [x[:split] for x in train_data] train_loader = DataLoader(ChunkDataSet(*train_data), batch_size=args.batch, shuffle=True, num_workers=4, pin_memory=True) valid_loader = DataLoader(ChunkDataSet(*valid_data), batch_size=args.batch, num_workers=4, pin_memory=True) if args.pretrained: dirname = args.pretrained if not os.path.isdir(dirname) and os.path.isdir( os.path.join(__models__, dirname)): dirname = os.path.join(__models__, dirname) config_file = os.path.join(dirname, 'config.toml') else: config_file = args.config config = toml.load(config_file) argsdict = dict(training=vars(args)) chunk_config = {} chunk_config_file = os.path.join(args.directory, 'config.toml') if os.path.isfile(chunk_config_file): chunk_config = toml.load(os.path.join(chunk_config_file)) os.makedirs(workdir, exist_ok=True) toml.dump({ **config, **argsdict, **chunk_config }, open(os.path.join(workdir, 'config.toml'), 'w')) print("[loading model]") if args.pretrained: print("[using pretrained model {}]".format(args.pretrained)) model = load_model(args.pretrained, device, half=False) else: model = load_symbol(config, 'Model')(config) optimizer = AdamW(model.parameters(), amsgrad=False, lr=args.lr) scaler = GradScaler(enabled=half_supported() and not args.no_amp) last_epoch = load_state(workdir, args.device, model, optimizer, use_amp=not args.no_amp) lr_scheduler = func_scheduler(optimizer, cosine_decay_schedule(1.0, 0.1), args.epochs * len(train_loader), warmup_steps=500, start_step=last_epoch * len(train_loader)) if args.multi_gpu: from torch.nn import DataParallel model = DataParallel(model) model.decode = model.module.decode model.alphabet = model.module.alphabet if hasattr(model, 'seqdist'): criterion = model.seqdist.ctc_loss else: criterion = None for epoch in range(1 + last_epoch, args.epochs + 1 + last_epoch): try: with CSVLogger(os.path.join( workdir, 'losses_{}.csv'.format(epoch))) as loss_log: train_loss, duration = train(model, device, train_loader, optimizer, criterion=criterion, use_amp=half_supported() and not args.no_amp, scaler=scaler, lr_scheduler=lr_scheduler, loss_log=loss_log) model_state = model.state_dict( ) if not args.multi_gpu else model.module.state_dict() torch.save(model_state, os.path.join(workdir, "weights_%s.tar" % epoch)) val_loss, val_mean, val_median = test(model, device, valid_loader, criterion=criterion) except KeyboardInterrupt: break print( "[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%" .format(epoch, workdir, val_loss, val_mean, val_median)) with CSVLogger(os.path.join(workdir, 'training.csv')) as training_log: training_log.append( OrderedDict([('time', datetime.today()), ('duration', int(duration)), ('epoch', epoch), ('train_loss', train_loss), ('validation_loss', val_loss), ('validation_mean', val_mean), ('validation_median', val_median)]))
def main(args): workdir = os.path.expanduser(args.training_directory) if os.path.exists(workdir) and not args.force: print("[error] %s exists, use -f to force continue training." % workdir) exit(1) init(args.seed, args.device, (not args.nondeterministic)) device = torch.device(args.device) print("[loading data]") try: train_loader_kwargs, valid_loader_kwargs = load_numpy( args.chunks, args.directory) except FileNotFoundError: train_loader_kwargs, valid_loader_kwargs = load_script( args.directory, seed=args.seed, chunks=args.chunks, valid_chunks=args.valid_chunks) loader_kwargs = { "batch_size": args.batch, "num_workers": 4, "pin_memory": True } train_loader = DataLoader(**loader_kwargs, **train_loader_kwargs) valid_loader = DataLoader(**loader_kwargs, **valid_loader_kwargs) if args.pretrained: dirname = args.pretrained if not os.path.isdir(dirname) and os.path.isdir( os.path.join(__models__, dirname)): dirname = os.path.join(__models__, dirname) config_file = os.path.join(dirname, 'config.toml') else: config_file = args.config config = toml.load(config_file) argsdict = dict(training=vars(args)) os.makedirs(workdir, exist_ok=True) toml.dump({ **config, **argsdict }, open(os.path.join(workdir, 'config.toml'), 'w')) print("[loading model]") if args.pretrained: print("[using pretrained model {}]".format(args.pretrained)) model = load_model(args.pretrained, device, half=False) else: model = load_symbol(config, 'Model')(config) if config.get("lr_scheduler"): sched_config = config["lr_scheduler"] lr_scheduler_fn = getattr(import_module(sched_config["package"]), sched_config["symbol"])(**sched_config) else: lr_scheduler_fn = None trainer = Trainer(model, device, train_loader, valid_loader, use_amp=half_supported() and not args.no_amp, lr_scheduler_fn=lr_scheduler_fn, restore_optim=args.restore_optim, save_optim_every=args.save_optim_every, grad_accum_split=args.grad_accum_split) if (',' in args.lr): lr = [float(x) for x in args.lr.split(',')] else: lr = float(args.lr) trainer.fit(workdir, args.epochs, lr)