def main(): anchors = [30, 54, 95] shuffle = not (args.no_shuffle) exp = args.exp warm_up_epoch = 3 # Load and process data if args.fold: df_train = pd.read_csv(args.data_path + 'k_fold/official_train_fold%d.csv' % (args.fold)) df_val = pd.read_csv(args.data_path + 'k_fold/official_val_fold%d.csv' % (args.fold)) else: df_train = pd.read_csv(args.data_path + 'official_train.csv') df_val = pd.read_csv(args.data_path + 'official_val.csv') train = df_train.image_path.to_list() val = df_val.image_path.to_list() if exp: y_train = df_train.anchor.to_list() y_val = df_val.anchor.to_list() reg_train_gt = df_train.exp_wind.to_list() reg_val_gt = df_val.exp_wind.to_list() else: y_train = df_train.wind_speed.to_list() y_val = df_val.wind_speed.to_list() train_transform, val_transform = get_transform(args.image_size) train_dataset = WindDataset(image_list=train, target=y_train, exp_target=reg_train_gt if exp else None, transform=train_transform) val_dataset = WindDataset(image_list=val, target=y_val, exp_target=reg_val_gt if exp else None, transform=val_transform) train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=shuffle, num_workers=args.num_workers, drop_last=True) val_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=True) warm_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size * 14, shuffle=shuffle, num_workers=args.num_workers, drop_last=True) # Load model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') last_epoch = 0 # model = ResNet50_BN_idea() if not exp: model = Effnet_Wind_B7() # model = Effnet_Wind_B5() else: model = Effnet_Wind_B5_exp_6() # model = ResNetExample() # if not exp: # model = Seresnext_Wind() # else: # model = Seresnext_Wind_Exp() # Optimizer if args.opt == 'radam': optimizer = RAdam( model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.weight_decay, ) elif args.opt == 'adamw': optimizer = AdamW(model.parameters(), args.lr) elif args.opt == 'adam': optimizer = Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) else: optimizer = SGD(model.parameters(), args.lr, momentum=0.9, nesterov=True, weight_decay=args.weight_decay) if args.weights: # model.load_state_dict(torch.load(args.weights)) last_epoch = extract_number(args.weights) try: checkpoint = torch.load(args.weights) model.load_state_dict(checkpoint['model_state_dict']) if checkpoint['pre_opt'] == args.opt: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) print(optimizer) except: model.load_state_dict(torch.load(args.weights)) else: model.apply(reset_m_batchnorm) model.to(device) # Loss function if exp: criterion = JointLoss2() else: criterion = RMSELoss() # generate log and visualization save_path = args.save_path log_cache = (args.batch_size, args.image_size, shuffle, exp) write_log(args.save_path, model, optimizer, criterion, log_cache) plot_dict = {'train': list(), 'val': list()} log_train_path = save_path + 'training_log.txt' plot_train_path = save_path + 'log.json' write_mode = 'w' if os.path.exists(log_train_path) and os.path.exists(plot_train_path): write_mode = 'a' with open(plot_train_path, 'r') as j: plot_dict = json.load(j) plot_dict['train'] = plot_dict['train'][:last_epoch] plot_dict['val'] = plot_dict['val'][:last_epoch] # Training print('Start warm up') model.freeze_except_last() for epoch in range(warm_up_epoch): warm_up( model=model, dataloader=warm_loader, optimizer=optimizer, criterion=criterion, device=device, ) model.unfreeze() with open(log_train_path, write_mode) as f: for epoch in range(1, args.epoch + 1): print('Epoch:', epoch + last_epoch) f.write('Epoch: %d\n' % (epoch + last_epoch)) loss = train_epoch(model=model, dataloader=train_loader, optimizer=optimizer, criterion=criterion, device=device, exp=exp) RMSE = val_epoch(model=model, dataloader=val_loader, device=device, exp=exp, anchors=anchors) if not exp: f.write('Training loss: %.4f\n' % (loss)) f.write('RMSE val: %.4f\n' % (RMSE)) print('RMSE loss: %.4f' % (loss)) print('RMSE val: %.4f' % (RMSE)) else: loss, classify, regress = loss RMSE, accuracy = RMSE f.write('Training loss: %.4f\n' % (loss)) f.write('Classification loss: %.4f\n' % (classify)) f.write('Regression loss: %.4f\n' % (regress)) f.write('Accuracy val: %.4f\n' % (accuracy)) f.write('RMSE val: %.4f\n' % (RMSE)) print('Training loss: %.4f' % (loss)) print('Classification loss: %.4f' % (classify)) print('Regression loss: %.4f' % (regress)) print('Accuracy val: %.4f' % (accuracy)) print('RMSE val: %.4f' % (RMSE)) # torch.save(model.state_dict(), save_path + 'epoch%d.pth'%(epoch+last_epoch)) save_name = save_path + 'epoch%d.pth' % (epoch + last_epoch) save_pth(save_name, epoch + last_epoch, model, optimizer, args.opt) plot_dict['train'].append(loss) plot_dict['val'].append(RMSE) with open(plot_train_path, 'w') as j: json.dump(plot_dict, j)
def train(args, cfg): device = torch.device('cuda') model = ModelWithLoss(cfg).to(device) print('------------Model Architecture-------------') print(model) print('Loading Datasets...') data_loader = {} if cfg.SOLVER.AUGMENTATION: train_transforms = SyntheticTransforms() else: train_transforms = ToTensor() if cfg.DATASET.TRACK == 'synthetic': train_dataset = SyntheticBurst(ZurichRAW2RGB(cfg.DATASET.TRAIN_SYNTHETIC), crop_sz=cfg.SOLVER.PATCH_SIZE, burst_size=cfg.MODEL.BURST_SIZE, transform=train_transforms) elif cfg.DATASET.TRACK == 'real': train_dataset = BurstSRDataset(cfg.DATASET.REAL, split='train', crop_sz=cfg.SOLVER.PATCH_SIZE // 8, burst_size=cfg.MODEL.BURST_SIZE) sampler = RandomSampler(train_dataset) batch_sampler = BatchSampler(sampler=sampler, batch_size=cfg.SOLVER.BATCH_SIZE, drop_last=True) batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER) train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_sampler=batch_sampler, pin_memory=True) data_loader['train'] = train_loader # if args.eval_step != 0: # val_transforms = # val_dataset = # sampler = SequentialSampler(val_dataset) # batch_sampler = BatchSampler(sampler=sampler, batch_size=args.batch_size, drop_last=False) # val_loader = DataLoader(val_dataset, num_workers=args.num_workers, batch_sampler=batch_sampler) # data_loader['val'] = val_loader if cfg.SOLVER.OPTIMIZER == 'radam': optimizer = RAdam(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.SOLVER.LR) elif cfg.SOLVER.OPTIMIZER == 'adabound': optimizer = AdaBound(filter(lambda p:p.requires_grad, model.parameters()), lr=cfg.SOLVER.LR, final_lr=cfg.SOLVER.FINAL_LR) # optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.SOLVER.LR) # scheduler = MultiStepLR(optimizer, cfg.SOLVER.LR_STEP, gamma=0.1) scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.LR, cfg.SOLVER.LR_STEP, warmup_factor=cfg.SOLVER.WARMUP_FACTOR, warmup_iters=cfg.SOLVER.WARMUP_ITER) if args.resume_iter != 0: model_path = os.path.join(cfg.OUTPUT_DIR, 'model', 'iteration_{}.pth'.format(args.resume_iter)) print(f'Resume from {model_path}') model.model.load_state_dict(fix_model_state_dict(torch.load(os.path.join(cfg.OUTPUT_DIR, 'model', 'iteration_{}.pth'.format(args.resume_iter))))) if model.flow_refine: FR_model_path = os.path.dirname(model_path)[:-5] + "FR_model/" + 'iteration_{}.pth'.format(args.resume_iter) model.FR_model.load_state_dict(torch.load(FR_model_path)) if model.denoise_burst: denoise_model_path = os.path.dirname(model_path)[:-5] + "denoise_model/" + 'iteration_{}.pth'.format(args.resume_iter) model.denoise_model.load_state_dict(torch.load(denoise_model_path)) optimizer.load_state_dict(torch.load(os.path.join(cfg.OUTPUT_DIR, 'optimizer', 'iteration_{}.pth'.format(args.resume_iter)))) scheduler.load_state_dict(torch.load(os.path.join(cfg.OUTPUT_DIR, 'scheduler', 'iteration_{}.pth'.format(args.resume_iter)))) elif cfg.SOLVER.PRETRAIN_MODEL != '': model_path = cfg.SOLVER.PRETRAIN_MODEL print(f'load pretrain model from {model_path}') model.model.load_state_dict(fix_model_state_dict(torch.load(model_path))) if model.flow_refine: FR_model_path = os.path.dirname(model_path)[:-5] + "FR_model/" + os.path.basename(cfg.SOLVER.PRETRAIN_MODEL) model.FR_model.load_state_dict(torch.load(FR_model_path)) if model.denoise_burst: denoise_model_path = os.path.dirname(model_path)[:-5] + "denoise_model/" + os.path.basename(cfg.SOLVER.PRETRAIN_MODEL) model.denoise_model.load_state_dict(torch.load(denoise_model_path)) if cfg.SOLVER.SYNC_BATCHNORM: model = convert_model(model).to(device) if args.num_gpus > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpus))) if not args.debug: summary_writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR) else: summary_writer = None do_train(args, cfg, model, optimizer, scheduler, data_loader, device, summary_writer)
def train(rank: int, cfg: DictConfig): print(OmegaConf.to_yaml(cfg)) if cfg.train.n_gpu > 1: init_process_group(backend=cfg.train.dist_config['dist_backend'], init_method=cfg.train.dist_config['dist_url'], world_size=cfg.train.dist_config['world_size'] * cfg.train.n_gpu, rank=rank) device = torch.device( 'cuda:{:d}'.format(rank) if torch.cuda.is_available() else 'cpu') generator = Generator(sum(cfg.model.feature_dims), *cfg.model.cond_dims, **cfg.model.generator).to(device) discriminator = Discriminator(**cfg.model.discriminator).to(device) if rank == 0: print(generator) os.makedirs(cfg.train.ckpt_dir, exist_ok=True) print("checkpoints directory : ", cfg.train.ckpt_dir) if os.path.isdir(cfg.train.ckpt_dir): cp_g = scan_checkpoint(cfg.train.ckpt_dir, 'g_') cp_do = scan_checkpoint(cfg.train.ckpt_dir, 'd_') steps = 1 if cp_g is None or cp_do is None: state_dict_do = None last_epoch = -1 else: state_dict_g = load_checkpoint(cp_g, device) state_dict_do = load_checkpoint(cp_do, device) generator.load_state_dict(state_dict_g['generator']) discriminator.load_state_dict(state_dict_do['discriminator']) steps = state_dict_do['steps'] + 1 last_epoch = state_dict_do['epoch'] if cfg.train.n_gpu > 1: generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) discriminator = DistributedDataParallel(discriminator, device_ids=[rank]).to(device) optim_g = RAdam(generator.parameters(), cfg.opt.lr, betas=cfg.opt.betas) optim_d = RAdam(discriminator.parameters(), cfg.opt.lr, betas=cfg.opt.betas) if state_dict_do is not None: optim_g.load_state_dict(state_dict_do['optim_g']) optim_d.load_state_dict(state_dict_do['optim_d']) scheduler_g = torch.optim.lr_scheduler.ExponentialLR( optim_g, gamma=cfg.opt.lr_decay, last_epoch=last_epoch) scheduler_d = torch.optim.lr_scheduler.ExponentialLR( optim_d, gamma=cfg.opt.lr_decay, last_epoch=last_epoch) train_filelist = load_dataset_filelist(cfg.dataset.train_list) trainset = FeatureDataset(cfg.dataset, train_filelist, cfg.data) train_sampler = DistributedSampler( trainset) if cfg.train.n_gpu > 1 else None train_loader = DataLoader(trainset, batch_size=cfg.train.batch_size, num_workers=cfg.train.num_workers, shuffle=True, sampler=train_sampler, pin_memory=True, drop_last=True) if rank == 0: val_filelist = load_dataset_filelist(cfg.dataset.test_list) valset = FeatureDataset(cfg.dataset, val_filelist, cfg.data, segmented=False) val_loader = DataLoader(valset, batch_size=1, num_workers=cfg.train.num_workers, shuffle=False, sampler=train_sampler, pin_memory=True) sw = SummaryWriter(os.path.join(cfg.train.ckpt_dir, 'logs')) generator.train() discriminator.train() for epoch in range(max(0, last_epoch), cfg.train.epochs): if rank == 0: start = time.time() print("Epoch: {}".format(epoch + 1)) if cfg.train.n_gpu > 1: train_sampler.set_epoch(epoch) for y, x_noised_features, x_noised_cond in train_loader: if rank == 0: start_b = time.time() y = y.to(device, non_blocking=True) x_noised_features = x_noised_features.transpose(1, 2).to( device, non_blocking=True) x_noised_cond = x_noised_cond.to(device, non_blocking=True) z1 = torch.randn(cfg.train.batch_size, cfg.model.cond_dims[1], device=device) z2 = torch.randn(cfg.train.batch_size, cfg.model.cond_dims[1], device=device) y_hat1 = generator(x_noised_features, x_noised_cond, z=z1) y_hat2 = generator(x_noised_features, x_noised_cond, z=z2) # Discriminator real_scores, fake_scores = discriminator(y), discriminator( y_hat1.detach()) d_loss = discriminator_loss(real_scores, fake_scores) optim_d.zero_grad() d_loss.backward(retain_graph=True) optim_d.step() # Generator g_stft_loss = criterion(y, y_hat1) + criterion( y, y_hat2) - criterion(y_hat1, y_hat2) g_adv_loss = adversarial_loss(fake_scores) g_loss = g_adv_loss + g_stft_loss optim_g.zero_grad() g_loss.backward() optim_g.step() if rank == 0: # STDOUT logging if steps % cfg.train.stdout_interval == 0: with torch.no_grad(): print( 'Steps : {:d}, Gen Loss Total : {:4.3f}, STFT Error : {:4.3f}, s/b : {:4.3f}' .format(steps, g_loss, g_stft_loss, time.time() - start_b)) # checkpointing if steps % cfg.train.checkpoint_interval == 0: ckpt_dir = "{}/g_{:08d}".format(cfg.train.ckpt_dir, steps) save_checkpoint( ckpt_dir, { 'generator': (generator.module if cfg.train.n_gpu > 1 else generator).state_dict() }) ckpt_dir = "{}/do_{:08d}".format(cfg.train.ckpt_dir, steps) save_checkpoint( ckpt_dir, { 'discriminator': (discriminator.module if cfg.train.n_gpu > 1 else discriminator).state_dict(), 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, 'epoch': epoch }) # Tensorboard summary logging if steps % cfg.train.summary_interval == 0: sw.add_scalar("training/gen_loss_total", g_loss, steps) sw.add_scalar("training/gen_stft_error", g_stft_loss, steps) # Validation if steps % cfg.train.validation_interval == 0: generator.eval() torch.cuda.empty_cache() val_err_tot = 0 with torch.no_grad(): for j, (y, x_noised_features, x_noised_cond) in enumerate(val_loader): y_hat = generator( x_noised_features.transpose(1, 2).to(device), x_noised_cond.to(device)) val_err_tot += criterion(y, y_hat).item() if j <= 4: # sw.add_audio('noised/y_noised_{}'.format(j), y_noised[0], steps, cfg.data.target_sample_rate) sw.add_audio('generated/y_hat_{}'.format(j), y_hat[0], steps, cfg.data.sample_rate) sw.add_audio('gt/y_{}'.format(j), y[0], steps, cfg.data.sample_rate) val_err = val_err_tot / (j + 1) sw.add_scalar("validation/stft_error", val_err, steps) generator.train() steps += 1 scheduler_g.step() scheduler_d.step() if rank == 0: print('Time taken for epoch {} is {} sec\n'.format( epoch + 1, int(time.time() - start)))