Exemple #1
0
def main(args):

    poas = []
    init(args.seed, args.device)

    print("* loading data")
    testdata = ChunkDataSet(
        *load_data(limit=args.chunks, shuffle=args.shuffle))
    dataloader = DataLoader(testdata, batch_size=args.batchsize)

    for w in [int(i) for i in args.weights.split(',')]:

        print("* loading model", w)
        model = load_model(args.model_directory, args.device, weights=w)

        print("* calling")
        predictions = []
        t0 = time.perf_counter()

        for data, *_ in dataloader:
            with torch.no_grad():
                log_probs = model(data.to(args.device))
                predictions.append(log_probs.exp().cpu().numpy())

        duration = time.perf_counter() - t0

        references = [
            decode_ref(target, model.alphabet)
            for target in dataloader.dataset.targets
        ]
        sequences = [
            decode_ctc(post, model.alphabet)
            for post in np.concatenate(predictions)
        ]
        accuracies = list(starmap(accuracy, zip(references, sequences)))

        if args.poa: poas.append(sequences)

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)
        print("* samples/s %.2E" % (args.chunks * data.shape[2] / duration))

    if args.poa:

        print("* doing poa")
        t0 = time.perf_counter()
        # group each sequence prediction per model together
        poas = [list(seq) for seq in zip(*poas)]

        consensuses = poa(poas)
        duration = time.perf_counter() - t0

        accuracies = list(starmap(accuracy, zip(references, consensuses)))

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)
Exemple #2
0
def main(args):

    workdir = os.path.expanduser(args.tuning_directory)

    if os.path.exists(workdir) and not args.force:
        print("* error: %s exists." % workdir)
        exit(1)

    os.makedirs(workdir, exist_ok=True)

    init(args.seed, args.device)
    device = torch.device(args.device)

    print("[loading data]")
    chunks, chunk_lengths, targets, target_lengths = load_data(
        limit=args.chunks, directory=args.directory)
    split = np.floor(chunks.shape[0] * args.validation_split).astype(np.int32)
    train_dataset = ChunkDataSet(chunks[:split], chunk_lengths[:split],
                                 targets[:split], target_lengths[:split])
    test_dataset = ChunkDataSet(chunks[split:], chunk_lengths[split:],
                                targets[split:], target_lengths[split:])
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch,
                             num_workers=4,
                             pin_memory=True)

    def objective(trial):

        config = toml.load(args.config)

        lr = 1e-3
        #config['block'][0]['stride'] = [trial.suggest_int('stride', 4, 6)]

        # C1
        config['block'][0]['kernel'] = [
            int(trial.suggest_discrete_uniform('c1_kernel', 1, 129, 2))
        ]
        config['block'][0]['filters'] = trial.suggest_int(
            'c1_filters', 1, 1024)

        # B1 - B5
        for i in range(1, 6):
            config['block'][i]['repeat'] = trial.suggest_int(
                'b%s_repeat' % i, 1, 9)
            config['block'][i]['filters'] = trial.suggest_int(
                'b%s_filters' % i, 1, 512)
            config['block'][i]['kernel'] = [
                int(trial.suggest_discrete_uniform('b%s_kernel' % i, 1, 129,
                                                   2))
            ]

        # C2
        config['block'][-2]['kernel'] = [
            int(trial.suggest_discrete_uniform('c2_kernel', 1, 129, 2))
        ]
        config['block'][-2]['filters'] = trial.suggest_int(
            'c2_filters', 1, 1024)

        # C3
        config['block'][-1]['kernel'] = [
            int(trial.suggest_discrete_uniform('c3_kernel', 1, 129, 2))
        ]
        config['block'][-1]['filters'] = trial.suggest_int(
            'c3_filters', 1, 1024)

        model = load_symbol(config, 'Model')(config)
        num_params = sum(p.numel() for p in model.parameters())

        print("[trial %s]" % trial.number)

        if num_params > args.max_params:
            print("[pruned] network too large")
            raise optuna.exceptions.TrialPruned()

        model.to(args.device)
        model.train()

        os.makedirs(workdir, exist_ok=True)

        optimizer = AdamW(model.parameters(), amsgrad=True, lr=lr)
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)
        schedular = CosineAnnealingLR(optimizer,
                                      args.epochs * len(train_loader))

        for epoch in range(1, args.epochs + 1):

            try:
                train_loss, duration = train(model,
                                             device,
                                             train_loader,
                                             optimizer,
                                             use_amp=True)
                val_loss, val_mean, val_median = test(model, device,
                                                      test_loader)
                print(
                    "[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%"
                    .format(epoch, workdir, val_loss, val_mean, val_median))
            except KeyboardInterrupt:
                exit()
            except:
                print("[pruned] exception")
                raise optuna.exceptions.TrialPruned()

            if np.isnan(val_loss): val_loss = 9.9
            trial.report(val_loss, epoch)

            if trial.should_prune():
                print("[pruned] unpromising")
                raise optuna.exceptions.TrialPruned()

        trial.set_user_attr('seed', args.seed)
        trial.set_user_attr('val_loss', val_loss)
        trial.set_user_attr('val_mean', val_mean)
        trial.set_user_attr('val_median', val_median)
        trial.set_user_attr('train_loss', train_loss)
        trial.set_user_attr('batchsize', args.batch)
        trial.set_user_attr('model_params', num_params)

        torch.save(model.state_dict(),
                   os.path.join(workdir, "weights_%s.tar" % trial.number))
        toml.dump(
            config,
            open(os.path.join(workdir, 'config_%s.toml' % trial.number), 'w'))

        print("[loss] %.4f" % val_loss)
        return val_loss

    print("[starting study]")

    optuna.logging.set_verbosity(optuna.logging.WARNING)

    study = optuna.create_study(direction='minimize',
                                storage='sqlite:///%s' %
                                os.path.join(workdir, 'tune.db'),
                                study_name='bonito-study',
                                load_if_exists=True,
                                pruner=SuccessiveHalvingPruner())

    study.optimize(objective, n_trials=args.trials)
Exemple #3
0
def main(args):

    poas = []
    init(args.seed, args.device)

    print("* loading data")
    testdata = ChunkDataSet(
        *load_data(
            limit=args.chunks, shuffle=args.shuffle,
            directory=args.directory, validation=True
        )
    )
    dataloader = DataLoader(testdata, batch_size=args.batchsize)
    accuracy_with_cov = lambda ref, seq: accuracy(ref, seq, min_coverage=args.min_coverage)

    for w in [int(i) for i in args.weights.split(',')]:

        seqs = []

        print("* loading model", w)
        model = load_model(args.model_directory, args.device, weights=w)

        print("* calling")
        t0 = time.perf_counter()

        with torch.no_grad():
            for data, *_ in dataloader:
                if half_supported():
                    data = data.type(torch.float16).to(args.device)
                else:
                    data = data.to(args.device)

                log_probs = model(data)

                if hasattr(model, 'decode_batch'):
                    seqs.extend(model.decode_batch(log_probs))
                else:
                    seqs.extend([model.decode(p) for p in permute(log_probs, 'TNC', 'NTC')])

        duration = time.perf_counter() - t0

        refs = [decode_ref(target, model.alphabet) for target in dataloader.dataset.targets]
        accuracies = [accuracy_with_cov(ref, seq) if len(seq) else 0. for ref, seq in zip(refs, seqs)]

        if args.poa: poas.append(sequences)

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)
        print("* samples/s %.2E" % (args.chunks * data.shape[2] / duration))

    if args.poa:

        print("* doing poa")
        t0 = time.perf_counter()
        # group each sequence prediction per model together
        poas = [list(seq) for seq in zip(*poas)]
        consensuses = poa(poas)
        duration = time.perf_counter() - t0
        accuracies = list(starmap(accuracy_with_coverage_filter, zip(references, consensuses)))

        print("* mean      %.2f%%" % np.mean(accuracies))
        print("* median    %.2f%%" % np.median(accuracies))
        print("* time      %.2f" % duration)
Exemple #4
0
def main(args):

    init(args.seed, args.device)

    if args.model_directory in models and args.model_directory not in os.listdir(
            __models__):
        sys.stderr.write("> downloading model\n")
        File(__models__, models[args.model_directory]).download()

    sys.stderr.write(f"> loading model {args.model_directory}\n")
    try:
        model = load_model(
            args.model_directory,
            args.device,
            weights=int(args.weights),
            chunksize=args.chunksize,
            overlap=args.overlap,
            batchsize=args.batchsize,
            quantize=args.quantize,
            use_koi=True,
        )
    except FileNotFoundError:
        sys.stderr.write(f"> error: failed to load {args.model_directory}\n")
        sys.stderr.write(f"> available models:\n")
        for model in sorted(models):
            sys.stderr.write(f" - {model}\n")
        exit(1)

    if args.verbose:
        sys.stderr.write(
            f"> model basecaller params: {model.config['basecaller']}\n")

    basecall = load_symbol(args.model_directory, "basecall")

    mods_model = None
    if args.modified_base_model is not None or args.modified_bases is not None:
        sys.stderr.write("> loading modified base model\n")
        mods_model = load_mods_model(args.modified_bases, args.model_directory,
                                     args.modified_base_model)
        sys.stderr.write(f"> {mods_model[1]['alphabet_str']}\n")

    if args.reference:
        sys.stderr.write("> loading reference\n")
        aligner = Aligner(args.reference, preset='ont-map', best_n=1)
        if not aligner:
            sys.stderr.write("> failed to load/build index\n")
            exit(1)
    else:
        aligner = None

    fmt = biofmt(aligned=args.reference is not None)

    if args.reference and args.reference.endswith(
            ".mmi") and fmt.name == "cram":
        sys.stderr.write(
            "> error: reference cannot be a .mmi when outputting cram\n")
        exit(1)
    elif args.reference and fmt.name == "fastq":
        sys.stderr.write(
            f"> warning: did you really want {fmt.aligned} {fmt.name}?\n")
    else:
        sys.stderr.write(f"> outputting {fmt.aligned} {fmt.name}\n")

    if args.save_ctc and not args.reference:
        sys.stderr.write(
            "> a reference is needed to output ctc training data\n")
        exit(1)

    if fmt.name != 'fastq':
        groups = get_read_groups(args.reads_directory,
                                 args.model_directory,
                                 n_proc=8,
                                 recursive=args.recursive,
                                 read_ids=column_to_set(args.read_ids),
                                 skip=args.skip,
                                 cancel=process_cancel())
    else:
        groups = []

    reads = get_reads(args.reads_directory,
                      n_proc=8,
                      recursive=args.recursive,
                      read_ids=column_to_set(args.read_ids),
                      skip=args.skip,
                      cancel=process_cancel())

    if args.max_reads:
        reads = take(reads, args.max_reads)

    if args.save_ctc:
        reads = (chunk for read in reads for chunk in read_chunks(
            read,
            chunksize=model.config["basecaller"]["chunksize"],
            overlap=model.config["basecaller"]["overlap"]))
        ResultsWriter = CTCWriter
    else:
        ResultsWriter = Writer

    results = basecall(model,
                       reads,
                       reverse=args.revcomp,
                       batchsize=model.config["basecaller"]["batchsize"],
                       chunksize=model.config["basecaller"]["chunksize"],
                       overlap=model.config["basecaller"]["overlap"])

    if mods_model is not None:
        results = process_itemmap(partial(call_mods, mods_model), results)
    if aligner:
        results = align_map(aligner, results, n_thread=os.cpu_count())

    writer = ResultsWriter(
        fmt.mode,
        tqdm(results, desc="> calling", unit=" reads", leave=False),
        aligner=aligner,
        group_key=args.model_directory,
        ref_fn=args.reference,
        groups=groups,
    )

    t0 = perf_counter()
    writer.start()
    writer.join()
    duration = perf_counter() - t0
    num_samples = sum(num_samples for read_id, num_samples in writer.log)

    sys.stderr.write("> completed reads: %s\n" % len(writer.log))
    sys.stderr.write("> duration: %s\n" %
                     timedelta(seconds=np.round(duration)))
    sys.stderr.write("> samples per second %.1E\n" % (num_samples / duration))
    sys.stderr.write("> done\n")
Exemple #5
0
def main(args):

    workdir = os.path.expanduser(args.training_directory)

    if os.path.exists(workdir) and not args.force:
        print("[error] %s exists, use -f to force continue training." %
              workdir)
        exit(1)

    init(args.seed, args.device)
    device = torch.device(args.device)

    print("[loading data]")
    train_data = load_data(limit=args.chunks, directory=args.directory)
    if os.path.exists(os.path.join(args.directory, 'validation')):
        valid_data = load_data(
            directory=os.path.join(args.directory, 'validation'))
    else:
        print("[validation set not found: splitting training set]")
        split = np.floor(len(train_data[0]) * 0.97).astype(np.int32)
        valid_data = [x[split:] for x in train_data]
        train_data = [x[:split] for x in train_data]

    train_loader = DataLoader(ChunkDataSet(*train_data),
                              batch_size=args.batch,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)
    valid_loader = DataLoader(ChunkDataSet(*valid_data),
                              batch_size=args.batch,
                              num_workers=4,
                              pin_memory=True)

    config = toml.load(args.config)
    argsdict = dict(training=vars(args))

    chunk_config = {}
    chunk_config_file = os.path.join(args.directory, 'config.toml')
    if os.path.isfile(chunk_config_file):
        chunk_config = toml.load(os.path.join(chunk_config_file))

    os.makedirs(workdir, exist_ok=True)
    toml.dump({
        **config,
        **argsdict,
        **chunk_config
    }, open(os.path.join(workdir, 'config.toml'), 'w'))

    print("[loading model]")
    if args.pretrained:
        print("[using pretrained model {}]".format(args.pretrained))
        model = load_model(args.pretrained, device, half=False)
    else:
        model = load_symbol(config, 'Model')(config)
    optimizer = AdamW(model.parameters(), amsgrad=False, lr=args.lr)

    last_epoch = load_state(workdir,
                            args.device,
                            model,
                            optimizer,
                            use_amp=args.amp)

    lr_scheduler = func_scheduler(optimizer,
                                  cosine_decay_schedule(1.0, 0.1),
                                  args.epochs * len(train_loader),
                                  warmup_steps=500,
                                  start_step=last_epoch * len(train_loader))

    if args.multi_gpu:
        from torch.nn import DataParallel
        model = DataParallel(model)
        model.decode = model.module.decode
        model.alphabet = model.module.alphabet

    if hasattr(model, 'seqdist'):
        criterion = model.seqdist.ctc_loss
    else:
        criterion = None

    for epoch in range(1 + last_epoch, args.epochs + 1 + last_epoch):

        try:
            with CSVLogger(os.path.join(
                    workdir, 'losses_{}.csv'.format(epoch))) as loss_log:
                train_loss, duration = train(model,
                                             device,
                                             train_loader,
                                             optimizer,
                                             criterion=criterion,
                                             use_amp=args.amp,
                                             lr_scheduler=lr_scheduler,
                                             loss_log=loss_log)

            model_state = model.state_dict(
            ) if not args.multi_gpu else model.module.state_dict()
            torch.save(model_state,
                       os.path.join(workdir, "weights_%s.tar" % epoch))

            val_loss, val_mean, val_median = test(model,
                                                  device,
                                                  valid_loader,
                                                  criterion=criterion)
        except KeyboardInterrupt:
            break

        print(
            "[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%"
            .format(epoch, workdir, val_loss, val_mean, val_median))

        with CSVLogger(os.path.join(workdir, 'training.csv')) as training_log:
            training_log.append(
                OrderedDict([('time', datetime.today()),
                             ('duration', int(duration)), ('epoch', epoch),
                             ('train_loss', train_loss),
                             ('validation_loss', val_loss),
                             ('validation_mean', val_mean),
                             ('validation_median', val_median)]))
Exemple #6
0
def main(args):

    workdir = os.path.expanduser(args.training_directory)

    if os.path.exists(workdir) and not args.force:
        print("[error] %s exists, use -f to force continue training." % workdir)
        exit(1)

    init(args.seed, args.device)
    device = torch.device(args.device)

    print("[loading data]")
    chunks, targets, lengths = load_data(limit=args.chunks, shuffle=True, directory=args.directory)

    split = np.floor(chunks.shape[0] * args.validation_split).astype(np.int32)
    train_dataset = ChunkDataSet(chunks[:split], targets[:split], lengths[:split])
    test_dataset = ChunkDataSet(chunks[split:], targets[split:], lengths[split:])
    train_loader = DataLoader(train_dataset, batch_size=args.batch, shuffle=True, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch, num_workers=4, pin_memory=True)

    config = toml.load(args.config)
    argsdict = dict(training=vars(args))

    chunk_config = {}
    chunk_config_file = os.path.join(args.directory, 'config.toml')
    if os.path.isfile(chunk_config_file):
        chunk_config = toml.load(os.path.join(chunk_config_file))

    os.makedirs(workdir, exist_ok=True)
    toml.dump({**config, **argsdict, **chunk_config}, open(os.path.join(workdir, 'config.toml'), 'w'))

    print("[loading model]")
    model = load_symbol(config, 'Model')(config)
    optimizer = AdamW(model.parameters(), amsgrad=False, lr=args.lr)

    last_epoch = load_state(workdir, args.device, model, optimizer, use_amp=args.amp)

    lr_scheduler = func_scheduler(
        optimizer, cosine_decay_schedule(1.0, 0.1), args.epochs * len(train_loader),
        warmup_steps=500, start_step=last_epoch*len(train_loader)
    )

    if args.multi_gpu:
        from torch.nn import DataParallel
        model = DataParallel(model)
        model.decode = model.module.decode
        model.alphabet = model.module.alphabet

    if hasattr(model, 'seqdist'):
        criterion = model.seqdist.ctc_loss
    else:
        criterion = None

    for epoch in range(1 + last_epoch, args.epochs + 1 + last_epoch):

        try:
            train_loss, duration = train(
                model, device, train_loader, optimizer, criterion=criterion,
                use_amp=args.amp, lr_scheduler=lr_scheduler
            )
            val_loss, val_mean, val_median = test(
                model, device, test_loader, criterion=criterion
            )
        except KeyboardInterrupt:
            break

        print("[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%".format(
            epoch, workdir, val_loss, val_mean, val_median
        ))

        model_state = model.state_dict() if not args.multi_gpu else model.module.state_dict()
        torch.save(model_state, os.path.join(workdir, "weights_%s.tar" % epoch))
        torch.save(optimizer.state_dict(), os.path.join(workdir, "optim_%s.tar" % epoch))

        with open(os.path.join(workdir, 'training.csv'), 'a', newline='') as csvfile:
            csvw = csv.writer(csvfile, delimiter=',')
            if epoch == 1:
                csvw.writerow([
                    'time', 'duration', 'epoch', 'train_loss',
                    'validation_loss', 'validation_mean', 'validation_median'
                ])
            csvw.writerow([
                datetime.today(), int(duration), epoch,
                train_loss, val_loss, val_mean, val_median,
            ])
Exemple #7
0
def main(args):

    workdir = os.path.expanduser(args.training_directory)

    if os.path.exists(workdir) and not args.force:
        print("[error] %s exists, use -f to force continue training." %
              workdir)
        exit(1)

    init(args.seed, args.device, (not args.nondeterministic))
    device = torch.device(args.device)

    print("[loading data]")
    try:
        train_loader_kwargs, valid_loader_kwargs = load_numpy(
            args.chunks, args.directory)
    except FileNotFoundError:
        train_loader_kwargs, valid_loader_kwargs = load_script(
            args.directory,
            seed=args.seed,
            chunks=args.chunks,
            valid_chunks=args.valid_chunks)

    loader_kwargs = {
        "batch_size": args.batch,
        "num_workers": 4,
        "pin_memory": True
    }
    train_loader = DataLoader(**loader_kwargs, **train_loader_kwargs)
    valid_loader = DataLoader(**loader_kwargs, **valid_loader_kwargs)

    if args.pretrained:
        dirname = args.pretrained
        if not os.path.isdir(dirname) and os.path.isdir(
                os.path.join(__models__, dirname)):
            dirname = os.path.join(__models__, dirname)
        config_file = os.path.join(dirname, 'config.toml')
    else:
        config_file = args.config

    config = toml.load(config_file)

    argsdict = dict(training=vars(args))

    os.makedirs(workdir, exist_ok=True)
    toml.dump({
        **config,
        **argsdict
    }, open(os.path.join(workdir, 'config.toml'), 'w'))

    print("[loading model]")
    if args.pretrained:
        print("[using pretrained model {}]".format(args.pretrained))
        model = load_model(args.pretrained, device, half=False)
    else:
        model = load_symbol(config, 'Model')(config)

    if config.get("lr_scheduler"):
        sched_config = config["lr_scheduler"]
        lr_scheduler_fn = getattr(import_module(sched_config["package"]),
                                  sched_config["symbol"])(**sched_config)
    else:
        lr_scheduler_fn = None

    trainer = Trainer(model,
                      device,
                      train_loader,
                      valid_loader,
                      use_amp=half_supported() and not args.no_amp,
                      lr_scheduler_fn=lr_scheduler_fn,
                      restore_optim=args.restore_optim,
                      save_optim_every=args.save_optim_every,
                      grad_accum_split=args.grad_accum_split)

    if (',' in args.lr):
        lr = [float(x) for x in args.lr.split(',')]
    else:
        lr = float(args.lr)
    trainer.fit(workdir, args.epochs, lr)
Exemple #8
0
def main(args):

    workdir = os.path.expanduser(args.training_directory)

    if os.path.exists(workdir) and not args.force:
        print("[error] %s exists." % workdir)
        exit(1)

    init(args.seed, args.device)
    device = torch.device(args.device)

    print("[loading data]")
    chunks, chunk_lengths, targets, target_lengths = load_data(
        limit=args.chunks, shuffle=True, directory=args.directory)

    split = np.floor(chunks.shape[0] * args.validation_split).astype(np.int32)
    train_dataset = ChunkDataSet(chunks[:split], chunk_lengths[:split],
                                 targets[:split], target_lengths[:split])
    test_dataset = ChunkDataSet(chunks[split:], chunk_lengths[split:],
                                targets[split:], target_lengths[split:])
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch,
                             num_workers=4,
                             pin_memory=True)

    config = toml.load(args.config)
    argsdict = dict(training=vars(args))

    chunk_config = {}
    chunk_config_file = os.path.join(
        args.directory if args.directory else __data__, 'config.toml')
    if os.path.isfile(chunk_config_file):
        chunk_config = toml.load(os.path.join(chunk_config_file))

    print("[loading model]")
    model = Model(config)

    weights = os.path.join(workdir, 'weights.tar')
    if os.path.exists(weights): model.load_state_dict(torch.load(weights))

    model.to(device)
    model.train()

    os.makedirs(workdir, exist_ok=True)
    toml.dump({
        **config,
        **argsdict,
        **chunk_config
    }, open(os.path.join(workdir, 'config.toml'), 'w'))

    optimizer = AdamW(model.parameters(), amsgrad=True, lr=args.lr)

    if args.amp:
        try:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O1",
                                              verbosity=0)
        except NameError:
            print(
                "[error]: Cannot use AMP: Apex package needs to be installed manually, See https://github.com/NVIDIA/apex"
            )
            exit(1)

    schedular = CosineAnnealingLR(optimizer, args.epochs * len(train_loader))

    for epoch in range(1, args.epochs + 1):

        try:
            train_loss, duration = train(model,
                                         device,
                                         train_loader,
                                         optimizer,
                                         use_amp=args.amp)
            val_loss, val_mean, val_median = test(model, device, test_loader)
        except KeyboardInterrupt:
            break

        print(
            "[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%"
            .format(epoch, workdir, val_loss, val_mean, val_median))

        torch.save(model.state_dict(),
                   os.path.join(workdir, "weights_%s.tar" % epoch))
        with open(os.path.join(workdir, 'training.csv'), 'a',
                  newline='') as csvfile:
            csvw = csv.writer(csvfile, delimiter=',')
            if epoch == 1:
                csvw.writerow([
                    'time', 'duration', 'epoch', 'train_loss',
                    'validation_loss', 'validation_mean', 'validation_median'
                ])
            csvw.writerow([
                datetime.today(),
                int(duration),
                epoch,
                train_loss,
                val_loss,
                val_mean,
                val_median,
            ])

        schedular.step()
Exemple #9
0
def main(args):

    workdir = os.path.expanduser(args.tuning_directory)

    if os.path.exists(workdir) and not args.force:
        print("* error: %s exists." % workdir)
        exit(1)

    os.makedirs(workdir, exist_ok=True)

    init(args.seed, args.device)
    device = torch.device(args.device)

    print("[loading data]")
    train_data = load_data(limit=args.chunks, directory=args.directory)
    if os.path.exists(os.path.join(args.directory, 'validation')):
        valid_data = load_data(directory=os.path.join(args.directory,
                                                      'validation'),
                               limit=10000)
    else:
        print("[validation set not found: splitting training set]")
        split = np.floor(len(train_data[0]) * 0.97).astype(np.int32)
        valid_data = [x[split:] for x in train_data]
        train_data = [x[:split] for x in train_data]

    train_loader = DataLoader(ChunkDataSet(*train_data),
                              batch_size=args.batch,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)
    test_loader = DataLoader(ChunkDataSet(*valid_data),
                             batch_size=args.batch,
                             num_workers=4,
                             pin_memory=True)

    def objective(trial):

        config = toml.load(args.config)

        lr = trial.suggest_loguniform('learning_rate', 1e-5, 1e-1)

        model = load_symbol(config, 'Model')(config)

        num_params = sum(p.numel() for p in model.parameters())

        print("[trial %s]" % trial.number)

        model.to(args.device)
        model.train()

        os.makedirs(workdir, exist_ok=True)

        scaler = GradScaler(enabled=True)
        optimizer = AdamW(model.parameters(), amsgrad=False, lr=lr)
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)

        if hasattr(model, 'seqdist'):
            criterion = model.seqdist.ctc_loss
        else:
            criterion = None

        lr_scheduler = func_scheduler(
            optimizer,
            cosine_decay_schedule(1.0, decay),
            args.epochs * len(train_loader),
            warmup_steps=warmup_steps,
            warmup_ratio=warmup_ratio,
        )

        for epoch in range(1, args.epochs + 1):

            try:
                train_loss, duration = train(model,
                                             device,
                                             train_loader,
                                             optimizer,
                                             scaler=scaler,
                                             use_amp=True,
                                             criterion=criterion)
                val_loss, val_mean, val_median = test(model,
                                                      device,
                                                      test_loader,
                                                      criterion=criterion)
                print(
                    "[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%"
                    .format(epoch, workdir, val_loss, val_mean, val_median))
            except KeyboardInterrupt:
                exit()
            except Exception as e:
                print("[pruned] exception")
                raise optuna.exceptions.TrialPruned()

            if np.isnan(val_loss): val_loss = 9.9
            trial.report(val_loss, epoch)

            if trial.should_prune():
                print("[pruned] unpromising")
                raise optuna.exceptions.TrialPruned()

        trial.set_user_attr('val_loss', val_loss)
        trial.set_user_attr('val_mean', val_mean)
        trial.set_user_attr('val_median', val_median)
        trial.set_user_attr('train_loss', train_loss)
        trial.set_user_attr('model_params', num_params)

        torch.save(model.state_dict(),
                   os.path.join(workdir, "weights_%s.tar" % trial.number))
        toml.dump(
            config,
            open(os.path.join(workdir, 'config_%s.toml' % trial.number), 'w'))

        print("[loss] %.4f" % val_loss)
        return val_loss

    print("[starting study]")

    optuna.logging.set_verbosity(optuna.logging.WARNING)

    study = optuna.create_study(direction='minimize',
                                storage='sqlite:///%s' %
                                os.path.join(workdir, 'tune.db'),
                                study_name='bonito-study',
                                load_if_exists=True,
                                pruner=SuccessiveHalvingPruner())

    study.optimize(objective, n_trials=args.trials)