def test_compose(): with pytest.raises(TypeError): Compose('LoadAlpha') target_keys = ['img', 'meta'] img = np.random.randn(256, 256, 3) results = dict(img=img, abandoned_key=None, img_name='test_image.png') test_pipeline = [ dict(type='Collect', keys=['img'], meta_keys=['img_name']), dict(type='ImageToTensor', keys=['img']) ] compose = Compose(test_pipeline) compose_results = compose(results) assert check_keys_equal(compose_results.keys(), target_keys) assert check_keys_equal(compose_results['meta'].data.keys(), ['img_name']) results = None image_to_tensor = ImageToTensor(keys=[]) test_pipeline = [image_to_tensor] compose = Compose(test_pipeline) compose_results = compose(results) assert compose_results is None assert repr(compose) == (compose.__class__.__name__ + f'(\n {image_to_tensor}\n)')
def restoration_inference(model, img): """Inference image with the model. Args: model (nn.Module): The loaded model. img (str): File path of input image. Returns: Tensor: The predicted restoration result. """ cfg = model.cfg device = next(model.parameters()).device # model device # remove gt from test_pipeline keys_to_remove = ['gt', 'gt_path'] for key in keys_to_remove: for pipeline in list(cfg.test_pipeline): if 'key' in pipeline and key == pipeline['key']: cfg.test_pipeline.remove(pipeline) if 'keys' in pipeline and key in pipeline['keys']: pipeline['keys'].remove(key) if len(pipeline['keys']) == 0: cfg.test_pipeline.remove(pipeline) if 'meta_keys' in pipeline and key in pipeline['meta_keys']: pipeline['meta_keys'].remove(key) # build the data pipeline test_pipeline = Compose(cfg.test_pipeline) # prepare data data = dict(lq_path=img) data = test_pipeline(data) data = scatter(collate([data], samples_per_gpu=1), [device])[0] # forward the model with torch.no_grad(): result = model(test_mode=True, **data) return result['output']
def _prepare_input_img(model_type: str, img_path: str, config: dict, rescale_shape: Optional[Iterable] = None) -> dict: """Prepare the input image Args: model_type (str): which kind of model config belong to, \ one of ['inpainting', 'mattor', 'restorer', 'synthesizer']. img_path (str): image path to show or verify. config (dict): MMCV config, determined by the inpupt config file. rescale_shape (Optional[Iterable]): to rescale the shape of the \ input tensor. Returns: dict: {'imgs': imgs, 'img_metas': img_metas} """ # remove alpha from test_pipeline model_type = model_type if model_type == 'mattor': keys_to_remove = ['alpha', 'ori_alpha'] elif model_type == 'restorer': keys_to_remove = ['gt', 'gt_path'] for key in keys_to_remove: for pipeline in list(config.test_pipeline): if 'key' in pipeline and key == pipeline['key']: config.test_pipeline.remove(pipeline) if 'keys' in pipeline and key in pipeline['keys']: pipeline['keys'].remove(key) if len(pipeline['keys']) == 0: config.test_pipeline.remove(pipeline) if 'meta_keys' in pipeline and key in pipeline['meta_keys']: pipeline['meta_keys'].remove(key) # build the data pipeline test_pipeline = Compose(config.test_pipeline) # prepare data if model_type == 'mattor': raise RuntimeError('Invalid model_type!', model_type) if model_type == 'restorer': data = dict(lq_path=img_path) data = test_pipeline(data) if model_type == 'restorer': imgs = data['lq'] else: imgs = data['img'] img_metas = [data['meta']] if rescale_shape is not None: for img_meta in img_metas: img_meta['ori_shape'] = tuple(rescale_shape) + (3, ) mm_inputs = {'imgs': imgs, 'img_metas': img_metas} return mm_inputs
def restoration_video_inference(model, img_dir, window_size, filename_tmpl): """Inference image with the model. Args: model (nn.Module): The loaded model. img_dir (str): Directory of the input video. window_size (int): The window size used in sliding-window framework. This value should be set according to the settings of the network. A value smaller than 0 means using recurrent framework. Returns: Tensor: The predicted restoration result. """ device = next(model.parameters()).device # model device # pipeline test_pipeline = [ dict(type='GenerateSegmentIndices', interval_list=[1], filename_tmpl=filename_tmpl), dict(type='LoadImageFromFileList', io_backend='disk', key='lq', channel_order='rgb'), dict(type='RescaleToZeroOne', keys=['lq']), dict(type='FramesToTensor', keys=['lq']), dict(type='Collect', keys=['lq'], meta_keys=['lq_path', 'key']) ] # build the data pipeline test_pipeline = Compose(test_pipeline) # prepare data sequence_length = len(glob.glob(f'{img_dir}/*')) key = img_dir.split('/')[-1] lq_folder = '/'.join(img_dir.split('/')[:-1]) data = dict(lq_path=lq_folder, gt_path='', key=key, sequence_length=sequence_length) data = test_pipeline(data) data = scatter(collate([data], samples_per_gpu=1), [device])[0]['lq'] # forward the model with torch.no_grad(): if window_size > 0: # sliding window framework data = pad_sequence(data, window_size) result = [] for i in range(0, data.size(1) - 2 * window_size): data_i = data[:, i:i + window_size] result.append(model(lq=data_i, test_mode=True)['output']) result = torch.stack(result, dim=1) else: # recurrent framework result = model(lq=data, test_mode=True)['output'] return result
def generation_inference(model, img, img_unpaired=None): """Inference image with the model. Args: model (nn.Module): The loaded model. img (str): File path of input image. img_unpaired (str, optional): File path of the unpaired image. If not None, perform unpaired image generation. Default: None. Returns: np.ndarray: The predicted generation result. """ cfg = model.cfg device = next(model.parameters()).device # model device # build the data pipeline test_pipeline = Compose(cfg.test_pipeline) # prepare data if img_unpaired is None: data = dict(pair_path=img) else: data = dict(img_a_path=img, img_b_path=img_unpaired) data = test_pipeline(data) data = scatter(collate([data], samples_per_gpu=1), [device])[0] # forward the model with torch.no_grad(): results = model(test_mode=True, **data) # process generation shown mode if img_unpaired is None: if model.show_input: output = np.concatenate([ tensor2img(results['real_a'], min_max=(-1, 1)), tensor2img(results['fake_b'], min_max=(-1, 1)), tensor2img(results['real_b'], min_max=(-1, 1)) ], axis=1) else: output = tensor2img(results['fake_b'], min_max=(-1, 1)) else: if model.show_input: output = np.concatenate([ tensor2img(results['real_a'], min_max=(-1, 1)), tensor2img(results['fake_b'], min_max=(-1, 1)), tensor2img(results['real_b'], min_max=(-1, 1)), tensor2img(results['fake_a'], min_max=(-1, 1)) ], axis=1) else: if model.test_direction == 'a2b': output = tensor2img(results['fake_b'], min_max=(-1, 1)) else: output = tensor2img(results['fake_a'], min_max=(-1, 1)) return output
def matting_inference(model, img, trimap): """Inference image(s) with the model. Args: model (nn.Module): The loaded model. img (str): Image file path. trimap (str): Trimap file path. Returns: np.ndarray: The predicted alpha matte. """ cfg = model.cfg device = next(model.parameters()).device # model device # remove alpha from test_pipeline keys_to_remove = ['alpha', 'ori_alpha'] for key in keys_to_remove: for pipeline in list(cfg.test_pipeline): if 'key' in pipeline and key == pipeline['key']: cfg.test_pipeline.remove(pipeline) if 'keys' in pipeline and key in pipeline['keys']: pipeline['keys'].remove(key) if len(pipeline['keys']) == 0: cfg.test_pipeline.remove(pipeline) if 'meta_keys' in pipeline and key in pipeline['meta_keys']: pipeline['meta_keys'].remove(key) # build the data pipeline test_pipeline = Compose(cfg.test_pipeline) # prepare data data = dict(merged_path=img, trimap_path=trimap) data = test_pipeline(data) # # Test Code # merged = data['merged'] # ori_merged = data['ori_merged'] # trimap = data['trimap'] # trimap_transformed = data['trimap_transformed'] # ori_merged.cpu().numpy().tofile('dat/' + 'ori_merged_new' + '.dat') # merged.cpu().numpy().tofile('dat/' + 'merged_rgbtrue' + '.dat') # trimap.cpu().numpy().tofile('dat/' + 'trimap' + '.dat') # trimap_transformed.numpy().tofile('dat/' + 'trimap_transformed' + '.dat') data = scatter(collate([data], samples_per_gpu=1), [device])[0] # merged = data['merged'] # merged.cpu().numpy().tofile('dat/merged_tensor_new_norm_list.dat') print("data prepare success!!!") # forward the model with torch.no_grad(): result = model(test_mode=True, **data) return result['pred_alpha']
def matting_inference_file(model, img, trimap=None, mask=None, image_path="input file directly"): """Inference image(s) with the model. Args: model (nn.Module): The loaded model. img (str): Image file path. trimap (str): Trimap file path. Returns: np.ndarray: The predicted alpha matte. """ assert trimap is not None or mask is not None cfg = model.cfg device = next(model.parameters()).device # model device # remove alpha from test_pipeline keys_to_remove = ['alpha', 'ori_alpha'] for key in keys_to_remove: for pipeline in list(cfg.test_pipeline): if 'key' in pipeline and key == pipeline['key']: cfg.test_pipeline.remove(pipeline) if 'keys' in pipeline and key in pipeline['keys']: pipeline['keys'].remove(key) if len(pipeline['keys']) == 0: cfg.test_pipeline.remove(pipeline) if 'meta_keys' in pipeline and key in pipeline['meta_keys']: pipeline['meta_keys'].remove(key) # build the data pipeline test_pipeline = cfg.test_pipeline[2:] test_pipeline = Compose(test_pipeline) # prepare data data = dict(merged=img, mask=mask, ori_mask=mask, trimap=trimap, ori_trimap=trimap, ori_merged=img.copy(), merged_path=image_path, merged_ori_shape=img.shape) data = test_pipeline(data) data = scatter(collate([data], samples_per_gpu=1), [device])[0] # forward the model with torch.no_grad(): result = model(test_mode=True, **data) return result
def inpainting_inference(model, masked_img, mask): """Inference image with the model. Args: model (nn.Module): The loaded model. masked_img (str): File path of image with mask. mask (str): Mask file path. Returns: Tensor: The predicted inpainting result. """ device = next(model.parameters()).device # model device infer_pipeline = [ dict(type='LoadImageFromFile', key='masked_img'), dict(type='LoadMask', mask_mode='file', mask_config=dict()), dict(type='Pad', keys=['masked_img', 'mask'], mode='reflect'), dict( type='Normalize', keys=['masked_img'], mean=[127.5] * 3, std=[127.5] * 3, to_rgb=False), dict(type='GetMaskedImage', img_name='masked_img'), dict( type='Collect', keys=['masked_img', 'mask'], meta_keys=['masked_img_path']), dict(type='ImageToTensor', keys=['masked_img', 'mask']) ] # build the data pipeline test_pipeline = Compose(infer_pipeline) # prepare data data = dict(masked_img_path=masked_img, mask_path=mask) data = test_pipeline(data) data = scatter(collate([data], samples_per_gpu=1), [device])[0] # forward the model with torch.no_grad(): result = model(test_mode=True, **data) return result['fake_img']
def matting_inference(model, img, trimap): """Inference image(s) with the model. Args: model (nn.Module): The loaded model. img (str): Image file path. trimap (str): Trimap file path. Returns: np.ndarray: The predicted alpha matte. """ cfg = model.cfg device = next(model.parameters()).device # model device # remove alpha from test_pipeline keys_to_remove = ['alpha', 'ori_alpha'] for key in keys_to_remove: for pipeline in list(cfg.test_pipeline): if 'key' in pipeline and key == pipeline['key']: cfg.test_pipeline.remove(pipeline) if 'keys' in pipeline and key in pipeline['keys']: pipeline['keys'].remove(key) if len(pipeline['keys']) == 0: cfg.test_pipeline.remove(pipeline) if 'meta_keys' in pipeline and key in pipeline['meta_keys']: pipeline['meta_keys'].remove(key) # build the data pipeline test_pipeline = Compose(cfg.test_pipeline) # prepare data data = dict(merged_path=img, trimap_path=trimap) data = test_pipeline(data) data = scatter(collate([data], samples_per_gpu=1), [device])[0] # forward the model with torch.no_grad(): result = model(test_mode=True, **data) return result['pred_alpha']
def restoration_video_inference(model, img_dir, window_size, start_idx, filename_tmpl): """Inference image with the model. Args: model (nn.Module): The loaded model. img_dir (str): Directory of the input video. window_size (int): The window size used in sliding-window framework. This value should be set according to the settings of the network. A value smaller than 0 means using recurrent framework. start_idx (int): The index corresponds to the first frame in the sequence. filename_tmpl (str): Template for file name. Returns: Tensor: The predicted restoration result. """ device = next(model.parameters()).device # model device # build the data pipeline if model.cfg.get('demo_pipeline', None): test_pipeline = model.cfg.demo_pipeline elif model.cfg.get('test_pipeline', None): test_pipeline = model.cfg.test_pipeline else: test_pipeline = model.cfg.val_pipeline # the first element in the pipeline must be 'GenerateSegmentIndices' if test_pipeline[0]['type'] != 'GenerateSegmentIndices': raise TypeError('The first element in the pipeline must be ' f'"GenerateSegmentIndices", but got ' f'"{test_pipeline[0]["type"]}".') # specify start_idx and filename_tmpl test_pipeline[0]['start_idx'] = start_idx test_pipeline[0]['filename_tmpl'] = filename_tmpl # compose the pipeline test_pipeline = Compose(test_pipeline) # prepare data sequence_length = len(glob.glob(f'{img_dir}/*')) key = img_dir.split('/')[-1] lq_folder = '/'.join(img_dir.split('/')[:-1]) data = dict( lq_path=lq_folder, gt_path='', key=key, sequence_length=sequence_length) data = test_pipeline(data) data = scatter(collate([data], samples_per_gpu=1), [device])[0]['lq'] # forward the model with torch.no_grad(): if window_size > 0: # sliding window framework data = pad_sequence(data, window_size) result = [] for i in range(0, data.size(1) - 2 * (window_size // 2)): data_i = data[:, i:i + window_size] result.append(model(lq=data_i, test_mode=True)['output']) result = torch.stack(result, dim=1) else: # recurrent framework result = model(lq=data, test_mode=True)['output'] return result
if model_type == 'mattor': keys_to_remove = ['alpha', 'ori_alpha'] elif model_type == 'restorer': keys_to_remove = ['gt', 'gt_path'] for key in keys_to_remove: for pipeline in list(config.test_pipeline): if 'key' in pipeline and key == pipeline['key']: config.test_pipeline.remove(pipeline) if 'keys' in pipeline and key in pipeline['keys']: pipeline['keys'].remove(key) if len(pipeline['keys']) == 0: config.test_pipeline.remove(pipeline) if 'meta_keys' in pipeline and key in pipeline['meta_keys']: pipeline['meta_keys'].remove(key) # build the data pipeline test_pipeline = Compose(config.test_pipeline) # prepare data if model_type == 'mattor': data = dict(merged_path=args.img_path, trimap_path=args.trimap_path) elif model_type == 'restorer': data = dict(lq_path=args.img_path) data = test_pipeline(data) # convert model to onnx file pytorch2onnx(model, data, model_type, opset_version=args.opset_version, show=args.show, output_file=args.output_file, verify=args.verify,
def restoration_face_inference(model, img, upscale_factor=1, face_size=1024): """Inference image with the model. Args: model (nn.Module): The loaded model. img (str): File path of input image. upscale_factor (int, optional): The number of times the input image is upsampled. Default: 1. face_size (int, optional): The size of the cropped and aligned faces. Default: 1024. Returns: Tensor: The predicted restoration result. """ device = next(model.parameters()).device # model device # build the data pipeline if model.cfg.get('demo_pipeline', None): test_pipeline = model.cfg.demo_pipeline elif model.cfg.get('test_pipeline', None): test_pipeline = model.cfg.test_pipeline else: test_pipeline = model.cfg.val_pipeline # remove gt from test_pipeline keys_to_remove = ['gt', 'gt_path'] for key in keys_to_remove: for pipeline in list(test_pipeline): if 'key' in pipeline and key == pipeline['key']: test_pipeline.remove(pipeline) if 'keys' in pipeline and key in pipeline['keys']: pipeline['keys'].remove(key) if len(pipeline['keys']) == 0: test_pipeline.remove(pipeline) if 'meta_keys' in pipeline and key in pipeline['meta_keys']: pipeline['meta_keys'].remove(key) # build the data pipeline test_pipeline = Compose(test_pipeline) # face helper for detecting and aligning faces assert has_facexlib, 'Please install FaceXLib to use the demo.' face_helper = FaceRestoreHelper(upscale_factor, face_size=face_size, crop_ratio=(1, 1), det_model='retinaface_resnet50', template_3points=True, save_ext='png', device=device) face_helper.read_image(img) # get face landmarks for each face face_helper.get_face_landmarks_5(only_center_face=False, eye_dist_threshold=None) # align and warp each face face_helper.align_warp_face() for i, img in enumerate(face_helper.cropped_faces): # prepare data data = dict(lq=img.astype(np.float32)) data = test_pipeline(data) data = scatter(collate([data], samples_per_gpu=1), [device])[0] with torch.no_grad(): output = model(test_mode=True, **data)['output'].clip_(0, 1) output = output.squeeze(0).permute(1, 2, 0)[:, :, [2, 1, 0]] output = output.cpu().numpy() * 255 # (0, 255) face_helper.add_restored_face(output) face_helper.get_inverse_affine(None) restored_img = face_helper.paste_faces_to_input_image(upsample_img=None) return restored_img
def video_interpolation_inference(model, input_dir, start_idx=0, end_idx=None, batch_size=4): """Inference image with the model. Args: model (nn.Module): The loaded model. input_dir (str): Directory of the input video. start_idx (int): The index corresponding to the first frame in the sequence. Default: 0. end_idx (int | None): The index corresponding to the last interpolated frame in the sequence. If it is None, interpolate to the last frame of video or sequence. Default: None. batch_size (int): Batch size. Default: 4. Returns: output (list[numpy.array]): The predicted interpolation result. It is an image sequence. input_fps (float): The fps of input video. If the input is an image sequence, input_fps=0.0 """ device = next(model.parameters()).device # model device # build the data pipeline if model.cfg.get('demo_pipeline', None): test_pipeline = model.cfg.demo_pipeline elif model.cfg.get('test_pipeline', None): test_pipeline = model.cfg.test_pipeline else: test_pipeline = model.cfg.val_pipeline # check if the input is a video input_fps = 0.0 file_extension = os.path.splitext(input_dir)[1] if file_extension in VIDEO_EXTENSIONS: video_reader = mmcv.VideoReader(input_dir) input_fps = video_reader.fps images = [] # load the images for img in video_reader[start_idx:end_idx]: images.append(np.flip(img, axis=2)) # BGR --> RGB else: files = os.listdir(input_dir) files = [osp.join(input_dir, f) for f in files] files.sort() files = files[start_idx:end_idx] images = [read_image(f) for f in files] data = dict(inputs=images, inputs_path=None, key=input_dir) # remove the data loading pipeline tmp_pipeline = [] for pipeline in test_pipeline: if pipeline['type'] not in [ 'GenerateSegmentIndices', 'LoadImageFromFileList', 'LoadImageFromFile' ]: tmp_pipeline.append(pipeline) test_pipeline = tmp_pipeline # compose the pipeline test_pipeline = Compose(test_pipeline) data = [test_pipeline(data)] data = collate(data, samples_per_gpu=1)['inputs'] # data.shape: [1, t, c, h, w] # forward the model output_list = [] data = model.split_frames(data) input_tensors = data.clone().detach() with torch.no_grad(): length = data.shape[0] for start in range(0, length, batch_size): end = start + batch_size output = model(data[start:end].to(device), test_mode=True)['output'] if len(output.shape) == 4: output = output.unsqueeze(1) output_list.append(output.cpu()) output_tensors = torch.cat(output_list, dim=0) result = model.merge_frames(input_tensors, output_tensors) return result, input_fps
def restoration_video_inference(model, img_dir, window_size, start_idx, filename_tmpl, max_seq_len=None): """Inference image with the model. Args: model (nn.Module): The loaded model. img_dir (str): Directory of the input video. window_size (int): The window size used in sliding-window framework. This value should be set according to the settings of the network. A value smaller than 0 means using recurrent framework. start_idx (int): The index corresponds to the first frame in the sequence. filename_tmpl (str): Template for file name. max_seq_len (int | None): The maximum sequence length that the model processes. If the sequence length is larger than this number, the sequence is split into multiple segments. If it is None, the entire sequence is processed at once. Returns: Tensor: The predicted restoration result. """ device = next(model.parameters()).device # model device # build the data pipeline if model.cfg.get('demo_pipeline', None): test_pipeline = model.cfg.demo_pipeline elif model.cfg.get('test_pipeline', None): test_pipeline = model.cfg.test_pipeline else: test_pipeline = model.cfg.val_pipeline # check if the input is a video file_extension = osp.splitext(img_dir)[1] if file_extension in VIDEO_EXTENSIONS: video_reader = mmcv.VideoReader(img_dir) # load the images data = dict(lq=[], lq_path=None, key=img_dir) for frame in video_reader: data['lq'].append(np.flip(frame, axis=2)) # remove the data loading pipeline tmp_pipeline = [] for pipeline in test_pipeline: if pipeline['type'] not in [ 'GenerateSegmentIndices', 'LoadImageFromFileList' ]: tmp_pipeline.append(pipeline) test_pipeline = tmp_pipeline else: # the first element in the pipeline must be 'GenerateSegmentIndices' if test_pipeline[0]['type'] != 'GenerateSegmentIndices': raise TypeError('The first element in the pipeline must be ' f'"GenerateSegmentIndices", but got ' f'"{test_pipeline[0]["type"]}".') # specify start_idx and filename_tmpl test_pipeline[0]['start_idx'] = start_idx test_pipeline[0]['filename_tmpl'] = filename_tmpl # prepare data sequence_length = len(glob.glob(osp.join(img_dir, '*'))) img_dir_split = re.split(r'[\\/]', img_dir) key = img_dir_split[-1] lq_folder = reduce(osp.join, img_dir_split[:-1]) data = dict(lq_path=lq_folder, gt_path='', key=key, sequence_length=sequence_length) # compose the pipeline test_pipeline = Compose(test_pipeline) data = test_pipeline(data) data = scatter(collate([data], samples_per_gpu=1), [device])[0]['lq'] # forward the model with torch.no_grad(): if window_size > 0: # sliding window framework data = pad_sequence(data, window_size) result = [] for i in range(0, data.size(1) - 2 * (window_size // 2)): data_i = data[:, i:i + window_size] result.append(model(lq=data_i, test_mode=True)['output'].cpu()) result = torch.stack(result, dim=1) else: # recurrent framework if max_seq_len is None: result = model(lq=data, test_mode=True)['output'].cpu() else: result = [] for i in range(0, data.size(1), max_seq_len): result.append( model(lq=data[:, i:i + max_seq_len], test_mode=True)['output'].cpu()) result = torch.cat(result, dim=1) return result