def __init__(self, data_opt, **kwargs): """ Folder dataset with paired data support both BI & BD degradation """ super(ValidationDataset, self).__init__(data_opt, **kwargs) # get keys gt_keys = sorted(os.listdir(self.gt_seq_dir)) self.keys = sorted(list(set(gt_keys))) if data_opt['name'].startswith('Actors'): for i, k in enumerate(self.keys): self.keys[i] = k + '/frames' self.kernel = create_kernel({ 'dataset':{ 'degradation': { 'sigma': self.sigma } }, 'device': 'cuda' }) # filter keys if self.filter_file: with open(self.filter_file, 'r') as f: sel_keys = { line.strip() for line in f } self.keys = sorted(list(sel_keys & set(self.keys)))
def prepare_training_data(self, data): """ prepare gt, lr data for training for BD degradation, generate lr data and remove the border of gt data for BI degradation, use input data directly """ degradation_type = self.opt['dataset']['degradation']['type'] if degradation_type == 'BI': self.gt_data = data['gt'].to(self.device) self.lr_data = data['lr'].to(self.device) elif degradation_type == 'BD': # generate lr data on the fly (on gpu) # set params scale = self.opt['scale'] sigma = self.opt['dataset']['degradation'].get('sigma', 1.5) border_size = int(sigma * 3.0) gt_data = data['gt'].to(self.device) # with border n, t, c, gt_h, gt_w = gt_data.size() lr_h = (gt_h - 2 * border_size) // scale lr_w = (gt_w - 2 * border_size) // scale # create blurring kernel if self.blur_kernel is None: self.blur_kernel = create_kernel(sigma).to(self.device) blur_kernel = self.blur_kernel # generate lr data gt_data = gt_data.view(n * t, c, gt_h, gt_w) lr_data = downsample_bd(gt_data, blur_kernel, scale, pad_data=False) lr_data = lr_data.view(n, t, c, lr_h, lr_w) # remove gt border gt_data = gt_data[..., border_size:border_size + scale * lr_h, border_size:border_size + scale * lr_w] gt_data = gt_data.view(n, t, c, scale * lr_h, scale * lr_w) self.gt_data, self.lr_data = gt_data, lr_data # tchw|float32
def downscale_data(opt): for dataset_idx in sorted(opt['dataset'].keys()): if not dataset_idx.startswith('all'): continue loader = create_dataloader(opt, dataset_idx=dataset_idx) degradation_type = opt['dataset']['degradation']['type'] if degradation_type == 'BD': kernel = data_utils.create_kernel(opt) if degradation_type == 'Style': path = opt['cartoon_model'] cartoonizer = SimpleGenerator().to(torch.device(opt['device'])) cartoonizer.load_weights(path) cartoonizer.eval() for item in tqdm(loader, ascii=True): if degradation_type == 'BD': data = prepare_data(opt, item, kernel) elif degradation_type == 'BI': data = data_utils.BI_downsample(opt, item) elif degradation_type == 'Style': image = item['gt'][0] image = resize(image) image = image.to(torch.device(opt['device'])) with torch.no_grad(): stylized_image = cartoonizer(image).unsqueeze(0) stylized_image = (stylized_image + 1) * 0.5 data = {'gt': image.unsqueeze(0), 'lr': stylized_image} lr_data = data['lr'] gt_data = data['gt'] img = lr_data.squeeze(0).squeeze(0).permute(1, 2, 0).cpu().numpy() path = osp.join( 'data', opt['dataset']['common']['name'], opt['data_subset'], opt['dataset'][dataset_idx]['actor_name'], opt['data_type'] + '_' + opt['dataset']['degradation']['type'], opt['dataset'][dataset_idx]['segment'], 'frames') os.makedirs(path, exist_ok=True) path = osp.join(path, item['frame_key'][0]) img = img * 255.0 img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB) cv2.imwrite(path, img)
def prepare_inference_data(self, data): """ Prepare lr data for training (w/o loading on device) """ degradation_type = self.opt['dataset']['degradation']['type'] if degradation_type == 'BI': self.lr_data = data['lr'] elif degradation_type == 'BD': if 'lr' in data: self.lr_data = data['lr'] else: # generate lr data on the fly (on cpu) # TODO: do frame-wise downsampling on gpu for acceleration? gt_data = data['gt'] # thwc|uint8 # set params scale = self.opt['scale'] sigma = self.opt['dataset']['degradation'].get('sigma', 1.5) # create blurring kernel if self.blur_kernel is None: self.blur_kernel = create_kernel(sigma) blur_kernel = self.blur_kernel.cpu() # generate lr data gt_data = gt_data.permute(0, 3, 1, 2).float() / 255.0 # tchw|float32 lr_data = downsample_bd(gt_data, blur_kernel, scale, pad_data=True) lr_data = lr_data.permute(0, 2, 3, 1) # thwc|float32 self.lr_data = lr_data # thwc to tchw self.lr_data = self.lr_data.permute(0, 3, 1, 2) # tchw|float32
def train(opt): # logging logger = base_utils.get_logger('base') logger.info('{} Options {}'.format('='*20, '='*20)) base_utils.print_options(opt, logger) # create data loader train_loader = create_dataloader(opt, dataset_idx='train') # create downsampling kernels for BD degradation kernel = data_utils.create_kernel(opt) # create model model = define_model(opt) # training configs total_sample = len(train_loader.dataset) iter_per_epoch = len(train_loader) total_iter = opt['train']['total_iter'] total_epoch = int(math.ceil(total_iter / iter_per_epoch)) start_iter, iter = opt['train']['start_iter'], 0 test_freq = opt['test']['test_freq'] log_freq = opt['logger']['log_freq'] ckpt_freq = opt['logger']['ckpt_freq'] logger.info('Number of training samples: {}'.format(total_sample)) logger.info('Total epochs needed: {} for {} iterations'.format( total_epoch, total_iter)) # train for epoch in range(total_epoch): for data in train_loader: # update iter iter += 1 curr_iter = start_iter + iter if iter > total_iter: logger.info('Finish training') break # update learning rate model.update_learning_rate() # prepare data data = prepare_data(opt, data, kernel) # train for a mini-batch model.train(data) # update running log model.update_running_log() # log if log_freq > 0 and iter % log_freq == 0: # basic info msg = '[epoch: {} | iter: {}'.format(epoch, curr_iter) for lr_type, lr in model.get_current_learning_rate().items(): msg += ' | {}: {:.2e}'.format(lr_type, lr) msg += '] ' # loss info log_dict = model.get_running_log() msg += ', '.join([ '{}: {:.3e}'.format(k, v) for k, v in log_dict.items()]) logger.info(msg) # save model if ckpt_freq > 0 and iter % ckpt_freq == 0: model.save(curr_iter) # evaluate performance if test_freq > 0 and iter % test_freq == 0: # setup model index model_idx = 'G_iter{}'.format(curr_iter) # for each testset for dataset_idx in sorted(opt['dataset'].keys()): # use dataset with prefix `test` if not dataset_idx.startswith('test'): continue ds_name = opt['dataset'][dataset_idx]['name'] logger.info( 'Testing on {}: {}'.format(dataset_idx, ds_name)) # create data loader test_loader = create_dataloader(opt, dataset_idx=dataset_idx) # define metric calculator metric_calculator = MetricCalculator(opt) # infer and compute metrics for each sequence for data in test_loader: # fetch data lr_data = data['lr'][0] seq_idx = data['seq_idx'][0] frm_idx = [frm_idx[0] for frm_idx in data['frm_idx']] # infer hr_seq = model.infer(lr_data) # thwc|rgb|uint8 # save results (optional) if opt['test']['save_res']: res_dir = osp.join( opt['test']['res_dir'], ds_name, model_idx) res_seq_dir = osp.join(res_dir, seq_idx) data_utils.save_sequence( res_seq_dir, hr_seq, frm_idx, to_bgr=True) # compute metrics for the current sequence true_seq_dir = osp.join( opt['dataset'][dataset_idx]['gt_seq_dir'], seq_idx) metric_calculator.compute_sequence_metrics( seq_idx, true_seq_dir, '', pred_seq=hr_seq) # save/print metrics if opt['test'].get('save_json'): # save results to json file json_path = osp.join( opt['test']['json_dir'], '{}_avg.json'.format(ds_name)) metric_calculator.save_results( model_idx, json_path, override=True) else: # print directly metric_calculator.display_results()
def train(opt): # logging logger = base_utils.get_logger('base') logger.info('{} Options {}'.format('='*20, '='*20)) base_utils.print_options(opt, logger) # create data loader train_loader = create_dataloader(opt, dataset_idx='train') # create downsampling kernels for BD degradation kernel = data_utils.create_kernel(opt) # create model model = define_model(opt) # training configs total_sample = len(train_loader.dataset) iter_per_epoch = len(train_loader) total_iter = opt['train']['total_iter'] total_epoch = int(math.ceil(total_iter / iter_per_epoch)) curr_iter = opt['train']['start_iter'] test_freq = opt['test']['test_freq'] log_freq = opt['logger']['log_freq'] ckpt_freq = opt['logger']['ckpt_freq'] sigma_freq = opt['dataset']['degradation'].get('sigma_freq', 0) sigma_inc = opt['dataset']['degradation'].get('sigma_inc', 0) sigma_max = opt['dataset']['degradation'].get('sigma_max', 10) logger.info('Number of training samples: {}'.format(total_sample)) logger.info('Total epochs needed: {} for {} iterations'.format( total_epoch, total_iter)) print('device count:', torch.cuda.device_count()) # train for epoch in range(total_epoch): for data in tqdm(train_loader): # update iter curr_iter += 1 if curr_iter > total_iter: logger.info('Finish training') break # update learning rate model.update_learning_rate() # prepare data data = prepare_data(opt, data, kernel) # train for a mini-batch model.train(data) # update running log model.update_running_log() # log if log_freq > 0 and curr_iter % log_freq == 0: # basic info msg = '[epoch: {} | iter: {}'.format(epoch, curr_iter) for lr_type, lr in model.get_current_learning_rate().items(): msg += ' | {}: {:.2e}'.format(lr_type, lr) msg += '] ' # loss info log_dict = model.get_running_log() msg += ', '.join([ '{}: {:.3e}'.format(k, v) for k, v in log_dict.items()]) if opt['dataset']['degradation']['type'] == 'BD': msg += ' | Sigma: {}'.format(opt['dataset']['degradation']['sigma']) logger.info(msg) # save model if ckpt_freq > 0 and curr_iter % ckpt_freq == 0: model.save(curr_iter) # evaluate performance if test_freq > 0 and curr_iter % test_freq == 0: # setup model index model_idx = 'G_iter{}'.format(curr_iter) if opt['dataset']['degradation']['type'] == 'BD': model_idx = model_idx + str(opt['dataset']['degradation']['sigma']) # for each testset for dataset_idx in sorted(opt['dataset'].keys()): # use dataset with prefix `test` if not dataset_idx.startswith('validate'): continue validate(opt, model, logger, dataset_idx, model_idx) # schedule sigma if opt['dataset']['degradation']['type'] == 'BD': if sigma_freq > 0 and (epoch + 1) % sigma_freq == 0: current_sigma = opt['dataset']['degradation']['sigma'] opt['dataset']['degradation']['sigma'] = min(current_sigma + sigma_inc, sigma_max) kernel = data_utils.create_kernel(opt) # __getitem__ in custom dataset class uses some crop that depends sigma # it is crucial to change this cropsize accordingly if sigma is being changed train_loader.dataset.change_cropsize(opt['dataset']['degradation']['sigma']) print('kernel changed')