def __init__(self, cfg_or_path, device=torch.device("cuda:0")): """ Args: cfg_or_path (str or dict): device (torch.device): """ if isinstance(cfg_or_path, str): cfg = load_toml_file(cfg_or_path) else: cfg = cfg_or_path self.inpainting_control_size = cfg["inpainting_control_size"] """ deepfill_v2 """ self.inpainting_cfg_path = cfg["inpainting_cfg_path"] self.inpainting_ckpt_path = cfg["inpainting_ckpt_path"] """ super-resolution""" self.sr_cfg_path = cfg["sr_cfg_path"] self.sr_ckpt_path = cfg["sr_ckpt_path"] self.temp_dir = cfg["temp_dir"] self.inpainting_model = init_model(self.inpainting_cfg_path, self.inpainting_ckpt_path, device=device.__str__()) self.sr_model = init_model(self.sr_cfg_path, self.sr_ckpt_path, device=device.__str__()) self.device = device self.cfg = cfg if not os.path.exists(self.temp_dir): os.makedirs(self.temp_dir)
def test_restoration_video_inference(): if torch.cuda.is_available(): # recurrent framework (BasicVSR) model = init_model( './configs/restorers/basicvsr/basicvsr_reds4.py', None, device='cuda') img_dir = './tests/data/vimeo90k/00001/0266' window_size = 0 start_idx = 1 filename_tmpl = 'im{}.png' output = restoration_video_inference(model, img_dir, window_size, start_idx, filename_tmpl) assert output.shape == (1, 7, 3, 256, 448) # sliding-window framework (EDVR) window_size = 5 model = init_model( './configs/restorers/edvr/edvrm_wotsa_x4_g8_600k_reds.py', None, device='cuda') output = restoration_video_inference(model, img_dir, window_size, start_idx, filename_tmpl) assert output.shape == (1, 7, 3, 256, 448) # without demo_pipeline model.cfg.test_pipeline = model.cfg.demo_pipeline model.cfg.pop('demo_pipeline') output = restoration_video_inference(model, img_dir, window_size, start_idx, filename_tmpl) assert output.shape == (1, 7, 3, 256, 448) # without test_pipeline and demo_pipeline model.cfg.val_pipeline = model.cfg.test_pipeline model.cfg.pop('test_pipeline') output = restoration_video_inference(model, img_dir, window_size, start_idx, filename_tmpl) assert output.shape == (1, 7, 3, 256, 448) # the first element in the pipeline must be 'GenerateSegmentIndices' with pytest.raises(TypeError): model.cfg.val_pipeline = model.cfg.val_pipeline[1:] output = restoration_video_inference(model, img_dir, window_size, start_idx, filename_tmpl) # video (mp4) input model = init_model( './configs/restorers/basicvsr/basicvsr_reds4.py', None, device='cuda') img_dir = './tests/data/test_inference.mp4' window_size = 0 start_idx = 1 filename_tmpl = 'im{}.png' output = restoration_video_inference(model, img_dir, window_size, start_idx, filename_tmpl) assert output.shape == (1, 5, 3, 256, 256)
def test_video_interpolation_inference(): model = init_model( './configs/video_interpolators/cain/cain_b5_320k_vimeo-triplet.py', None, device='cpu') model.cfg['demo_pipeline'] = [ dict( type='LoadImageFromFileList', io_backend='disk', key='inputs', channel_order='rgb'), dict(type='RescaleToZeroOne', keys=['inputs']), dict(type='FramesToTensor', keys=['inputs']), dict( type='Collect', keys=['inputs'], meta_keys=['inputs_path', 'key']) ] input_dir = './tests/data/vimeo90k/00001/0266' output, fps = video_interpolation_inference( model, input_dir, batch_size=10) assert isinstance(output, list) assert isinstance(fps, float) input_dir = './tests/data/test_inference.mp4' output, fps = video_interpolation_inference(model, input_dir) assert isinstance(output, list) assert isinstance(fps, float) if torch.cuda.is_available(): model = init_model( './configs/video_interpolators/cain/cain_b5_320k_vimeo-triplet.py', None, device='cuda') model.cfg['demo_pipeline'] = [ dict( type='LoadImageFromFileList', io_backend='disk', key='inputs', channel_order='rgb'), dict(type='RescaleToZeroOne', keys=['inputs']), dict(type='FramesToTensor', keys=['inputs']), dict( type='Collect', keys=['inputs'], meta_keys=['inputs_path', 'key']) ] input_dir = './tests/data/vimeo90k/00001/0266' output, fps = video_interpolation_inference( model, input_dir, batch_size=10) assert isinstance(output, list) assert isinstance(fps, float) input_dir = './tests/data/test_inference.mp4' output, fps = video_interpolation_inference(model, input_dir) assert isinstance(output, list) assert isinstance(fps, float)
def main(): args = parse_args() if not os.path.isfile(args.img_path): raise ValueError('It seems that you did not input a valid ' '"image_path". Please double check your input, or ' 'you may want to use "restoration_video_demo.py" ' 'for video restoration.') if args.ref_path and not os.path.isfile(args.ref_path): raise ValueError('It seems that you did not input a valid ' '"ref_path". Please double check your input, or ' 'you may want to use "ref_path=None" ' 'for single restoration.') model = init_model(args.config, args.checkpoint, device=torch.device('cuda', args.device)) if args.ref_path: # Ref-SR output = restoration_inference(model, args.img_path, args.ref_path) else: # SISR output = restoration_inference(model, args.img_path) output = tensor2img(output) mmcv.imwrite(output, args.save_path) if args.imshow: mmcv.imshow(output, 'predicted restoration result')
def main(): args = parse_args() model = init_model( args.config, args.checkpoint, device=torch.device('cuda', args.device)) output = generation_inference(model, args.img_path, args.unpaired_path) mmcv.imwrite(output, args.save_path) if args.imshow: mmcv.imshow(output, 'predicted generation result')
def main(): args = parse_args() model = init_model( args.config, args.checkpoint, device=torch.device('cuda', args.device)) result = inpainting_inference(model, args.masked_img_path, args.mask_path) result = tensor2img(result, min_max=(-1, 1))[..., ::-1] mmcv.imwrite(result, args.save_path) if args.imshow: mmcv.imshow(result, 'predicted inpainting result')
def __init__(self, cfg_or_path, device=torch.device("cuda:0")): """ Args: cfg_or_path: the config object, it contains the following information: seg_cfg_path="./assets/configs/detection/point_rend/point_rend_r50_caffe_fpn_mstrain_3x_coco.py", seg_ckpt_path="./assets/checkpoints/detection/point_rend_r50_caffe_fpn_mstrain_3x_coco-e0ebb6b7.pth", matting_cfg_path="./assets/configs/editing/mattors/gca/gca_r34_4x10_200k_comp1k.py", matting_ckpt_path="./assets/checkpoints/mattors/gca_r34_4x10_200k_comp1k_SAD-34.77_20200604_213848-4369bea0.pth", person_label_index = 0 temp_dir="./assets/temp" trimap_control_size = 300 matting_image_size = 512 morph_kernel_size = 3 erode_iter_num = 2 dilate_iter_num = 7 device: """ if isinstance(cfg_or_path, str): cfg = EasyDict(load_toml_file(cfg_or_path)) else: cfg = cfg_or_path self.trimap_control_size = cfg.trimap_control_size self.matting_image_size = cfg.matting_image_size self.erode_iter_num = cfg.erode_iter_num self.dilate_iter_num = cfg.dilate_iter_num self.morph_kernel_size = cfg.morph_kernel_size """ point_rend_r50_caffe_fpn_mstrain_3x_coco """ self.detection_config_file = cfg.seg_cfg_path self.detection_checkpoint_file = cfg.seg_ckpt_path self.person_label_index = cfg.person_label_index """ gca_r34_4x10_200k_comp1k """ self.editing_config_file = cfg.matting_cfg_path self.editing_checkpoint_file = cfg.matting_ckpt_path self.device = device self.detection_model = init_detector(self.detection_config_file, self.detection_checkpoint_file, device=device) self.matting_model = init_model(self.editing_config_file, self.editing_checkpoint_file, device=device.__str__()) self.temp_dir = cfg.temp_dir if not os.path.exists(self.temp_dir): os.makedirs(self.temp_dir)
def main(): args = parse_args() model = init_model(args.config, args.checkpoint, device=torch.device('cuda', args.device)) output = restoration_inference(model, args.img_path) output = tensor2img(output) # print(np.shape(output)) mmcv.imwrite(output, args.save_path) if args.imshow: mmcv.imshow(output, 'predicted restoration result')
def main(): args = parse_args() model = init_model(args.config, args.checkpoint, device=torch.device('cuda', args.device)) pred_alpha = matting_inference(model, args.img_path, args.trimap_path) * 255 mmcv.imwrite(pred_alpha, args.save_path) if args.imshow: mmcv.imshow(pred_alpha, 'predicted alpha matte')
def main(): args = parse_args() model = init_model( args.config, args.checkpoint, device=torch.device('cuda', args.device)) output = restoration_video_inference(model, args.input_dir, args.window_size, args.filename_tmpl) for i in range(0, output.size(1)): output_i = output[:, i, :, :, :] output_i = tensor2img(output_i) save_path_i = f'{args.output_dir}/{i:08d}.png' mmcv.imwrite(output_i, save_path_i)
def main(): args = parse_args() model = init_model( args.config, args.checkpoint, device=torch.device('cuda', args.device)) img_dir = os.listdir(args.img_path) for img in img_dir: pred_alpha, perd_fg, pred_bg = fba_inference(model, os.path.join(args.img_path, img) , os.path.join(args.trimap_path, img)) mmcv.imwrite(pred_alpha * 255, os.path.join('data/portrait/results/debug-nograd-noexcel/alpha', img)) mmcv.imwrite(perd_fg * 255, os.path.join('data/portrait/results/debug-nograd-noexcel/fg', img))
def initialize(self, context): print('MMEditHandler.initialize is called') properties = context.system_properties self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = torch.device(self.map_location + ':' + str(properties.get('gpu_id')) if torch.cuda. is_available() else self.map_location) self.manifest = context.manifest model_dir = properties.get('model_dir') serialized_file = self.manifest['model']['serializedFile'] checkpoint = os.path.join(model_dir, serialized_file) self.config_file = os.path.join(model_dir, 'config.py') self.model = init_model(self.config_file, checkpoint, self.device) self.initialized = True
def main(): args = parse_args() model = init_model(args.config, args.checkpoint, device=torch.device('cuda', args.device)) img_dir = os.listdir(args.img_path) for img in tqdm(img_dir): pred_alpha = matting_inference(model, os.path.join( args.img_path, img), os.path.join(args.trimap_path, img)) * 255 mmcv.imwrite( pred_alpha, os.path.join( "data/GS_Video/results/ngs-iter_78000-SAD-7.081.pth/alpha", img))
def main(): args = parse_args() if not os.path.isfile(args.img_path): raise ValueError('It seems that you did not input a valid ' '"image_path". Please double check your input, or ' 'you may want to use "restoration_video_demo.py" ' 'for video restoration.') model = init_model( args.config, args.checkpoint, device=torch.device('cuda', args.device)) output = restoration_face_inference(model, args.img_path, args.upscale_factor, args.face_size) mmcv.imwrite(output, args.save_path) if args.imshow: mmcv.imshow(output, 'predicted restoration result')
def main(): args = parse_args() # rename_pth(args.checkpoint) # print('rename success') model = init_model(args.config, args.checkpoint, device=torch.device('cuda', args.device)) for i in model.state_dict(): print(i) pred_alpha = matting_inference(model, args.img_path, args.trimap_path) * 255 # print(pred_alpha) mmcv.imwrite(pred_alpha, args.save_path) if args.imshow: mmcv.imshow(pred_alpha, 'predicted alpha matte')
def main(): """ Demo for video interpolation models. Note that we accept video as input(output), when 'input_dir'('output_dir') is set to the path to the video. But using videos introduces video compression, which lower the visual quality. If you want actual quality, please save them as separate images (.png). """ args = parse_args() model = init_model(args.config, args.checkpoint, device=torch.device('cuda', args.device)) output, fps = video_interpolation_inference(model, args.input_dir, args.start_idx, args.end_idx, args.batch_size) if args.fps_multiplier: assert args.fps_multiplier > 0, '`fps_multiplier` cannot be negative' assert fps > 0, 'the input is not a video' fps = args.fps_multiplier * fps else: fps = args.fps if args.fps > 0 else fps file_extension = os.path.splitext(args.output_dir)[1] if file_extension in VIDEO_EXTENSIONS: # save as video h, w = output[0].shape[:2] fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter(args.output_dir, fourcc, fps, (w, h)) for img in output: video_writer.write(img) cv2.destroyAllWindows() video_writer.release() else: # save as images for i, img in enumerate(output): save_path = f'{args.output_dir}/{args.filename_tmpl.format(i)}' mmcv.imwrite(img, save_path)
def main(): """ Demo for video restoration models. Note that we accept video as input/output, when 'input_dir'/'output_dir' is set to the path to the video. But using videos introduces video compression, which lowers the visual quality. If you want actual quality, please save them as separate images (.png). """ args = parse_args() model = init_model(args.config, args.checkpoint, device=torch.device('cuda', args.device)) output = restoration_video_inference(model, args.input_dir, args.window_size, args.start_idx, args.filename_tmpl, args.max_seq_len) file_extension = os.path.splitext(args.output_dir)[1] if file_extension in VIDEO_EXTENSIONS: # save as video h, w = output.shape[-2:] fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter(args.output_dir, fourcc, 25, (w, h)) for i in range(0, output.size(1)): img = tensor2img(output[:, i, :, :, :]) video_writer.write(img.astype(np.uint8)) cv2.destroyAllWindows() video_writer.release() else: for i in range(args.start_idx, args.start_idx + output.size(1)): output_i = output[:, i - args.start_idx, :, :, :] output_i = tensor2img(output_i) save_path_i = f'{args.output_dir}/{args.filename_tmpl.format(i)}' mmcv.imwrite(output_i, save_path_i)
src_language = b_boxes_manager.detect_src_lang( translation_service=config["translation_service"]) if src_language == dst_language: copyfile(image_path, result_path) else: torch.cuda.empty_cache() # Inpainting t0 = time.time() ## create mask for inpainting mask = b_boxes_manager.mask(image) ## Save temp mask file temporarly mask_path = "temp.png" Image.fromarray(mask).save("temp.png") ## init the inpainter model = init_model(config["Inpainter"]["configuration"], config["Inpainter"]["pretrained_model"], device=device) ## inpaint result = inpainting_inference(model, image_path, mask_path)[0] result = tensor2img(result, min_max=(-1, 1)) os.remove("temp.png") print("Inpainting time : {:.3f}".format(time.time() - t0)) # Translation t0 = time.time() result = b_boxes_manager.fill_translation( result, same_block_th=config["same_block_th"], translation_service=config["translation_service"], src_language=src_language, dest_language=dst_language,