model = build_model(hps, log) model.to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) scaler = GradScaler() scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) state = {k: v for k, v in args._get_kwargs()} if args.load_step == 0: # new model global_epoch = 0 global_step = 0 actnorm_init(train_loader, model, device) else: # saved model model, optimizer, scheduler, global_epoch, global_step = load_checkpoint( args.load_step, model, optimizer, scheduler) log.write('\n ! --- load the model and continue training --- ! \n') log_train.write( '\n ! --- load the model and continue training --- ! \n') log_eval.write( '\n ! --- load the model and continue training --- ! \n') log.flush() log_train.flush() log_eval.flush() for param_group in optimizer.param_groups: print('lr', param_group['lr'])
def main_worker(gpu, ngpus_per_node, args): global global_step global start_time args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) hps = Hyperparameters(args) sample_path, save_path, load_path, log_path = mkdir(args) if not args.distributed or (args.rank % ngpus_per_node == 0): log, log_train, log_eval = get_logger(log_path, args.model_name) else: log, log_train, log_eval = None, None, None model = build_model(hps, log) if args.distributed: # Multiple processes, single GPU per process if args.gpu is not None: def _transform_(m): return nn.parallel.DistributedDataParallel( m, device_ids=[args.gpu], output_device=args.gpu, check_reduction=True) torch.cuda.set_device(args.gpu) model.cuda(args.gpu) model.multi_gpu_wrapper(_transform_) args.bsz = int(args.bsz / ngpus_per_node) args.workers = 0 else: assert 0, "DistributedDataParallel constructor should always set the single device scope" elif args.gpu is not None: # Single process, single GPU per process torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) else: # Single process, multiple GPUs per process def _transform_(m): return nn.DataParallel(m) model = model.cuda() model.multi_gpu_wrapper(_transform_) train_loader, test_loader, synth_loader = load_dataset(args) optimizer = optim.Adam(model.parameters(), lr=args.lr) scaler = GradScaler() scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) state = {k: v for k, v in args._get_kwargs()} if args.load_step == 0: # new model global_epoch = 0 global_step = 0 actnorm_init(train_loader, model, args.gpu) else: # saved model model, optimizer, scheduler, global_epoch, global_step = load_checkpoint( args.load_step, load_path, model, optimizer, scheduler) if log is not None: log.write('\n ! --- load the model and continue training --- ! \n') log_train.write( '\n ! --- load the model and continue training --- ! \n') log_eval.write( '\n ! --- load the model and continue training --- ! \n') log.flush() log_train.flush() log_eval.flush() start_time = time.time() dateTime = datetime.datetime.fromtimestamp(start_time).strftime( '%Y-%m-%d %H:%M:%S') print('training starts at ', dateTime) for epoch in range(global_epoch + 1, args.epochs + 1): training_epoch_loss = train(args.gpu, epoch, train_loader, synth_loader, sample_path, model, optimizer, scaler, scheduler, log_train, args) with torch.no_grad(): eval_epoch_loss = evaluate(args.gpu, epoch, test_loader, model, log_eval) if log is not None: state['training_loss'] = training_epoch_loss state['eval_loss'] = eval_epoch_loss state['epoch'] = epoch log.write('%s\n' % json.dumps(state)) log.flush() if not args.distributed or (args.rank % ngpus_per_node == 0): save_checkpoint(save_path, model, optimizer, scaler, scheduler, global_step, epoch) print('Epoch {} Model Saved! Loss : {:.4f}'.format( epoch, eval_epoch_loss)) with torch.no_grad(): synthesize(args.gpu, sample_path, synth_loader, model, args.num_sample, args.sr) gc.collect() if log is not None: log_train.close() log_eval.close() log.close()