def load_model(dirname, device, weights=None, half=False): """ Load a model from disk """ if not os.path.isdir(dirname) and os.path.isdir( os.path.join(__dir__, "models", dirname)): dirname = os.path.join(__dir__, "models", dirname) if not weights: # take the latest checkpoint weight_files = glob(os.path.join(dirname, "weights_*.tar")) if not weight_files: raise FileNotFoundError("no model weights found in '%s'" % dirname) weights = max( [int(re.sub(".*_([0-9]+).tar", "\\1", w)) for w in weight_files]) device = torch.device(device) config = os.path.join(dirname, 'config.toml') weights = os.path.join(dirname, 'weights_%s.tar' % weights) model = Model(toml.load(config)) model.to(device) state_dict = torch.load(weights, map_location=device) new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k.replace('module.', '') new_state_dict[name] = v model.load_state_dict(new_state_dict) if half: model = model.half() model.eval() return model
def load_model(dirname, device, weights=None): """ Load a model from disk """ if not weights: # take the latest checkpoint weight_files = glob(os.path.join(dirname, "weights_*.tar")) weights = max( [int(re.sub(".*_([0-9]+).tar", "\\1", w)) for w in weight_files]) device = torch.device(device) config = os.path.join(dirname, 'config.toml') weights = os.path.join(dirname, 'weights_%s.tar' % weights) model = Model(toml.load(config)) model.to(device) model.load_state_dict(torch.load(weights, map_location=device)) model.eval() return model
def main(args): workdir = os.path.expanduser(args.training_directory) if os.path.exists(workdir) and not args.force: print("[error] %s exists." % workdir) exit(1) init(args.seed, args.device) device = torch.device(args.device) print("[loading data]") chunks, chunk_lengths, targets, target_lengths = load_data( limit=args.chunks, shuffle=True, directory=args.directory) split = np.floor(chunks.shape[0] * args.validation_split).astype(np.int32) train_dataset = ChunkDataSet(chunks[:split], chunk_lengths[:split], targets[:split], target_lengths[:split]) test_dataset = ChunkDataSet(chunks[split:], chunk_lengths[split:], targets[split:], target_lengths[split:]) train_loader = DataLoader(train_dataset, batch_size=args.batch, shuffle=True, num_workers=4, pin_memory=True) test_loader = DataLoader(test_dataset, batch_size=args.batch, num_workers=4, pin_memory=True) config = toml.load(args.config) argsdict = dict(training=vars(args)) chunk_config = {} chunk_config_file = os.path.join( args.directory if args.directory else __data__, 'config.toml') if os.path.isfile(chunk_config_file): chunk_config = toml.load(os.path.join(chunk_config_file)) print("[loading model]") model = Model(config) weights = os.path.join(workdir, 'weights.tar') if os.path.exists(weights): model.load_state_dict(torch.load(weights)) model.to(device) model.train() os.makedirs(workdir, exist_ok=True) toml.dump({ **config, **argsdict, **chunk_config }, open(os.path.join(workdir, 'config.toml'), 'w')) optimizer = AdamW(model.parameters(), amsgrad=True, lr=args.lr) if args.amp: try: model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0) except NameError: print( "[error]: Cannot use AMP: Apex package needs to be installed manually, See https://github.com/NVIDIA/apex" ) exit(1) schedular = CosineAnnealingLR(optimizer, args.epochs * len(train_loader)) for epoch in range(1, args.epochs + 1): try: train_loss, duration = train(model, device, train_loader, optimizer, use_amp=args.amp) val_loss, val_mean, val_median = test(model, device, test_loader) except KeyboardInterrupt: break print( "[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%" .format(epoch, workdir, val_loss, val_mean, val_median)) torch.save(model.state_dict(), os.path.join(workdir, "weights_%s.tar" % epoch)) with open(os.path.join(workdir, 'training.csv'), 'a', newline='') as csvfile: csvw = csv.writer(csvfile, delimiter=',') if epoch == 1: csvw.writerow([ 'time', 'duration', 'epoch', 'train_loss', 'validation_loss', 'validation_mean', 'validation_median' ]) csvw.writerow([ datetime.today(), int(duration), epoch, train_loss, val_loss, val_mean, val_median, ]) schedular.step()