コード例 #1
0
def compute_scores(model,
                   batch,
                   beam_width=32,
                   beam_cut=100.0,
                   scale=1.0,
                   offset=0.0,
                   blank_score=2.0,
                   reverse=False):
    """
    Compute scores for model.
    """
    with torch.inference_mode():
        device = next(model.parameters()).device
        dtype = torch.float16 if half_supported() else torch.float32
        scores = model(batch.to(dtype).to(device))
        if reverse:
            scores = model.seqdist.reverse_complement(scores)
        sequence, qstring, moves = beam_search(scores,
                                               beam_width=beam_width,
                                               beam_cut=beam_cut,
                                               scale=scale,
                                               offset=offset,
                                               blank_score=blank_score)
        return {
            'qstring': qstring,
            'sequence': sequence,
            'moves': np.array(moves, dtype=bool),
        }
コード例 #2
0
def compute_scores(model, batch, reverse=False):
    with torch.no_grad():
        device = next(model.parameters()).device
        dtype = torch.float16 if half_supported() else torch.float32
        scores = model.encoder(batch.to(dtype).to(device))
        if reverse: scores = model.seqdist.reverse_complement(scores)
        betas = model.seqdist.backward_scores(scores.to(torch.float32))
        trans, init = model.seqdist.compute_transition_probs(scores, betas)
    return {
        'trans': trans.to(dtype).transpose(0, 1),
        'init': init.to(dtype).unsqueeze(1),
    }
コード例 #3
0
ファイル: basecall.py プロジェクト: vellamike/bonito
def compute_scores(model, batch):
    """
    Compute scores for model.
    """
    with torch.no_grad():
        device = next(model.parameters()).device
        dtype = torch.float16 if half_supported() else torch.float32
        scores = model.encoder(batch.to(dtype).to(device))
        betas = model.seqdist.backward_scores(scores.to(torch.float32))
        betas -= (betas.max(2, keepdim=True)[0] - 5.0)
    return {
        'scores': scores.transpose(0, 1),
        'betas': betas.transpose(0, 1),
    }
コード例 #4
0
def argparser():
    parser = ArgumentParser(
        formatter_class=ArgumentDefaultsHelpFormatter,
        add_help=False
    )
    parser.add_argument("model_directory")
    parser.add_argument("--directory", default=None)
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--half", action="store_true", default=half_supported())
    parser.add_argument("--seed", default=9, type=int)
    parser.add_argument("--weights", default="0", type=str)
    parser.add_argument("--chunks", default=500, type=int)
    parser.add_argument("--batchsize", default=100, type=int)
    parser.add_argument("--beamsize", default=5, type=int)
    parser.add_argument("--poa", action="store_true", default=False)
    parser.add_argument("--shuffle", action="store_true", default=True)
    return parser
コード例 #5
0
ファイル: basecaller.py プロジェクト: TimD1/bonito
def argparser():
    parser = ArgumentParser(
        formatter_class=ArgumentDefaultsHelpFormatter,
        add_help=False
    )
    parser.add_argument("model_directory")
    parser.add_argument("reads_directory")
    parser.add_argument("--reference")
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--weights", default="0", type=str)
    parser.add_argument("--beamsize", default=5, type=int)
    parser.add_argument("--chunksize", default=0, type=int)
    parser.add_argument("--overlap", default=0, type=int)
    parser.add_argument("--half", action="store_true", default=half_supported())
    parser.add_argument("--fastq", action="store_true", default=False)
    parser.add_argument("--cudart", action="store_true", default=False)
    parser.add_argument("--save-ctc", action="store_true", default=False)
    return parser
コード例 #6
0
ファイル: basecaller.py プロジェクト: shubhamchandak94/bonito
def argparser():
    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter,
                            add_help=False)
    parser.add_argument("model_directory")
    parser.add_argument("reads_directory")
    parser.add_argument("--reference")
    parser.add_argument("--read-ids")
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--weights", default="0", type=str)
    parser.add_argument("--beamsize", default=5, type=int)
    parser.add_argument("--post_file", type=str, required=True)
    parser.add_argument("--write_basecall", default=False, action='store_true')
    parser.add_argument("--chunksize", default=0, type=int)
    parser.add_argument("--overlap", default=0, type=int)
    parser.add_argument("--half",
                        action="store_true",
                        default=half_supported())
    parser.add_argument("--skip", action="store_true", default=False)
    parser.add_argument("--fastq", action="store_true", default=False)
    parser.add_argument("--cudart", action="store_true", default=False)
    parser.add_argument("--save-ctc", action="store_true", default=False)
    parser.add_argument("--ctc-min-coverage", default=0.9, type=float)
    parser.add_argument("--ctc-min-accuracy", default=0.9, type=float)
    return parser
コード例 #7
0
ファイル: evaluate.py プロジェクト: sirelkhatim/bonito
def main(args):

    poas = []
    init(args.seed, args.device)

    print("* loading data")
    testdata = ChunkDataSet(
        *load_data(
            limit=args.chunks, shuffle=args.shuffle,
            directory=args.directory, validation=True
        )
    )
    dataloader = DataLoader(testdata, batch_size=args.batchsize)
    accuracy_with_cov = lambda ref, seq: accuracy(ref, seq, min_coverage=args.min_coverage)

    for w in [int(i) for i in args.weights.split(',')]:

        seqs = []

        print("* loading model", w)
        model = load_model(args.model_directory, args.device, weights=w)

        print("* calling")
        t0 = time.perf_counter()

        with torch.no_grad():
            for data, *_ in dataloader:
                if half_supported():
                    data = data.type(torch.float16).to(args.device)
                else:
                    data = data.to(args.device)

                log_probs = model(data)

                if hasattr(model, 'decode_batch'):
                    seqs.extend(model.decode_batch(log_probs))
                else:
                    seqs.extend([model.decode(p) for p in permute(log_probs, 'TNC', 'NTC')])

        duration = time.perf_counter() - t0

        refs = [decode_ref(target, model.alphabet) for target in dataloader.dataset.targets]
        accuracies = [accuracy_with_cov(ref, seq) if len(seq) else 0. for ref, seq in zip(refs, seqs)]

        if args.poa: poas.append(sequences)

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)
        print("* samples/s %.2E" % (args.chunks * data.shape[2] / duration))

    if args.poa:

        print("* doing poa")
        t0 = time.perf_counter()
        # group each sequence prediction per model together
        poas = [list(seq) for seq in zip(*poas)]
        consensuses = poa(poas)
        duration = time.perf_counter() - t0
        accuracies = list(starmap(accuracy_with_coverage_filter, zip(references, consensuses)))

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)
コード例 #8
0
ファイル: pair.py プロジェクト: sirelkhatim/bonito
def main(args):

    samples = 0
    num_pairs = 0
    max_read_size = 4e6
    dtype = np.float16 if half_supported() else np.float32

    if args.index is not None:
        sys.stderr.write("> loading read index\n")
        index = json.load(open(args.index, 'r'))
    else:
        sys.stderr.write("> building read index\n")
        files = list(glob(os.path.join(args.reads_directory, '*.fast5')))
        index = build_index(files)
        if args.save_index:
            with open('bonito-read-id.idx', 'w') as f:
                json.dump(index, f)

    sys.stderr.write("> loading model\n")

    model_temp = load_model(args.temp_model_directory, args.device)
    model_comp = load_model(args.comp_model_directory, args.device)

    decoders = PairDecoderWriterPool(model_temp.alphabet, procs=args.num_procs)

    t0 = time.perf_counter()
    sys.stderr.write("> calling\n")

    with torch.no_grad(), open(args.pairs_file) as pairs, decoders:

        for pair in tqdm(pairs, ascii=True, ncols=100):

            read_id_1, read_id_2 = pair.strip().split(args.sep)

            if read_id_1 not in index or read_id_2 not in index: continue

            read_1 = get_raw_data_for_read(
                os.path.join(args.reads_directory, index[read_id_1]),
                read_id_1)
            raw_data_1 = read_1.signal

            if len(raw_data_1) > max_read_size:
                sys.stderr.write("> skipping long read %s (%s samples)\n" %
                                 (read_id_1, len(raw_data_1)))
                continue

            read_2 = get_raw_data_for_read(
                os.path.join(args.reads_directory, index[read_id_2]),
                read_id_2)
            raw_data_2 = read_2.signal

            if len(raw_data_2) > max_read_size:
                sys.stderr.write("> skipping long read %s (%s samples)\n" %
                                 (read_id_2, len(raw_data_2)))
                continue

            # call the template strand
            raw_data_1 = raw_data_1[np.newaxis, np.newaxis, :].astype(dtype)
            gpu_data_1 = torch.tensor(raw_data_1).to(args.device)
            logits_1 = model_temp(gpu_data_1).cpu().numpy().squeeze().astype(
                np.float32)

            # call the complement strand
            raw_data_2 = raw_data_2[np.newaxis, np.newaxis, :].astype(dtype)
            gpu_data_2 = torch.tensor(raw_data_2).to(args.device)
            logits_2 = model_comp(gpu_data_2).cpu().numpy().squeeze().astype(
                np.float32)

            num_pairs += 1
            samples += raw_data_1.shape[-1] + raw_data_2.shape[-1]

            # pair decode
            decoders.queue.put((read_id_1, logits_1, read_id_2, logits_2))

    duration = time.perf_counter() - t0

    sys.stderr.write("> completed pairs: %s\n" % num_pairs)
    sys.stderr.write("> samples per second %.1E\n" % (samples / duration))
    sys.stderr.write("> done\n")
コード例 #9
0
def main(args):

    workdir = os.path.expanduser(args.training_directory)

    if os.path.exists(workdir) and not args.force:
        print("[error] %s exists, use -f to force continue training." %
              workdir)
        exit(1)

    init(args.seed, args.device)
    device = torch.device(args.device)

    print("[loading data]")
    train_data = load_data(limit=args.chunks, directory=args.directory)
    if os.path.exists(os.path.join(args.directory, 'validation')):
        valid_data = load_data(
            directory=os.path.join(args.directory, 'validation'))
    else:
        print("[validation set not found: splitting training set]")
        split = np.floor(len(train_data[0]) * 0.97).astype(np.int32)
        valid_data = [x[split:] for x in train_data]
        train_data = [x[:split] for x in train_data]

    train_loader = DataLoader(ChunkDataSet(*train_data),
                              batch_size=args.batch,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)
    valid_loader = DataLoader(ChunkDataSet(*valid_data),
                              batch_size=args.batch,
                              num_workers=4,
                              pin_memory=True)

    if args.pretrained:
        dirname = args.pretrained
        if not os.path.isdir(dirname) and os.path.isdir(
                os.path.join(__models__, dirname)):
            dirname = os.path.join(__models__, dirname)
        config_file = os.path.join(dirname, 'config.toml')
    else:
        config_file = args.config

    config = toml.load(config_file)

    argsdict = dict(training=vars(args))

    chunk_config = {}
    chunk_config_file = os.path.join(args.directory, 'config.toml')
    if os.path.isfile(chunk_config_file):
        chunk_config = toml.load(os.path.join(chunk_config_file))

    os.makedirs(workdir, exist_ok=True)
    toml.dump({
        **config,
        **argsdict,
        **chunk_config
    }, open(os.path.join(workdir, 'config.toml'), 'w'))

    print("[loading model]")
    if args.pretrained:
        print("[using pretrained model {}]".format(args.pretrained))
        model = load_model(args.pretrained, device, half=False)
    else:
        model = load_symbol(config, 'Model')(config)
    optimizer = AdamW(model.parameters(), amsgrad=False, lr=args.lr)

    scaler = GradScaler(enabled=half_supported() and not args.no_amp)

    last_epoch = load_state(workdir,
                            args.device,
                            model,
                            optimizer,
                            use_amp=not args.no_amp)

    lr_scheduler = func_scheduler(optimizer,
                                  cosine_decay_schedule(1.0, 0.1),
                                  args.epochs * len(train_loader),
                                  warmup_steps=500,
                                  start_step=last_epoch * len(train_loader))

    if args.multi_gpu:
        from torch.nn import DataParallel
        model = DataParallel(model)
        model.decode = model.module.decode
        model.alphabet = model.module.alphabet

    if hasattr(model, 'seqdist'):
        criterion = model.seqdist.ctc_loss
    else:
        criterion = None

    for epoch in range(1 + last_epoch, args.epochs + 1 + last_epoch):

        try:
            with CSVLogger(os.path.join(
                    workdir, 'losses_{}.csv'.format(epoch))) as loss_log:
                train_loss, duration = train(model,
                                             device,
                                             train_loader,
                                             optimizer,
                                             criterion=criterion,
                                             use_amp=half_supported()
                                             and not args.no_amp,
                                             scaler=scaler,
                                             lr_scheduler=lr_scheduler,
                                             loss_log=loss_log)

            model_state = model.state_dict(
            ) if not args.multi_gpu else model.module.state_dict()
            torch.save(model_state,
                       os.path.join(workdir, "weights_%s.tar" % epoch))

            val_loss, val_mean, val_median = test(model,
                                                  device,
                                                  valid_loader,
                                                  criterion=criterion)
        except KeyboardInterrupt:
            break

        print(
            "[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%"
            .format(epoch, workdir, val_loss, val_mean, val_median))

        with CSVLogger(os.path.join(workdir, 'training.csv')) as training_log:
            training_log.append(
                OrderedDict([('time', datetime.today()),
                             ('duration', int(duration)), ('epoch', epoch),
                             ('train_loss', train_loss),
                             ('validation_loss', val_loss),
                             ('validation_mean', val_mean),
                             ('validation_median', val_median)]))
コード例 #10
0
def main(args):

    workdir = os.path.expanduser(args.training_directory)

    if os.path.exists(workdir) and not args.force:
        print("[error] %s exists, use -f to force continue training." %
              workdir)
        exit(1)

    init(args.seed, args.device, (not args.nondeterministic))
    device = torch.device(args.device)

    print("[loading data]")
    try:
        train_loader_kwargs, valid_loader_kwargs = load_numpy(
            args.chunks, args.directory)
    except FileNotFoundError:
        train_loader_kwargs, valid_loader_kwargs = load_script(
            args.directory,
            seed=args.seed,
            chunks=args.chunks,
            valid_chunks=args.valid_chunks)

    loader_kwargs = {
        "batch_size": args.batch,
        "num_workers": 4,
        "pin_memory": True
    }
    train_loader = DataLoader(**loader_kwargs, **train_loader_kwargs)
    valid_loader = DataLoader(**loader_kwargs, **valid_loader_kwargs)

    if args.pretrained:
        dirname = args.pretrained
        if not os.path.isdir(dirname) and os.path.isdir(
                os.path.join(__models__, dirname)):
            dirname = os.path.join(__models__, dirname)
        config_file = os.path.join(dirname, 'config.toml')
    else:
        config_file = args.config

    config = toml.load(config_file)

    argsdict = dict(training=vars(args))

    os.makedirs(workdir, exist_ok=True)
    toml.dump({
        **config,
        **argsdict
    }, open(os.path.join(workdir, 'config.toml'), 'w'))

    print("[loading model]")
    if args.pretrained:
        print("[using pretrained model {}]".format(args.pretrained))
        model = load_model(args.pretrained, device, half=False)
    else:
        model = load_symbol(config, 'Model')(config)

    if config.get("lr_scheduler"):
        sched_config = config["lr_scheduler"]
        lr_scheduler_fn = getattr(import_module(sched_config["package"]),
                                  sched_config["symbol"])(**sched_config)
    else:
        lr_scheduler_fn = None

    trainer = Trainer(model,
                      device,
                      train_loader,
                      valid_loader,
                      use_amp=half_supported() and not args.no_amp,
                      lr_scheduler_fn=lr_scheduler_fn,
                      restore_optim=args.restore_optim,
                      save_optim_every=args.save_optim_every,
                      grad_accum_split=args.grad_accum_split)

    if (',' in args.lr):
        lr = [float(x) for x in args.lr.split(',')]
    else:
        lr = float(args.lr)
    trainer.fit(workdir, args.epochs, lr)