예제 #1
0
def test_restoration_video_inference():
    if torch.cuda.is_available():
        # recurrent framework (BasicVSR)
        model = init_model(
            './configs/restorers/basicvsr/basicvsr_reds4.py',
            None,
            device='cuda')
        img_dir = './tests/data/vimeo90k/00001/0266'
        window_size = 0
        start_idx = 1
        filename_tmpl = 'im{}.png'

        output = restoration_video_inference(model, img_dir, window_size,
                                             start_idx, filename_tmpl)
        assert output.shape == (1, 7, 3, 256, 448)

        # sliding-window framework (EDVR)
        window_size = 5
        model = init_model(
            './configs/restorers/edvr/edvrm_wotsa_x4_g8_600k_reds.py',
            None,
            device='cuda')
        output = restoration_video_inference(model, img_dir, window_size,
                                             start_idx, filename_tmpl)
        assert output.shape == (1, 7, 3, 256, 448)

        # without demo_pipeline
        model.cfg.test_pipeline = model.cfg.demo_pipeline
        model.cfg.pop('demo_pipeline')
        output = restoration_video_inference(model, img_dir, window_size,
                                             start_idx, filename_tmpl)
        assert output.shape == (1, 7, 3, 256, 448)

        # without test_pipeline and demo_pipeline
        model.cfg.val_pipeline = model.cfg.test_pipeline
        model.cfg.pop('test_pipeline')
        output = restoration_video_inference(model, img_dir, window_size,
                                             start_idx, filename_tmpl)
        assert output.shape == (1, 7, 3, 256, 448)

        # the first element in the pipeline must be 'GenerateSegmentIndices'
        with pytest.raises(TypeError):
            model.cfg.val_pipeline = model.cfg.val_pipeline[1:]
            output = restoration_video_inference(model, img_dir, window_size,
                                                 start_idx, filename_tmpl)
def main():
    args = parse_args()

    model = init_model(
        args.config, args.checkpoint, device=torch.device('cuda', args.device))

    output = restoration_video_inference(model, args.input_dir,
                                         args.window_size, args.filename_tmpl)
    for i in range(0, output.size(1)):
        output_i = output[:, i, :, :, :]
        output_i = tensor2img(output_i)
        save_path_i = f'{args.output_dir}/{i:08d}.png'

        mmcv.imwrite(output_i, save_path_i)
def main():
    """ Demo for video restoration models.

    Note that we accept video as input/output, when 'input_dir'/'output_dir'
    is set to the path to the video. But using videos introduces video
    compression, which lowers the visual quality. If you want actual quality,
    please save them as separate images (.png).
    """

    args = parse_args()

    model = init_model(args.config,
                       args.checkpoint,
                       device=torch.device('cuda', args.device))

    output = restoration_video_inference(model, args.input_dir,
                                         args.window_size, args.start_idx,
                                         args.filename_tmpl, args.max_seq_len)

    file_extension = os.path.splitext(args.output_dir)[1]
    if file_extension in VIDEO_EXTENSIONS:  # save as video
        h, w = output.shape[-2:]
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        video_writer = cv2.VideoWriter(args.output_dir, fourcc, 25, (w, h))
        for i in range(0, output.size(1)):
            img = tensor2img(output[:, i, :, :, :])
            video_writer.write(img.astype(np.uint8))
        cv2.destroyAllWindows()
        video_writer.release()
    else:
        for i in range(args.start_idx, args.start_idx + output.size(1)):
            output_i = output[:, i - args.start_idx, :, :, :]
            output_i = tensor2img(output_i)
            save_path_i = f'{args.output_dir}/{args.filename_tmpl.format(i)}'

            mmcv.imwrite(output_i, save_path_i)