示例#1
0
def load_model(dirname, device, weights=None, half=False):
    """
    Load a model from disk
    """
    if not os.path.isdir(dirname) and os.path.isdir(
            os.path.join(__dir__, "models", dirname)):
        dirname = os.path.join(__dir__, "models", dirname)

    if not weights:  # take the latest checkpoint
        weight_files = glob(os.path.join(dirname, "weights_*.tar"))
        if not weight_files:
            raise FileNotFoundError("no model weights found in '%s'" % dirname)
        weights = max(
            [int(re.sub(".*_([0-9]+).tar", "\\1", w)) for w in weight_files])

    device = torch.device(device)
    config = os.path.join(dirname, 'config.toml')
    weights = os.path.join(dirname, 'weights_%s.tar' % weights)
    model = Model(toml.load(config))
    model.to(device)

    state_dict = torch.load(weights, map_location=device)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace('module.', '')
        new_state_dict[name] = v

    model.load_state_dict(new_state_dict)

    if half: model = model.half()
    model.eval()
    return model
示例#2
0
def load_model(dirname, device, weights=None):
    """
    Load a model from disk
    """
    if not weights:  # take the latest checkpoint
        weight_files = glob(os.path.join(dirname, "weights_*.tar"))
        weights = max(
            [int(re.sub(".*_([0-9]+).tar", "\\1", w)) for w in weight_files])

    device = torch.device(device)
    config = os.path.join(dirname, 'config.toml')
    weights = os.path.join(dirname, 'weights_%s.tar' % weights)
    model = Model(toml.load(config))
    model.to(device)
    model.load_state_dict(torch.load(weights, map_location=device))
    model.eval()
    return model
示例#3
0
文件: train.py 项目: lcerdeira/bonito
def main(args):

    workdir = os.path.expanduser(args.training_directory)

    if os.path.exists(workdir) and not args.force:
        print("[error] %s exists." % workdir)
        exit(1)

    init(args.seed, args.device)
    device = torch.device(args.device)

    print("[loading data]")
    chunks, chunk_lengths, targets, target_lengths = load_data(
        limit=args.chunks, shuffle=True, directory=args.directory)

    split = np.floor(chunks.shape[0] * args.validation_split).astype(np.int32)
    train_dataset = ChunkDataSet(chunks[:split], chunk_lengths[:split],
                                 targets[:split], target_lengths[:split])
    test_dataset = ChunkDataSet(chunks[split:], chunk_lengths[split:],
                                targets[split:], target_lengths[split:])
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch,
                             num_workers=4,
                             pin_memory=True)

    config = toml.load(args.config)
    argsdict = dict(training=vars(args))

    chunk_config = {}
    chunk_config_file = os.path.join(
        args.directory if args.directory else __data__, 'config.toml')
    if os.path.isfile(chunk_config_file):
        chunk_config = toml.load(os.path.join(chunk_config_file))

    print("[loading model]")
    model = Model(config)

    weights = os.path.join(workdir, 'weights.tar')
    if os.path.exists(weights): model.load_state_dict(torch.load(weights))

    model.to(device)
    model.train()

    os.makedirs(workdir, exist_ok=True)
    toml.dump({
        **config,
        **argsdict,
        **chunk_config
    }, open(os.path.join(workdir, 'config.toml'), 'w'))

    optimizer = AdamW(model.parameters(), amsgrad=True, lr=args.lr)

    if args.amp:
        try:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O1",
                                              verbosity=0)
        except NameError:
            print(
                "[error]: Cannot use AMP: Apex package needs to be installed manually, See https://github.com/NVIDIA/apex"
            )
            exit(1)

    schedular = CosineAnnealingLR(optimizer, args.epochs * len(train_loader))

    for epoch in range(1, args.epochs + 1):

        try:
            train_loss, duration = train(model,
                                         device,
                                         train_loader,
                                         optimizer,
                                         use_amp=args.amp)
            val_loss, val_mean, val_median = test(model, device, test_loader)
        except KeyboardInterrupt:
            break

        print(
            "[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%"
            .format(epoch, workdir, val_loss, val_mean, val_median))

        torch.save(model.state_dict(),
                   os.path.join(workdir, "weights_%s.tar" % epoch))
        with open(os.path.join(workdir, 'training.csv'), 'a',
                  newline='') as csvfile:
            csvw = csv.writer(csvfile, delimiter=',')
            if epoch == 1:
                csvw.writerow([
                    'time', 'duration', 'epoch', 'train_loss',
                    'validation_loss', 'validation_mean', 'validation_median'
                ])
            csvw.writerow([
                datetime.today(),
                int(duration),
                epoch,
                train_loss,
                val_loss,
                val_mean,
                val_median,
            ])

        schedular.step()