def setup_optimizers(self): train_opt = self.opt['train'] optim_params = [] optim_params_lowlr = [] for k, v in self.net_g.named_parameters(): if v.requires_grad: if k.startswith('module.offsets') or k.startswith( 'module.dcns'): optim_params_lowlr.append(v) else: optim_params.append(v) else: logger = get_root_logger() logger.warning(f'Params {k} will not be optimized.') # print(optim_params) ratio = 0.1 optim_type = train_opt['optim_g'].pop('type') if optim_type == 'Adam': self.optimizer_g = torch.optim.Adam( [{ 'params': optim_params }, { 'params': optim_params_lowlr, 'lr': train_opt['optim_g']['lr'] * ratio }], **train_opt['optim_g']) # elif optim_type == 'SGD': # self.optimizer_g = torch.optim.SGD(optim_params, # **train_opt['optim_g']) else: raise NotImplementedError( f'optimizer {optim_type} is not supperted yet.') self.optimizers.append(self.optimizer_g)
def create_dataset(dataset_opt): """Create dataset. Args: dataset_opt (dict): Configuration for dataset. It constains: name (str): Dataset name. type (str): Dataset type. """ dataset_type = dataset_opt['type'] # dynamic instantiation for module in _dataset_modules: dataset_cls = getattr(module, dataset_type, None) if dataset_cls is not None: break if dataset_cls is None: raise ValueError(f'Dataset {dataset_type} is not found.') dataset = dataset_cls(dataset_opt) logger = get_root_logger() logger.info( f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} ' 'is created.') return dataset
def create_model(opt): """Create model. Args: opt (dict): Configuration. It constains: model_type (str): Model type. """ model_type = opt['model_type'] # dynamic instantiation for module in _model_modules: model_cls = getattr(module, model_type, None) if model_cls is not None: break if model_cls is None: raise ValueError(f'Model {model_type} is not found.') model = model_cls(opt) # dummy_input = torch.randn(1, 7, 3, 176, 144, device='cuda') # input_names = ['input'] # output_names = ['output'] # torch.onnx.export(model.net_g.module, dummy_input, "edvr.onnx", verbose=True, opset_version=11, input_names=input_names, output_names=output_names) logger = get_root_logger() logger.info(f'Model [{model.__class__.__name__}] is created.') return model
def check_resume(opt, resume_iter): """Check resume states and pretrain_model paths. Args: opt (dict): Options. resume_iter (int): Resume iteration. """ logger = get_root_logger() if opt['path']['resume_state']: # ignore pretrained model paths if opt['path'].get('pretrain_model_g', None) is not None or opt['path'].get( 'pretrain_model_d', None) is not None: logger.warning( 'pretrain_model path will be ignored during resuming.') # set pretrained model paths. opt['path']['pretrain_model_g'] = osp.join(opt['path']['models'], f'net_g_{resume_iter}.pth') logger.info( f"Set pretrain_model_g to {opt['path']['pretrain_model_g']}") opt['path']['pretrain_model_d'] = osp.join(opt['path']['models'], f'net_d_{resume_iter}.pth') logger.info( f"Set pretrain_model_d to {opt['path']['pretrain_model_d']}")
def optimize_parameters(self, current_iter): if self.train_tsa_iter: if current_iter == 1: logger = get_root_logger() logger.info( f'Only train TSA module for {self.train_tsa_iter} iters.') for name, param in self.net_g.named_parameters(): if 'fusion' not in name: param.requires_grad = False elif current_iter == self.train_tsa_iter: logger = get_root_logger() logger.warning('Train all the parameters.') for param in self.net_g.parameters(): param.requires_grad = True super(EDVRModel, self).optimize_parameters(current_iter)
def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) # load gt image gt_path = self.paths[index] # avoid errors caused by high latency in reading files retry = 3 while retry > 0: try: img_bytes = self.file_client.get(gt_path) except Exception as e: logger = get_root_logger() logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}') # change another file to read index = random.randint(0, self.__len__()) gt_path = self.paths[index] time.sleep(1) # sleep 1s for occasional server congestion else: break finally: retry -= 1 img_gt = imfrombytes(img_bytes, float32=True) # random horizontal flip img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) # BGR to RGB, HWC to CHW, numpy to tensor img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) # normalize normalize(img_gt, self.mean, self.std, inplace=True) return {'gt': img_gt, 'gt_path': gt_path}
def __init__(self, opt): super(Vimeo90KDataset, self).__init__() self.opt = opt self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path( opt['dataroot_lq']) with open(opt['meta_info_file'], 'r') as fin: self.keys = [line.split(' ')[0] for line in fin] # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.is_lmdb = False if self.io_backend_opt['type'] == 'lmdb': self.is_lmdb = True self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] self.io_backend_opt['client_keys'] = ['lq', 'gt'] # indices of input images self.neighbor_list = [ i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame']) ] # temporal augmentation configs self.random_reverse = opt['random_reverse'] logger = get_root_logger() logger.info(f'Random reverse is {self.random_reverse}.')
def setup_optimizers(self): train_opt = self.opt['train'] dcn_lr_mul = train_opt.get('dcn_lr_mul', 1) logger = get_root_logger() logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.') if dcn_lr_mul == 1: optim_params = self.net_g.parameters() else: # separate dcn params and normal params for different lr normal_params = [] dcn_params = [] for name, param in self.net_g.named_parameters(): if 'dcn' in name: dcn_params.append(param) else: normal_params.append(param) optim_params = [ { # add normal params first 'params': normal_params, 'lr': train_opt['optim_g']['lr'] }, { 'params': dcn_params, 'lr': train_opt['optim_g']['lr'] * dcn_lr_mul }, ] optim_type = train_opt['optim_g'].pop('type') self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) self.optimizers.append(self.optimizer_g)
def build_network(opt): opt = deepcopy(opt) network_type = opt.pop('type') net = ARCH_REGISTRY.get(network_type)(**opt) logger = get_root_logger() logger.info(f'Network [{net.__class__.__name__}] is created.') return net
def load_network(self, net, load_path, strict=True, param_key='params'): """Load network. Args: load_path (str): The path of networks to be loaded. net (nn.Module): Network. strict (bool): Whether strictly loaded. param_key (str): The parameter key of loaded network. If set to None, use the root 'path'. Default: 'params'. """ logger = get_root_logger() net = self.get_bare_model(net) load_net = torch.load(load_path, map_location=lambda storage, loc: storage) if param_key is not None: if param_key not in load_net and 'params' in load_net: param_key = 'params' logger.info('Loading: params_ema does not exist, use params.') load_net = load_net[param_key] logger.info( f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].' ) # remove unnecessary 'module.' for k, v in deepcopy(load_net).items(): if k.startswith('module.'): load_net[k[7:]] = v load_net.pop(k) self._print_different_keys_loading(net, load_net, strict) net.load_state_dict(load_net, strict=strict)
def _print_different_keys_loading(self, crt_net, load_net, strict=True): """Print keys with different name or different size when loading models. 1. Print keys with different names. 2. If strict=False, print the same key but with different tensor size. It also ignore these keys with different sizes (not load). Args: crt_net (torch model): Current network. load_net (dict): Loaded network. strict (bool): Whether strictly loaded. Default: True. """ crt_net = self.get_bare_model(crt_net) crt_net = crt_net.state_dict() crt_net_keys = set(crt_net.keys()) load_net_keys = set(load_net.keys()) logger = get_root_logger() if crt_net_keys != load_net_keys: logger.warning('Current net - loaded net:') for v in sorted(list(crt_net_keys - load_net_keys)): logger.warning(f' {v}') logger.warning('Loaded net - current net:') for v in sorted(list(load_net_keys - crt_net_keys)): logger.warning(f' {v}') # check the size for the same keys if not strict: common_keys = crt_net_keys & load_net_keys for k in common_keys: if crt_net[k].size() != load_net[k].size(): logger.warning( f'Size different, ignore [{k}]: crt_net: ' f'{crt_net[k].shape}; load_net: {load_net[k].shape}') load_net[k + '.ignore'] = load_net.pop(k)
def init_training_settings(self): train_opt = self.opt['train'] self.ema_decay = train_opt.get('ema_decay', 0) if self.ema_decay > 0: logger = get_root_logger() logger.info( f'Use Exponential Moving Average with decay: {self.ema_decay}') # define network net_g with Exponential Moving Average (EMA) # net_g_ema is used only for testing on one GPU and saving # There is no need to wrap with DistributedDataParallel self.net_g_ema = build_network(self.opt['network_g']).to( self.device) # load pretrained model load_path = self.opt['path'].get('pretrain_network_g', None) if load_path is not None: self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') else: self.model_ema(0) # copy net_g weight self.net_g_ema.eval() # define network net_d self.net_d = build_network(self.opt['network_d']) self.net_d = self.model_to_device(self.net_d) self.print_network(self.net_d) # load pretrained models load_path = self.opt['path'].get('pretrain_network_d', None) if load_path is not None: param_key = self.opt['path'].get('param_key_d', 'params') self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key) self.net_g.train() self.net_d.train() # define losses if train_opt.get('pixel_opt'): self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) else: self.cri_pix = None if train_opt.get('perceptual_opt'): self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to( self.device) else: self.cri_perceptual = None if train_opt.get('gan_opt'): self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) self.net_d_iters = train_opt.get('net_d_iters', 1) self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) # set up optimizers and schedulers self.setup_optimizers() self.setup_schedulers()
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): log_str = f'Validation {dataset_name}\n' for metric, value in self.metric_results.items(): log_str += f'\t # {metric}: {value:.4f}\n' logger = get_root_logger() logger.info(log_str) if tb_logger: for metric, value in self.metric_results.items(): tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
def main(): # options parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = parse(args.opt, is_train=False) # distributed testing settings if args.launcher == 'none': # disabled distributed training opt['dist'] = False print('Disabled distributed testing.', flush=True) else: opt['dist'] = True if args.launcher == 'slurm' and 'dist_params' in opt: init_dist(args.launcher, **opt['dist_params']) else: init_dist(args.launcher) opt = dict_to_nonedict(opt) make_exp_dirs(opt) log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) logger.info(get_env_info()) logger.info(dict2str(opt)) # create test dataset and dataloader test_loaders = [] for phase, dataset_opt in sorted(opt['datasets'].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt) logger.info( f"Number of test images in {dataset_opt['name']}: {len(test_set)}") test_loaders.append(test_loader) # create model model = create_model(opt) for test_loader in test_loaders: test_set_name = test_loader.dataset.opt['name'] logger.info(f'Testing {test_set_name}...') model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img'])
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): # ----------------- calculate the average values for each folder, and for each metric ----------------- # # average all frames for each sub-folder # metric_results_avg is a dict:{ # 'folder1': tensor (len(metrics)), # 'folder2': tensor (len(metrics)) # } metric_results_avg = { folder: torch.mean(tensor, dim=0).cpu() for (folder, tensor) in self.metric_results.items() } # total_avg_results is a dict: { # 'metric1': float, # 'metric2': float # } total_avg_results = { metric: 0 for metric in self.opt['val']['metrics'].keys() } for folder, tensor in metric_results_avg.items(): for idx, metric in enumerate(total_avg_results.keys()): total_avg_results[metric] += metric_results_avg[folder][ idx].item() # average among folders for metric in total_avg_results.keys(): total_avg_results[metric] /= len(metric_results_avg) # update the best metric result self._update_best_metric_result(dataset_name, metric, total_avg_results[metric], current_iter) # ------------------------------------------ log the metric ------------------------------------------ # log_str = f'Validation {dataset_name}\n' for metric_idx, (metric, value) in enumerate(total_avg_results.items()): log_str += f'\t # {metric}: {value:.4f}' for folder, tensor in metric_results_avg.items(): log_str += f'\t # {folder}: {tensor[metric_idx].item():.4f}' if hasattr(self, 'best_metric_results'): log_str += ( f'\n\t Best: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ ' f'{self.best_metric_results[dataset_name][metric]["iter"]} iter' ) log_str += '\n' logger = get_root_logger() logger.info(log_str) if tb_logger: for metric_idx, (metric, value) in enumerate(total_avg_results.items()): tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) for folder, tensor in metric_results_avg.items(): tb_logger.add_scalar(f'metrics/{metric}/{folder}', tensor[metric_idx].item(), current_iter)
def dist_validation(self, dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image): logger = get_root_logger() # logger.info('Only support single GPU validation.') import os if os.environ['LOCAL_RANK'] == '0': return self.nondist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image) else: return 0.
def test(self): with torch.no_grad(): if hasattr(self, 'net_g_ema'): self.net_g_ema.eval() self.output, _ = self.net_g_ema(self.lq) else: logger = get_root_logger() logger.warning('Do not have self.net_g_ema, use self.net_g.') self.net_g.eval() self.output, _ = self.net_g(self.lq) self.net_g.train()
def __init__(self, opt): super(REDSDataset, self).__init__() self.opt = opt self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path( opt['dataroot_lq']) self.flow_root = Path( opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None assert opt['num_frame'] % 2 == 1, ( f'num_frame should be odd number, but got {opt["num_frame"]}') self.num_frame = opt['num_frame'] self.num_half_frames = opt['num_frame'] // 2 self.keys = [] with open(opt['meta_info_file'], 'r') as fin: for line in fin: folder, frame_num, _ = line.split(' ') self.keys.extend( [f'{folder}/{i:08d}' for i in range(int(frame_num))]) # remove the video clips used in validation if opt['val_partition'] == 'REDS4': val_partition = ['000', '011', '015', '020'] elif opt['val_partition'] == 'official': val_partition = [f'{v:03d}' for v in range(240, 270)] else: raise ValueError( f'Wrong validation partition {opt["val_partition"]}.' f"Supported ones are ['official', 'REDS4'].") self.keys = [ v for v in self.keys if v.split('/')[0] not in val_partition ] # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.is_lmdb = False if self.io_backend_opt['type'] == 'lmdb': self.is_lmdb = True if self.flow_root is not None: self.io_backend_opt['db_paths'] = [ self.lq_root, self.gt_root, self.flow_root ] self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow'] else: self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] self.io_backend_opt['client_keys'] = ['lq', 'gt'] # temporal augmentation configs self.interval_list = opt['interval_list'] self.random_reverse = opt['random_reverse'] interval_str = ','.join(str(x) for x in opt['interval_list']) logger = get_root_logger() logger.info(f'Temporal augmentation interval list: [{interval_str}]; ' f'random reverse is {self.random_reverse}.')
def build_model(opt): """Build model from options. Args: opt (dict): Configuration. It must contain: model_type (str): Model type. """ opt = deepcopy(opt) model = MODEL_REGISTRY.get(opt['model_type'])(opt) logger = get_root_logger() logger.info(f'Model [{model.__class__.__name__}] is created.') return model
def setup_optimizers(self): train_opt = self.opt['train'] optim_params = [] for k, v in self.net_g.named_parameters(): if v.requires_grad: optim_params.append(v) else: logger = get_root_logger() logger.warning(f'Params {k} will not be optimized.') optim_type = train_opt['optim_g'].pop('type') self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) self.optimizers.append(self.optimizer_g)
def build_loss(opt): """Build loss from options. Args: opt (dict): Configuration. It must constain: type (str): Model type. """ opt = deepcopy(opt) loss_type = opt.pop('type') loss = LOSS_REGISTRY.get(loss_type)(**opt) logger = get_root_logger() logger.info(f'Loss [{loss.__class__.__name__}] is created.') return loss
def forward(self, x, feat): out = self.conv_offset(feat) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) offset_absmean = torch.mean(torch.abs(offset)) if offset_absmean > 50: logger = get_root_logger() logger.warning( f'Offset abs mean is {offset_absmean}, larger than 50.') return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.deformable_groups)
def optimize_parameters(self, current_iter): if self.fix_flow_iter: logger = get_root_logger() if current_iter == 1: logger.info( f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.' ) for name, param in self.net_g.named_parameters(): if 'spynet' in name or 'edvr' in name: param.requires_grad_(False) elif current_iter == self.fix_flow_iter: logger.warning('Train all the parameters.') self.net_g.requires_grad_(True) super(VideoRecurrentModel, self).optimize_parameters(current_iter)
def build_dataset(dataset_opt): """Build dataset from options. Args: dataset_opt (dict): Configuration for dataset. It must contain: name (str): Dataset name. type (str): Dataset type. """ dataset_opt = deepcopy(dataset_opt) dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) logger = get_root_logger() logger.info( f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.') return dataset
def save_network(self, net, net_label, current_iter, param_key='params'): """Save networks. Args: net (nn.Module | list[nn.Module]): Network(s) to be saved. net_label (str): Network label. current_iter (int): Current iter number. param_key (str | list[str]): The parameter key(s) to save network. Default: 'params'. """ if current_iter == -1: current_iter = 'latest' save_filename = f'{net_label}_{current_iter}.pth' save_path = os.path.join(self.opt['path']['models'], save_filename) net = net if isinstance(net, list) else [net] param_key = param_key if isinstance(param_key, list) else [param_key] assert len(net) == len( param_key), 'The lengths of net and param_key should be the same.' save_dict = {} for net_, param_key_ in zip(net, param_key): net_ = self.get_bare_model(net_) state_dict = net_.state_dict() for key, param in state_dict.items(): if key.startswith('module.'): # remove unnecessary 'module.' key = key[7:] state_dict[key] = param.cpu() save_dict[param_key_] = state_dict # avoid occasional writing errors retry = 3 while retry > 0: try: torch.save(save_dict, save_path) except Exception as e: logger = get_root_logger() logger.warning( f'Save model error: {e}, remaining retry times: {retry - 1}' ) time.sleep(1) else: break finally: retry -= 1 if retry == 0: logger.warning(f'Still cannot save {save_path}. Just ignore it.')
def main(): # parse options, set distributed setting, set ramdom seed opt = parse_options(is_train=False) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True # mkdir and initialize loggers make_exp_dirs(opt) log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") logger = get_root_logger( logger_name='basicsr', log_level=logging.INFO, log_file=log_file) logger.info(get_env_info()) logger.info(dict2str(opt)) # create test dataset and dataloader test_loaders = [] for phase, dataset_opt in sorted(opt['datasets'].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader( test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) logger.info( f"Number of test images in {dataset_opt['name']}: {len(test_set)}") test_loaders.append(test_loader) # create model model = create_model(opt) for test_loader in test_loaders: test_set_name = test_loader.dataset.opt['name'] logger.info(f'Testing {test_set_name}...') rgb2bgr = opt['val'].get('rgb2bgr', True) # wheather use uint8 image to compute metrics use_image = opt['val'].get('use_image', True) model.validation( test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img'], rgb2bgr=rgb2bgr, use_image=use_image)
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): # average all frames for each sub-folder # metric_results_avg is a dict:{ # 'folder1': tensor (len(metrics)), # 'folder2': tensor (len(metrics)) # } metric_results_avg = { folder: torch.mean(tensor, dim=0).cpu() for (folder, tensor) in self.metric_results.items() } # total_avg_results is a dict: { # 'metric1': float, # 'metric2': float # } total_avg_results = { metric: 0 for metric in self.opt['val']['metrics'].keys() } for folder, tensor in metric_results_avg.items(): for idx, metric in enumerate(total_avg_results.keys()): total_avg_results[metric] += metric_results_avg[folder][ idx].item() # average among folders ''' for metric in total_avg_results.keys(): total_avg_results[metric] /= len(metric_results_avg) ''' log_str = f'Validation {dataset_name}\n' for metric_idx, (metric, value) in enumerate(total_avg_results.items()): log_str += f'\t # {metric}: {value:.4f}' for folder, tensor in metric_results_avg.items(): log_str += f'\t # {folder}: {tensor[metric_idx].item():.4f}' log_str += '\n' logger = get_root_logger() logger.info(log_str) if tb_logger: for metric_idx, (metric, value) in enumerate(total_avg_results.items()): tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) for folder, tensor in metric_results_avg.items(): tb_logger.add_scalar(f'metrics/{metric}/{folder}', tensor[metric_idx].item(), current_iter)
def init_training_settings(self): self.net_g.train() train_opt = self.opt["train"] self.ema_decay = train_opt.get("ema_decay", 0) if self.ema_decay > 0: logger = get_root_logger() logger.info( f"Use Exponential Moving Average with decay: {self.ema_decay}") # define network net_g with Exponential Moving Average (EMA) # net_g_ema is used only for testing on one GPU and saving # There is no need to wrap with DistributedDataParallel self.net_g_ema = build_network(self.opt["network_g"]).to( self.device) # load pretrained model load_path = self.opt["path"].get("pretrain_network_g", None) if load_path is not None: self.load_network( self.net_g_ema, load_path, self.opt["path"].get("strict_load_g", True), "params_ema", ) else: self.model_ema(0) # copy net_g weight self.net_g_ema.eval() # define losses if train_opt.get("pixel_opt"): self.cri_pix = build_loss(train_opt["pixel_opt"]).to(self.device) else: self.cri_pix = None if train_opt.get("perceptual_opt"): self.cri_perceptual = build_loss(train_opt["perceptual_opt"]).to( self.device) else: self.cri_perceptual = None if self.cri_pix is None and self.cri_perceptual is None: raise ValueError("Both pixel and perceptual losses are None.") # set up optimizers and schedulers self.setup_optimizers() self.setup_schedulers()
def __init__(self, opt): super(VidTestDataset, self).__init__() self.opt = opt self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq'] # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] assert self.io_backend_opt[ 'type'] != 'lmdb', 'No need to use lmdb during validation/test.' logger = get_root_logger() logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}') self.data_info = { 'lq_path': [], 'gt_path': [], 'clip_name': [], 'max_idx': [], } self.lq_frames, self.gt_frames = {}, {} self.clip_list = os.listdir(osp.abspath(self.gt_root)) self.clip_list.sort() for clip_name in self.clip_list: lq_frames_path = osp.join(self.lq_root, clip_name) lq_frames_path = sorted( list(scandir(lq_frames_path, full_path=True))) gt_frames_path = osp.join(self.gt_root, clip_name) gt_frames_path = sorted( list(scandir(gt_frames_path, full_path=True))) max_idx = len(lq_frames_path) assert max_idx == len(lq_frames_path), ( f'Different number of images in lq ({max_idx})' f' and gt folders ({len(gt_frames_path)})') self.data_info['lq_path'].extend(lq_frames_path) self.data_info['gt_path'].extend(gt_frames_path) self.data_info['clip_name'].append(clip_name) self.data_info['max_idx'].append(max_idx) self.lq_frames[clip_name] = lq_frames_path self.gt_frames[clip_name] = gt_frames_path
def setup_optimizers(self): train_opt = self.opt['train'] optim_params = [] for k, v in self.net_g.named_parameters(): if v.requires_grad: optim_params.append(v) else: logger = get_root_logger() logger.warning(f'Params {k} will not be optimized.') optim_type = train_opt['optim_g'].pop('type') if optim_type == 'Adam': self.optimizer_g = torch.optim.Adam(optim_params, **train_opt['optim_g']) else: raise NotImplementedError( f'optimizer {optim_type} is not supperted yet.') self.optimizers.append(self.optimizer_g)