Esempio n. 1
0
def chunk_dataset(reads, chunk_len, num_chunks=None):
    all_chunks = (
        (chunk, target) for read in reads for chunk, target in
        get_chunks(reads[read], regular_break_points(len(reads[read]['Dacs']), chunk_len))
    )
    chunks, targets = zip(*tqdm(take(all_chunks, num_chunks), total=num_chunks))
    targets, target_lens = pad_lengths(targets) # convert refs from ragged arrray
    return ChunkDataSet(chunks, targets, target_lens)
Esempio n. 2
0
def main(args):

    poas = []
    init(args.seed, args.device)

    print("* loading data")
    testdata = ChunkDataSet(
        *load_data(limit=args.chunks, shuffle=args.shuffle))
    dataloader = DataLoader(testdata, batch_size=args.batchsize)

    for w in [int(i) for i in args.weights.split(',')]:

        print("* loading model", w)
        model = load_model(args.model_directory, args.device, weights=w)

        print("* calling")
        predictions = []
        t0 = time.perf_counter()

        for data, *_ in dataloader:
            with torch.no_grad():
                log_probs = model(data.to(args.device))
                predictions.append(log_probs.exp().cpu().numpy())

        duration = time.perf_counter() - t0

        references = [
            decode_ref(target, model.alphabet)
            for target in dataloader.dataset.targets
        ]
        sequences = [
            decode_ctc(post, model.alphabet)
            for post in np.concatenate(predictions)
        ]
        accuracies = list(starmap(accuracy, zip(references, sequences)))

        if args.poa: poas.append(sequences)

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)
        print("* samples/s %.2E" % (args.chunks * data.shape[2] / duration))

    if args.poa:

        print("* doing poa")
        t0 = time.perf_counter()
        # group each sequence prediction per model together
        poas = [list(seq) for seq in zip(*poas)]

        consensuses = poa(poas)
        duration = time.perf_counter() - t0

        accuracies = list(starmap(accuracy, zip(references, consensuses)))

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)
Esempio n. 3
0
    def run(self):

        chunks = []
        targets = []
        target_lens = []

        while True:

            job = self.queue.get()
            if job is None: break
            chunks_, predictions = job

            # convert logprobs to probs
            predictions = np.exp(predictions.astype(np.float32))

            for chunk, pred in zip(chunks_, predictions):

                try:
                    sequence = self.model.decode(pred)
                except:
                    continue

                if not sequence:
                    continue

                for mapping in self.aligner.map(sequence):
                    cov = (mapping.q_en - mapping.q_st) / len(sequence)
                    acc = mapping.mlen / mapping.blen
                    refseq = self.aligner.seq(mapping.ctg, mapping.r_st + 1,
                                              mapping.r_en)
                    if 'N' in refseq: continue
                    if mapping.strand == -1: refseq = revcomp(refseq)
                    break
                else:
                    continue

                if acc > self.min_accuracy and cov > self.min_accuracy:
                    chunks.append(chunk.squeeze())
                    targets.append([
                        int(x) for x in refseq.translate({
                            65: '1',
                            67: '2',
                            71: '3',
                            84: '4'
                        })
                    ])
                    target_lens.append(len(refseq))

        if len(chunks) == 0: return

        chunks = np.array(chunks, dtype=np.float32)
        chunk_lens = np.full(chunks.shape[0], chunks.shape[1], dtype=np.int16)

        targets_ = np.zeros((chunks.shape[0], max(target_lens)),
                            dtype=np.uint8)
        for idx, target in enumerate(targets):
            targets_[idx, :len(target)] = target
        target_lens = np.array(target_lens, dtype=np.uint16)

        training = ChunkDataSet(chunks, chunk_lens, targets_, target_lens)
        training = filter_chunks(training)

        output_directory = '.' if sys.stdout.isatty() else dirname(
            realpath('/dev/fd/1'))
        np.save(os.path.join(output_directory, "chunks.npy"),
                training.chunks.squeeze(1))
        np.save(os.path.join(output_directory, "chunk_lengths.npy"),
                training.chunk_lengths)
        np.save(os.path.join(output_directory, "references.npy"),
                training.targets)
        np.save(os.path.join(output_directory, "reference_lengths.npy"),
                training.target_lengths)

        sys.stderr.write("> written ctc training data\n")
        sys.stderr.write("  - chunks.npy with shape (%s)\n" %
                         ','.join(map(str,
                                      training.chunks.squeeze(1).shape)))
        sys.stderr.write("  - chunk_lengths.npy with shape (%s)\n" %
                         ','.join(map(str, training.chunk_lengths.shape)))
        sys.stderr.write("  - references.npy with shape (%s)\n" %
                         ','.join(map(str, training.targets.shape)))
        sys.stderr.write("  - reference_lengths.npy shape (%s)\n" %
                         ','.join(map(str, training.target_lengths.shape)))
Esempio n. 4
0
def main(args):

    workdir = os.path.expanduser(args.tuning_directory)

    if os.path.exists(workdir) and not args.force:
        print("* error: %s exists." % workdir)
        exit(1)

    os.makedirs(workdir, exist_ok=True)

    init(args.seed, args.device)
    device = torch.device(args.device)

    print("[loading data]")
    chunks, chunk_lengths, targets, target_lengths = load_data(
        limit=args.chunks, directory=args.directory)
    split = np.floor(chunks.shape[0] * args.validation_split).astype(np.int32)
    train_dataset = ChunkDataSet(chunks[:split], chunk_lengths[:split],
                                 targets[:split], target_lengths[:split])
    test_dataset = ChunkDataSet(chunks[split:], chunk_lengths[split:],
                                targets[split:], target_lengths[split:])
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch,
                             num_workers=4,
                             pin_memory=True)

    def objective(trial):

        config = toml.load(args.config)

        lr = 1e-3
        #config['block'][0]['stride'] = [trial.suggest_int('stride', 4, 6)]

        # C1
        config['block'][0]['kernel'] = [
            int(trial.suggest_discrete_uniform('c1_kernel', 1, 129, 2))
        ]
        config['block'][0]['filters'] = trial.suggest_int(
            'c1_filters', 1, 1024)

        # B1 - B5
        for i in range(1, 6):
            config['block'][i]['repeat'] = trial.suggest_int(
                'b%s_repeat' % i, 1, 9)
            config['block'][i]['filters'] = trial.suggest_int(
                'b%s_filters' % i, 1, 512)
            config['block'][i]['kernel'] = [
                int(trial.suggest_discrete_uniform('b%s_kernel' % i, 1, 129,
                                                   2))
            ]

        # C2
        config['block'][-2]['kernel'] = [
            int(trial.suggest_discrete_uniform('c2_kernel', 1, 129, 2))
        ]
        config['block'][-2]['filters'] = trial.suggest_int(
            'c2_filters', 1, 1024)

        # C3
        config['block'][-1]['kernel'] = [
            int(trial.suggest_discrete_uniform('c3_kernel', 1, 129, 2))
        ]
        config['block'][-1]['filters'] = trial.suggest_int(
            'c3_filters', 1, 1024)

        model = load_symbol(config, 'Model')(config)
        num_params = sum(p.numel() for p in model.parameters())

        print("[trial %s]" % trial.number)

        if num_params > args.max_params:
            print("[pruned] network too large")
            raise optuna.exceptions.TrialPruned()

        model.to(args.device)
        model.train()

        os.makedirs(workdir, exist_ok=True)

        optimizer = AdamW(model.parameters(), amsgrad=True, lr=lr)
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)
        schedular = CosineAnnealingLR(optimizer,
                                      args.epochs * len(train_loader))

        for epoch in range(1, args.epochs + 1):

            try:
                train_loss, duration = train(model,
                                             device,
                                             train_loader,
                                             optimizer,
                                             use_amp=True)
                val_loss, val_mean, val_median = test(model, device,
                                                      test_loader)
                print(
                    "[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%"
                    .format(epoch, workdir, val_loss, val_mean, val_median))
            except KeyboardInterrupt:
                exit()
            except:
                print("[pruned] exception")
                raise optuna.exceptions.TrialPruned()

            if np.isnan(val_loss): val_loss = 9.9
            trial.report(val_loss, epoch)

            if trial.should_prune():
                print("[pruned] unpromising")
                raise optuna.exceptions.TrialPruned()

        trial.set_user_attr('seed', args.seed)
        trial.set_user_attr('val_loss', val_loss)
        trial.set_user_attr('val_mean', val_mean)
        trial.set_user_attr('val_median', val_median)
        trial.set_user_attr('train_loss', train_loss)
        trial.set_user_attr('batchsize', args.batch)
        trial.set_user_attr('model_params', num_params)

        torch.save(model.state_dict(),
                   os.path.join(workdir, "weights_%s.tar" % trial.number))
        toml.dump(
            config,
            open(os.path.join(workdir, 'config_%s.toml' % trial.number), 'w'))

        print("[loss] %.4f" % val_loss)
        return val_loss

    print("[starting study]")

    optuna.logging.set_verbosity(optuna.logging.WARNING)

    study = optuna.create_study(direction='minimize',
                                storage='sqlite:///%s' %
                                os.path.join(workdir, 'tune.db'),
                                study_name='bonito-study',
                                load_if_exists=True,
                                pruner=SuccessiveHalvingPruner())

    study.optimize(objective, n_trials=args.trials)
Esempio n. 5
0
def main(args):

    poas = []
    init(args.seed, args.device)

    print("* loading data")
    testdata = ChunkDataSet(
        *load_data(
            limit=args.chunks, shuffle=args.shuffle,
            directory=args.directory, validation=True
        )
    )
    dataloader = DataLoader(testdata, batch_size=args.batchsize)
    accuracy_with_cov = lambda ref, seq: accuracy(ref, seq, min_coverage=args.min_coverage)

    for w in [int(i) for i in args.weights.split(',')]:

        seqs = []

        print("* loading model", w)
        model = load_model(args.model_directory, args.device, weights=w)

        print("* calling")
        t0 = time.perf_counter()

        with torch.no_grad():
            for data, *_ in dataloader:
                if half_supported():
                    data = data.type(torch.float16).to(args.device)
                else:
                    data = data.to(args.device)

                log_probs = model(data)

                if hasattr(model, 'decode_batch'):
                    seqs.extend(model.decode_batch(log_probs))
                else:
                    seqs.extend([model.decode(p) for p in permute(log_probs, 'TNC', 'NTC')])

        duration = time.perf_counter() - t0

        refs = [decode_ref(target, model.alphabet) for target in dataloader.dataset.targets]
        accuracies = [accuracy_with_cov(ref, seq) if len(seq) else 0. for ref, seq in zip(refs, seqs)]

        if args.poa: poas.append(sequences)

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)
        print("* samples/s %.2E" % (args.chunks * data.shape[2] / duration))

    if args.poa:

        print("* doing poa")
        t0 = time.perf_counter()
        # group each sequence prediction per model together
        poas = [list(seq) for seq in zip(*poas)]
        consensuses = poa(poas)
        duration = time.perf_counter() - t0
        accuracies = list(starmap(accuracy_with_coverage_filter, zip(references, consensuses)))

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)
Esempio n. 6
0
def main(args):

    workdir = os.path.expanduser(args.training_directory)

    if os.path.exists(workdir) and not args.force:
        print("[error] %s exists, use -f to force continue training." %
              workdir)
        exit(1)

    init(args.seed, args.device)
    device = torch.device(args.device)

    print("[loading data]")
    train_data = load_data(limit=args.chunks, directory=args.directory)
    if os.path.exists(os.path.join(args.directory, 'validation')):
        valid_data = load_data(
            directory=os.path.join(args.directory, 'validation'))
    else:
        print("[validation set not found: splitting training set]")
        split = np.floor(len(train_data[0]) * 0.97).astype(np.int32)
        valid_data = [x[split:] for x in train_data]
        train_data = [x[:split] for x in train_data]

    train_loader = DataLoader(ChunkDataSet(*train_data),
                              batch_size=args.batch,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)
    valid_loader = DataLoader(ChunkDataSet(*valid_data),
                              batch_size=args.batch,
                              num_workers=4,
                              pin_memory=True)

    config = toml.load(args.config)
    argsdict = dict(training=vars(args))

    chunk_config = {}
    chunk_config_file = os.path.join(args.directory, 'config.toml')
    if os.path.isfile(chunk_config_file):
        chunk_config = toml.load(os.path.join(chunk_config_file))

    os.makedirs(workdir, exist_ok=True)
    toml.dump({
        **config,
        **argsdict,
        **chunk_config
    }, open(os.path.join(workdir, 'config.toml'), 'w'))

    print("[loading model]")
    if args.pretrained:
        print("[using pretrained model {}]".format(args.pretrained))
        model = load_model(args.pretrained, device, half=False)
    else:
        model = load_symbol(config, 'Model')(config)
    optimizer = AdamW(model.parameters(), amsgrad=False, lr=args.lr)

    last_epoch = load_state(workdir,
                            args.device,
                            model,
                            optimizer,
                            use_amp=args.amp)

    lr_scheduler = func_scheduler(optimizer,
                                  cosine_decay_schedule(1.0, 0.1),
                                  args.epochs * len(train_loader),
                                  warmup_steps=500,
                                  start_step=last_epoch * len(train_loader))

    if args.multi_gpu:
        from torch.nn import DataParallel
        model = DataParallel(model)
        model.decode = model.module.decode
        model.alphabet = model.module.alphabet

    if hasattr(model, 'seqdist'):
        criterion = model.seqdist.ctc_loss
    else:
        criterion = None

    for epoch in range(1 + last_epoch, args.epochs + 1 + last_epoch):

        try:
            with CSVLogger(os.path.join(
                    workdir, 'losses_{}.csv'.format(epoch))) as loss_log:
                train_loss, duration = train(model,
                                             device,
                                             train_loader,
                                             optimizer,
                                             criterion=criterion,
                                             use_amp=args.amp,
                                             lr_scheduler=lr_scheduler,
                                             loss_log=loss_log)

            model_state = model.state_dict(
            ) if not args.multi_gpu else model.module.state_dict()
            torch.save(model_state,
                       os.path.join(workdir, "weights_%s.tar" % epoch))

            val_loss, val_mean, val_median = test(model,
                                                  device,
                                                  valid_loader,
                                                  criterion=criterion)
        except KeyboardInterrupt:
            break

        print(
            "[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%"
            .format(epoch, workdir, val_loss, val_mean, val_median))

        with CSVLogger(os.path.join(workdir, 'training.csv')) as training_log:
            training_log.append(
                OrderedDict([('time', datetime.today()),
                             ('duration', int(duration)), ('epoch', epoch),
                             ('train_loss', train_loss),
                             ('validation_loss', val_loss),
                             ('validation_mean', val_mean),
                             ('validation_median', val_median)]))
Esempio n. 7
0
def main(args):

    workdir = os.path.expanduser(args.training_directory)

    if os.path.exists(workdir) and not args.force:
        print("[error] %s exists, use -f to force continue training." % workdir)
        exit(1)

    init(args.seed, args.device)
    device = torch.device(args.device)

    print("[loading data]")
    chunks, targets, lengths = load_data(limit=args.chunks, shuffle=True, directory=args.directory)

    split = np.floor(chunks.shape[0] * args.validation_split).astype(np.int32)
    train_dataset = ChunkDataSet(chunks[:split], targets[:split], lengths[:split])
    test_dataset = ChunkDataSet(chunks[split:], targets[split:], lengths[split:])
    train_loader = DataLoader(train_dataset, batch_size=args.batch, shuffle=True, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch, num_workers=4, pin_memory=True)

    config = toml.load(args.config)
    argsdict = dict(training=vars(args))

    chunk_config = {}
    chunk_config_file = os.path.join(args.directory, 'config.toml')
    if os.path.isfile(chunk_config_file):
        chunk_config = toml.load(os.path.join(chunk_config_file))

    os.makedirs(workdir, exist_ok=True)
    toml.dump({**config, **argsdict, **chunk_config}, open(os.path.join(workdir, 'config.toml'), 'w'))

    print("[loading model]")
    model = load_symbol(config, 'Model')(config)
    optimizer = AdamW(model.parameters(), amsgrad=False, lr=args.lr)

    last_epoch = load_state(workdir, args.device, model, optimizer, use_amp=args.amp)

    lr_scheduler = func_scheduler(
        optimizer, cosine_decay_schedule(1.0, 0.1), args.epochs * len(train_loader),
        warmup_steps=500, start_step=last_epoch*len(train_loader)
    )

    if args.multi_gpu:
        from torch.nn import DataParallel
        model = DataParallel(model)
        model.decode = model.module.decode
        model.alphabet = model.module.alphabet

    if hasattr(model, 'seqdist'):
        criterion = model.seqdist.ctc_loss
    else:
        criterion = None

    for epoch in range(1 + last_epoch, args.epochs + 1 + last_epoch):

        try:
            train_loss, duration = train(
                model, device, train_loader, optimizer, criterion=criterion,
                use_amp=args.amp, lr_scheduler=lr_scheduler
            )
            val_loss, val_mean, val_median = test(
                model, device, test_loader, criterion=criterion
            )
        except KeyboardInterrupt:
            break

        print("[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%".format(
            epoch, workdir, val_loss, val_mean, val_median
        ))

        model_state = model.state_dict() if not args.multi_gpu else model.module.state_dict()
        torch.save(model_state, os.path.join(workdir, "weights_%s.tar" % epoch))
        torch.save(optimizer.state_dict(), os.path.join(workdir, "optim_%s.tar" % epoch))

        with open(os.path.join(workdir, 'training.csv'), 'a', newline='') as csvfile:
            csvw = csv.writer(csvfile, delimiter=',')
            if epoch == 1:
                csvw.writerow([
                    'time', 'duration', 'epoch', 'train_loss',
                    'validation_loss', 'validation_mean', 'validation_median'
                ])
            csvw.writerow([
                datetime.today(), int(duration), epoch,
                train_loss, val_loss, val_mean, val_median,
            ])
Esempio n. 8
0
def main(args):

    workdir = os.path.expanduser(args.training_directory)

    if os.path.exists(workdir) and not args.force:
        print("[error] %s exists." % workdir)
        exit(1)

    init(args.seed, args.device)
    device = torch.device(args.device)

    print("[loading data]")
    chunks, chunk_lengths, targets, target_lengths = load_data(
        limit=args.chunks, shuffle=True, directory=args.directory)

    split = np.floor(chunks.shape[0] * args.validation_split).astype(np.int32)
    train_dataset = ChunkDataSet(chunks[:split], chunk_lengths[:split],
                                 targets[:split], target_lengths[:split])
    test_dataset = ChunkDataSet(chunks[split:], chunk_lengths[split:],
                                targets[split:], target_lengths[split:])
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch,
                             num_workers=4,
                             pin_memory=True)

    config = toml.load(args.config)
    argsdict = dict(training=vars(args))

    chunk_config = {}
    chunk_config_file = os.path.join(
        args.directory if args.directory else __data__, 'config.toml')
    if os.path.isfile(chunk_config_file):
        chunk_config = toml.load(os.path.join(chunk_config_file))

    print("[loading model]")
    model = Model(config)

    weights = os.path.join(workdir, 'weights.tar')
    if os.path.exists(weights): model.load_state_dict(torch.load(weights))

    model.to(device)
    model.train()

    os.makedirs(workdir, exist_ok=True)
    toml.dump({
        **config,
        **argsdict,
        **chunk_config
    }, open(os.path.join(workdir, 'config.toml'), 'w'))

    optimizer = AdamW(model.parameters(), amsgrad=True, lr=args.lr)

    if args.amp:
        try:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O1",
                                              verbosity=0)
        except NameError:
            print(
                "[error]: Cannot use AMP: Apex package needs to be installed manually, See https://github.com/NVIDIA/apex"
            )
            exit(1)

    schedular = CosineAnnealingLR(optimizer, args.epochs * len(train_loader))

    for epoch in range(1, args.epochs + 1):

        try:
            train_loss, duration = train(model,
                                         device,
                                         train_loader,
                                         optimizer,
                                         use_amp=args.amp)
            val_loss, val_mean, val_median = test(model, device, test_loader)
        except KeyboardInterrupt:
            break

        print(
            "[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%"
            .format(epoch, workdir, val_loss, val_mean, val_median))

        torch.save(model.state_dict(),
                   os.path.join(workdir, "weights_%s.tar" % epoch))
        with open(os.path.join(workdir, 'training.csv'), 'a',
                  newline='') as csvfile:
            csvw = csv.writer(csvfile, delimiter=',')
            if epoch == 1:
                csvw.writerow([
                    'time', 'duration', 'epoch', 'train_loss',
                    'validation_loss', 'validation_mean', 'validation_median'
                ])
            csvw.writerow([
                datetime.today(),
                int(duration),
                epoch,
                train_loss,
                val_loss,
                val_mean,
                val_median,
            ])

        schedular.step()
Esempio n. 9
0
def select_indices(ds, idx):
    return ChunkDataSet(
        ds.chunks.squeeze(1)[idx], ds.targets[idx], ds.lengths[idx]
    )
Esempio n. 10
0
def filter_chunks(ds, idx):
    filtered = ChunkDataSet(
        ds.chunks.squeeze(1)[idx], ds.targets[idx], ds.lengths[idx])
    filtered.targets = filtered.targets[:, :filtered.lengths.max()]
    return filtered
Esempio n. 11
0
def main(args):

    workdir = os.path.expanduser(args.tuning_directory)

    if os.path.exists(workdir) and not args.force:
        print("* error: %s exists." % workdir)
        exit(1)

    os.makedirs(workdir, exist_ok=True)

    init(args.seed, args.device)
    device = torch.device(args.device)

    print("[loading data]")
    train_data = load_data(limit=args.chunks, directory=args.directory)
    if os.path.exists(os.path.join(args.directory, 'validation')):
        valid_data = load_data(directory=os.path.join(args.directory,
                                                      'validation'),
                               limit=10000)
    else:
        print("[validation set not found: splitting training set]")
        split = np.floor(len(train_data[0]) * 0.97).astype(np.int32)
        valid_data = [x[split:] for x in train_data]
        train_data = [x[:split] for x in train_data]

    train_loader = DataLoader(ChunkDataSet(*train_data),
                              batch_size=args.batch,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)
    test_loader = DataLoader(ChunkDataSet(*valid_data),
                             batch_size=args.batch,
                             num_workers=4,
                             pin_memory=True)

    def objective(trial):

        config = toml.load(args.config)

        lr = trial.suggest_loguniform('learning_rate', 1e-5, 1e-1)

        model = load_symbol(config, 'Model')(config)

        num_params = sum(p.numel() for p in model.parameters())

        print("[trial %s]" % trial.number)

        model.to(args.device)
        model.train()

        os.makedirs(workdir, exist_ok=True)

        scaler = GradScaler(enabled=True)
        optimizer = AdamW(model.parameters(), amsgrad=False, lr=lr)
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)

        if hasattr(model, 'seqdist'):
            criterion = model.seqdist.ctc_loss
        else:
            criterion = None

        lr_scheduler = func_scheduler(
            optimizer,
            cosine_decay_schedule(1.0, decay),
            args.epochs * len(train_loader),
            warmup_steps=warmup_steps,
            warmup_ratio=warmup_ratio,
        )

        for epoch in range(1, args.epochs + 1):

            try:
                train_loss, duration = train(model,
                                             device,
                                             train_loader,
                                             optimizer,
                                             scaler=scaler,
                                             use_amp=True,
                                             criterion=criterion)
                val_loss, val_mean, val_median = test(model,
                                                      device,
                                                      test_loader,
                                                      criterion=criterion)
                print(
                    "[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%"
                    .format(epoch, workdir, val_loss, val_mean, val_median))
            except KeyboardInterrupt:
                exit()
            except Exception as e:
                print("[pruned] exception")
                raise optuna.exceptions.TrialPruned()

            if np.isnan(val_loss): val_loss = 9.9
            trial.report(val_loss, epoch)

            if trial.should_prune():
                print("[pruned] unpromising")
                raise optuna.exceptions.TrialPruned()

        trial.set_user_attr('val_loss', val_loss)
        trial.set_user_attr('val_mean', val_mean)
        trial.set_user_attr('val_median', val_median)
        trial.set_user_attr('train_loss', train_loss)
        trial.set_user_attr('model_params', num_params)

        torch.save(model.state_dict(),
                   os.path.join(workdir, "weights_%s.tar" % trial.number))
        toml.dump(
            config,
            open(os.path.join(workdir, 'config_%s.toml' % trial.number), 'w'))

        print("[loss] %.4f" % val_loss)
        return val_loss

    print("[starting study]")

    optuna.logging.set_verbosity(optuna.logging.WARNING)

    study = optuna.create_study(direction='minimize',
                                storage='sqlite:///%s' %
                                os.path.join(workdir, 'tune.db'),
                                study_name='bonito-study',
                                load_if_exists=True,
                                pruner=SuccessiveHalvingPruner())

    study.optimize(objective, n_trials=args.trials)
Esempio n. 12
0
    def run(self):

        chunks = []
        targets = []
        lengths = []

        for read, ctc_data in self.iterator:

            seq = ctc_data['sequence']
            qstring = ctc_data['qstring']
            mean_qscore = ctc_data['mean_qscore']
            mapping = ctc_data.get('mapping', False)

            self.log.append((read.read_id, len(read.signal)))

            if len(seq) == 0 or mapping is None:
                continue

            cov = (mapping.q_en - mapping.q_st) / len(seq)
            acc = mapping.mlen / mapping.blen
            refseq = self.aligner.seq(mapping.ctg, mapping.r_st, mapping.r_en)

            if acc < self.min_accuracy or cov < self.min_coverage or 'N' in refseq:
                continue

            write_sam(read.read_id,
                      seq,
                      qstring,
                      mapping,
                      fd=self.fd,
                      unaligned=mapping is None)
            with open(summary_file(), 'a') as summary:
                write_summary_row(read,
                                  len(seq),
                                  mean_qscore,
                                  alignment=mapping,
                                  fd=summary)

            if mapping.strand == -1:
                refseq = revcomp(refseq)

            target = [
                int(x) for x in refseq.translate({
                    65: '1',
                    67: '2',
                    71: '3',
                    84: '4'
                })
            ]
            targets.append(target)
            chunks.append(read.signal)
            lengths.append(len(target))

        if len(chunks) == 0:
            sys.stderr.write("> no suitable ctc data to write\n")
            return

        chunks = np.array(chunks, dtype=np.float16)
        targets_ = np.zeros((chunks.shape[0], max(lengths)), dtype=np.uint8)
        for idx, target in enumerate(targets):
            targets_[idx, :len(target)] = target
        lengths = np.array(lengths, dtype=np.uint16)

        training = ChunkDataSet(chunks, targets_, lengths)
        training = filter_chunks(training)

        mu, sd = np.mean(lengths), np.std(lengths)
        idx = [
            i for i, n in enumerate(lengths)
            if mu - 2.5 * sd < n < mu + 2.5 * sd
        ]
        summary = pd.read_csv(summary_file(), sep='\t')
        summary[summary.index.isin(idx)].to_csv(summary_file(),
                                                sep='\t',
                                                index=False)

        output_directory = '.' if sys.stdout.isatty() else dirname(
            realpath('/dev/fd/1'))
        np.save(os.path.join(output_directory, "chunks.npy"),
                training.chunks.squeeze(1))
        np.save(os.path.join(output_directory, "references.npy"),
                training.targets)
        np.save(os.path.join(output_directory, "reference_lengths.npy"),
                training.lengths)

        sys.stderr.write("> written ctc training data\n")
        sys.stderr.write("  - chunks.npy with shape (%s)\n" %
                         ','.join(map(str,
                                      training.chunks.squeeze(1).shape)))
        sys.stderr.write("  - references.npy with shape (%s)\n" %
                         ','.join(map(str, training.targets.shape)))
        sys.stderr.write("  - reference_lengths.npy shape (%s)\n" %
                         ','.join(map(str, training.lengths.shape)))