def main(args): workdir = os.path.expanduser(args.training_directory) if os.path.exists(workdir) and not args.force: print("[error] %s exists, use -f to force continue training." % workdir) exit(1) init(args.seed, args.device) device = torch.device(args.device) print("[loading data]") chunks, targets, lengths = load_data(limit=args.chunks, shuffle=True, directory=args.directory) split = np.floor(chunks.shape[0] * args.validation_split).astype(np.int32) train_dataset = ChunkDataSet(chunks[:split], targets[:split], lengths[:split]) test_dataset = ChunkDataSet(chunks[split:], targets[split:], lengths[split:]) train_loader = DataLoader(train_dataset, batch_size=args.batch, shuffle=True, num_workers=4, pin_memory=True) test_loader = DataLoader(test_dataset, batch_size=args.batch, num_workers=4, pin_memory=True) config = toml.load(args.config) argsdict = dict(training=vars(args)) chunk_config = {} chunk_config_file = os.path.join(args.directory, 'config.toml') if os.path.isfile(chunk_config_file): chunk_config = toml.load(os.path.join(chunk_config_file)) os.makedirs(workdir, exist_ok=True) toml.dump({**config, **argsdict, **chunk_config}, open(os.path.join(workdir, 'config.toml'), 'w')) print("[loading model]") model = load_symbol(config, 'Model')(config) optimizer = AdamW(model.parameters(), amsgrad=False, lr=args.lr) last_epoch = load_state(workdir, args.device, model, optimizer, use_amp=args.amp) lr_scheduler = func_scheduler( optimizer, cosine_decay_schedule(1.0, 0.1), args.epochs * len(train_loader), warmup_steps=500, start_step=last_epoch*len(train_loader) ) if args.multi_gpu: from torch.nn import DataParallel model = DataParallel(model) model.decode = model.module.decode model.alphabet = model.module.alphabet if hasattr(model, 'seqdist'): criterion = model.seqdist.ctc_loss else: criterion = None for epoch in range(1 + last_epoch, args.epochs + 1 + last_epoch): try: train_loss, duration = train( model, device, train_loader, optimizer, criterion=criterion, use_amp=args.amp, lr_scheduler=lr_scheduler ) val_loss, val_mean, val_median = test( model, device, test_loader, criterion=criterion ) except KeyboardInterrupt: break print("[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%".format( epoch, workdir, val_loss, val_mean, val_median )) model_state = model.state_dict() if not args.multi_gpu else model.module.state_dict() torch.save(model_state, os.path.join(workdir, "weights_%s.tar" % epoch)) torch.save(optimizer.state_dict(), os.path.join(workdir, "optim_%s.tar" % epoch)) with open(os.path.join(workdir, 'training.csv'), 'a', newline='') as csvfile: csvw = csv.writer(csvfile, delimiter=',') if epoch == 1: csvw.writerow([ 'time', 'duration', 'epoch', 'train_loss', 'validation_loss', 'validation_mean', 'validation_median' ]) csvw.writerow([ datetime.today(), int(duration), epoch, train_loss, val_loss, val_mean, val_median, ])
def main(args): workdir = os.path.expanduser(args.training_directory) if os.path.exists(workdir) and not args.force: print("[error] %s exists, use -f to force continue training." % workdir) exit(1) init(args.seed, args.device) device = torch.device(args.device) print("[loading data]") train_data = load_data(limit=args.chunks, directory=args.directory) if os.path.exists(os.path.join(args.directory, 'validation')): valid_data = load_data( directory=os.path.join(args.directory, 'validation')) else: print("[validation set not found: splitting training set]") split = np.floor(len(train_data[0]) * 0.97).astype(np.int32) valid_data = [x[split:] for x in train_data] train_data = [x[:split] for x in train_data] train_loader = DataLoader(ChunkDataSet(*train_data), batch_size=args.batch, shuffle=True, num_workers=4, pin_memory=True) valid_loader = DataLoader(ChunkDataSet(*valid_data), batch_size=args.batch, num_workers=4, pin_memory=True) config = toml.load(args.config) argsdict = dict(training=vars(args)) chunk_config = {} chunk_config_file = os.path.join(args.directory, 'config.toml') if os.path.isfile(chunk_config_file): chunk_config = toml.load(os.path.join(chunk_config_file)) os.makedirs(workdir, exist_ok=True) toml.dump({ **config, **argsdict, **chunk_config }, open(os.path.join(workdir, 'config.toml'), 'w')) print("[loading model]") if args.pretrained: print("[using pretrained model {}]".format(args.pretrained)) model = load_model(args.pretrained, device, half=False) else: model = load_symbol(config, 'Model')(config) optimizer = AdamW(model.parameters(), amsgrad=False, lr=args.lr) last_epoch = load_state(workdir, args.device, model, optimizer, use_amp=args.amp) lr_scheduler = func_scheduler(optimizer, cosine_decay_schedule(1.0, 0.1), args.epochs * len(train_loader), warmup_steps=500, start_step=last_epoch * len(train_loader)) if args.multi_gpu: from torch.nn import DataParallel model = DataParallel(model) model.decode = model.module.decode model.alphabet = model.module.alphabet if hasattr(model, 'seqdist'): criterion = model.seqdist.ctc_loss else: criterion = None for epoch in range(1 + last_epoch, args.epochs + 1 + last_epoch): try: with CSVLogger(os.path.join( workdir, 'losses_{}.csv'.format(epoch))) as loss_log: train_loss, duration = train(model, device, train_loader, optimizer, criterion=criterion, use_amp=args.amp, lr_scheduler=lr_scheduler, loss_log=loss_log) model_state = model.state_dict( ) if not args.multi_gpu else model.module.state_dict() torch.save(model_state, os.path.join(workdir, "weights_%s.tar" % epoch)) val_loss, val_mean, val_median = test(model, device, valid_loader, criterion=criterion) except KeyboardInterrupt: break print( "[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%" .format(epoch, workdir, val_loss, val_mean, val_median)) with CSVLogger(os.path.join(workdir, 'training.csv')) as training_log: training_log.append( OrderedDict([('time', datetime.today()), ('duration', int(duration)), ('epoch', epoch), ('train_loss', train_loss), ('validation_loss', val_loss), ('validation_mean', val_mean), ('validation_median', val_median)]))
def objective(trial): config = toml.load(args.config) lr = trial.suggest_loguniform('learning_rate', 1e-5, 1e-1) model = load_symbol(config, 'Model')(config) num_params = sum(p.numel() for p in model.parameters()) print("[trial %s]" % trial.number) model.to(args.device) model.train() os.makedirs(workdir, exist_ok=True) scaler = GradScaler(enabled=True) optimizer = AdamW(model.parameters(), amsgrad=False, lr=lr) model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0) if hasattr(model, 'seqdist'): criterion = model.seqdist.ctc_loss else: criterion = None lr_scheduler = func_scheduler( optimizer, cosine_decay_schedule(1.0, decay), args.epochs * len(train_loader), warmup_steps=warmup_steps, warmup_ratio=warmup_ratio, ) for epoch in range(1, args.epochs + 1): try: train_loss, duration = train(model, device, train_loader, optimizer, scaler=scaler, use_amp=True, criterion=criterion) val_loss, val_mean, val_median = test(model, device, test_loader, criterion=criterion) print( "[epoch {}] directory={} loss={:.4f} mean_acc={:.3f}% median_acc={:.3f}%" .format(epoch, workdir, val_loss, val_mean, val_median)) except KeyboardInterrupt: exit() except Exception as e: print("[pruned] exception") raise optuna.exceptions.TrialPruned() if np.isnan(val_loss): val_loss = 9.9 trial.report(val_loss, epoch) if trial.should_prune(): print("[pruned] unpromising") raise optuna.exceptions.TrialPruned() trial.set_user_attr('val_loss', val_loss) trial.set_user_attr('val_mean', val_mean) trial.set_user_attr('val_median', val_median) trial.set_user_attr('train_loss', train_loss) trial.set_user_attr('model_params', num_params) torch.save(model.state_dict(), os.path.join(workdir, "weights_%s.tar" % trial.number)) toml.dump( config, open(os.path.join(workdir, 'config_%s.toml' % trial.number), 'w')) print("[loss] %.4f" % val_loss) return val_loss