コード例 #1
0
def test_compose():
    with pytest.raises(TypeError):
        Compose('LoadAlpha')

    target_keys = ['img', 'meta']

    img = np.random.randn(256, 256, 3)
    results = dict(img=img, abandoned_key=None, img_name='test_image.png')
    test_pipeline = [
        dict(type='Collect', keys=['img'], meta_keys=['img_name']),
        dict(type='ImageToTensor', keys=['img'])
    ]
    compose = Compose(test_pipeline)
    compose_results = compose(results)
    assert check_keys_equal(compose_results.keys(), target_keys)
    assert check_keys_equal(compose_results['meta'].data.keys(), ['img_name'])

    results = None
    image_to_tensor = ImageToTensor(keys=[])
    test_pipeline = [image_to_tensor]
    compose = Compose(test_pipeline)
    compose_results = compose(results)
    assert compose_results is None

    assert repr(compose) == (compose.__class__.__name__ +
                             f'(\n    {image_to_tensor}\n)')
コード例 #2
0
def restoration_inference(model, img):
    """Inference image with the model.

    Args:
        model (nn.Module): The loaded model.
        img (str): File path of input image.

    Returns:
        Tensor: The predicted restoration result.
    """
    cfg = model.cfg
    device = next(model.parameters()).device  # model device
    # remove gt from test_pipeline
    keys_to_remove = ['gt', 'gt_path']
    for key in keys_to_remove:
        for pipeline in list(cfg.test_pipeline):
            if 'key' in pipeline and key == pipeline['key']:
                cfg.test_pipeline.remove(pipeline)
            if 'keys' in pipeline and key in pipeline['keys']:
                pipeline['keys'].remove(key)
                if len(pipeline['keys']) == 0:
                    cfg.test_pipeline.remove(pipeline)
            if 'meta_keys' in pipeline and key in pipeline['meta_keys']:
                pipeline['meta_keys'].remove(key)
    # build the data pipeline
    test_pipeline = Compose(cfg.test_pipeline)
    # prepare data
    data = dict(lq_path=img)
    data = test_pipeline(data)
    data = scatter(collate([data], samples_per_gpu=1), [device])[0]
    # forward the model
    with torch.no_grad():
        result = model(test_mode=True, **data)

    return result['output']
コード例 #3
0
def _prepare_input_img(model_type: str,
                       img_path: str,
                       config: dict,
                       rescale_shape: Optional[Iterable] = None) -> dict:
    """Prepare the input image

    Args:
        model_type (str): which kind of model config belong to, \
            one of ['inpainting', 'mattor', 'restorer', 'synthesizer'].
        img_path (str): image path to show or verify.
        config (dict): MMCV config, determined by the inpupt config file.
        rescale_shape (Optional[Iterable]): to rescale the shape of the \
            input tensor.

    Returns:
        dict: {'imgs': imgs, 'img_metas': img_metas}
    """
    # remove alpha from test_pipeline
    model_type = model_type
    if model_type == 'mattor':
        keys_to_remove = ['alpha', 'ori_alpha']
    elif model_type == 'restorer':
        keys_to_remove = ['gt', 'gt_path']
    for key in keys_to_remove:
        for pipeline in list(config.test_pipeline):
            if 'key' in pipeline and key == pipeline['key']:
                config.test_pipeline.remove(pipeline)
            if 'keys' in pipeline and key in pipeline['keys']:
                pipeline['keys'].remove(key)
                if len(pipeline['keys']) == 0:
                    config.test_pipeline.remove(pipeline)
            if 'meta_keys' in pipeline and key in pipeline['meta_keys']:
                pipeline['meta_keys'].remove(key)

    # build the data pipeline
    test_pipeline = Compose(config.test_pipeline)
    # prepare data
    if model_type == 'mattor':
        raise RuntimeError('Invalid model_type!', model_type)
    if model_type == 'restorer':
        data = dict(lq_path=img_path)

    data = test_pipeline(data)

    if model_type == 'restorer':
        imgs = data['lq']
    else:
        imgs = data['img']
    img_metas = [data['meta']]

    if rescale_shape is not None:
        for img_meta in img_metas:
            img_meta['ori_shape'] = tuple(rescale_shape) + (3, )

    mm_inputs = {'imgs': imgs, 'img_metas': img_metas}

    return mm_inputs
コード例 #4
0
def restoration_video_inference(model, img_dir, window_size, filename_tmpl):
    """Inference image with the model.

    Args:
        model (nn.Module): The loaded model.
        img_dir (str): Directory of the input video.
        window_size (int): The window size used in sliding-window framework.
            This value should be set according to the settings of the network.
            A value smaller than 0 means using recurrent framework.

    Returns:
        Tensor: The predicted restoration result.
    """

    device = next(model.parameters()).device  # model device

    # pipeline
    test_pipeline = [
        dict(type='GenerateSegmentIndices',
             interval_list=[1],
             filename_tmpl=filename_tmpl),
        dict(type='LoadImageFromFileList',
             io_backend='disk',
             key='lq',
             channel_order='rgb'),
        dict(type='RescaleToZeroOne', keys=['lq']),
        dict(type='FramesToTensor', keys=['lq']),
        dict(type='Collect', keys=['lq'], meta_keys=['lq_path', 'key'])
    ]

    # build the data pipeline
    test_pipeline = Compose(test_pipeline)

    # prepare data
    sequence_length = len(glob.glob(f'{img_dir}/*'))
    key = img_dir.split('/')[-1]
    lq_folder = '/'.join(img_dir.split('/')[:-1])
    data = dict(lq_path=lq_folder,
                gt_path='',
                key=key,
                sequence_length=sequence_length)
    data = test_pipeline(data)
    data = scatter(collate([data], samples_per_gpu=1), [device])[0]['lq']

    # forward the model
    with torch.no_grad():
        if window_size > 0:  # sliding window framework
            data = pad_sequence(data, window_size)
            result = []
            for i in range(0, data.size(1) - 2 * window_size):
                data_i = data[:, i:i + window_size]
                result.append(model(lq=data_i, test_mode=True)['output'])
            result = torch.stack(result, dim=1)
        else:  # recurrent framework
            result = model(lq=data, test_mode=True)['output']

    return result
コード例 #5
0
def generation_inference(model, img, img_unpaired=None):
    """Inference image with the model.

    Args:
        model (nn.Module): The loaded model.
        img (str): File path of input image.
        img_unpaired (str, optional): File path of the unpaired image.
            If not None, perform unpaired image generation. Default: None.

    Returns:
        np.ndarray: The predicted generation result.
    """
    cfg = model.cfg
    device = next(model.parameters()).device  # model device
    # build the data pipeline
    test_pipeline = Compose(cfg.test_pipeline)
    # prepare data
    if img_unpaired is None:
        data = dict(pair_path=img)
    else:
        data = dict(img_a_path=img, img_b_path=img_unpaired)
    data = test_pipeline(data)
    data = scatter(collate([data], samples_per_gpu=1), [device])[0]
    # forward the model
    with torch.no_grad():
        results = model(test_mode=True, **data)
    # process generation shown mode
    if img_unpaired is None:
        if model.show_input:
            output = np.concatenate([
                tensor2img(results['real_a'], min_max=(-1, 1)),
                tensor2img(results['fake_b'], min_max=(-1, 1)),
                tensor2img(results['real_b'], min_max=(-1, 1))
            ],
                                    axis=1)
        else:
            output = tensor2img(results['fake_b'], min_max=(-1, 1))
    else:
        if model.show_input:
            output = np.concatenate([
                tensor2img(results['real_a'], min_max=(-1, 1)),
                tensor2img(results['fake_b'], min_max=(-1, 1)),
                tensor2img(results['real_b'], min_max=(-1, 1)),
                tensor2img(results['fake_a'], min_max=(-1, 1))
            ],
                                    axis=1)
        else:
            if model.test_direction == 'a2b':
                output = tensor2img(results['fake_b'], min_max=(-1, 1))
            else:
                output = tensor2img(results['fake_a'], min_max=(-1, 1))
    return output
コード例 #6
0
ファイル: test_matting.py プロジェクト: wchstrife/mmediting
def matting_inference(model, img, trimap):
    """Inference image(s) with the model.

    Args:
        model (nn.Module): The loaded model.
        img (str): Image file path.
        trimap (str): Trimap file path.

    Returns:
        np.ndarray: The predicted alpha matte.
    """
    cfg = model.cfg
    device = next(model.parameters()).device  # model device
    # remove alpha from test_pipeline
    keys_to_remove = ['alpha', 'ori_alpha']
    for key in keys_to_remove:
        for pipeline in list(cfg.test_pipeline):
            if 'key' in pipeline and key == pipeline['key']:
                cfg.test_pipeline.remove(pipeline)
            if 'keys' in pipeline and key in pipeline['keys']:
                pipeline['keys'].remove(key)
                if len(pipeline['keys']) == 0:
                    cfg.test_pipeline.remove(pipeline)
            if 'meta_keys' in pipeline and key in pipeline['meta_keys']:
                pipeline['meta_keys'].remove(key)
    # build the data pipeline
    test_pipeline = Compose(cfg.test_pipeline)
    # prepare data
    data = dict(merged_path=img, trimap_path=trimap)
    data = test_pipeline(data)

    # # Test Code
    # merged = data['merged']
    # ori_merged = data['ori_merged']
    # trimap = data['trimap']
    # trimap_transformed = data['trimap_transformed']

    # ori_merged.cpu().numpy().tofile('dat/' + 'ori_merged_new' + '.dat')
    # merged.cpu().numpy().tofile('dat/' + 'merged_rgbtrue' + '.dat')
    # trimap.cpu().numpy().tofile('dat/' + 'trimap' + '.dat')
    # trimap_transformed.numpy().tofile('dat/' + 'trimap_transformed' + '.dat')

    data = scatter(collate([data], samples_per_gpu=1), [device])[0]

    # merged = data['merged']
    # merged.cpu().numpy().tofile('dat/merged_tensor_new_norm_list.dat')
    print("data prepare success!!!")
    # forward the model
    with torch.no_grad():
        result = model(test_mode=True, **data)

    return result['pred_alpha']
コード例 #7
0
def matting_inference_file(model,
                           img,
                           trimap=None,
                           mask=None,
                           image_path="input file directly"):
    """Inference image(s) with the model.

    Args:
        model (nn.Module): The loaded model.
        img (str): Image file path.
        trimap (str): Trimap file path.

    Returns:
        np.ndarray: The predicted alpha matte.
    """
    assert trimap is not None or mask is not None
    cfg = model.cfg
    device = next(model.parameters()).device  # model device
    # remove alpha from test_pipeline
    keys_to_remove = ['alpha', 'ori_alpha']
    for key in keys_to_remove:
        for pipeline in list(cfg.test_pipeline):
            if 'key' in pipeline and key == pipeline['key']:
                cfg.test_pipeline.remove(pipeline)
            if 'keys' in pipeline and key in pipeline['keys']:
                pipeline['keys'].remove(key)
                if len(pipeline['keys']) == 0:
                    cfg.test_pipeline.remove(pipeline)
            if 'meta_keys' in pipeline and key in pipeline['meta_keys']:
                pipeline['meta_keys'].remove(key)
    # build the data pipeline
    test_pipeline = cfg.test_pipeline[2:]
    test_pipeline = Compose(test_pipeline)
    # prepare data
    data = dict(merged=img,
                mask=mask,
                ori_mask=mask,
                trimap=trimap,
                ori_trimap=trimap,
                ori_merged=img.copy(),
                merged_path=image_path,
                merged_ori_shape=img.shape)

    data = test_pipeline(data)
    data = scatter(collate([data], samples_per_gpu=1), [device])[0]
    # forward the model
    with torch.no_grad():
        result = model(test_mode=True, **data)
    return result
コード例 #8
0
def inpainting_inference(model, masked_img, mask):
    """Inference image with the model.

    Args:
        model (nn.Module): The loaded model.
        masked_img (str): File path of image with mask.
        mask (str): Mask file path.

    Returns:
        Tensor: The predicted inpainting result.
    """
    device = next(model.parameters()).device  # model device

    infer_pipeline = [
        dict(type='LoadImageFromFile', key='masked_img'),
        dict(type='LoadMask', mask_mode='file', mask_config=dict()),
        dict(type='Pad', keys=['masked_img', 'mask'], mode='reflect'),
        dict(
            type='Normalize',
            keys=['masked_img'],
            mean=[127.5] * 3,
            std=[127.5] * 3,
            to_rgb=False),
        dict(type='GetMaskedImage', img_name='masked_img'),
        dict(
            type='Collect',
            keys=['masked_img', 'mask'],
            meta_keys=['masked_img_path']),
        dict(type='ImageToTensor', keys=['masked_img', 'mask'])
    ]

    # build the data pipeline
    test_pipeline = Compose(infer_pipeline)
    # prepare data
    data = dict(masked_img_path=masked_img, mask_path=mask)
    data = test_pipeline(data)
    data = scatter(collate([data], samples_per_gpu=1), [device])[0]
    # forward the model
    with torch.no_grad():
        result = model(test_mode=True, **data)

    return result['fake_img']
コード例 #9
0
def matting_inference(model, img, trimap):
    """Inference image(s) with the model.

    Args:
        model (nn.Module): The loaded model.
        img (str): Image file path.
        trimap (str): Trimap file path.

    Returns:
        np.ndarray: The predicted alpha matte.
    """
    cfg = model.cfg
    device = next(model.parameters()).device  # model device
    # remove alpha from test_pipeline
    keys_to_remove = ['alpha', 'ori_alpha']
    for key in keys_to_remove:
        for pipeline in list(cfg.test_pipeline):
            if 'key' in pipeline and key == pipeline['key']:
                cfg.test_pipeline.remove(pipeline)
            if 'keys' in pipeline and key in pipeline['keys']:
                pipeline['keys'].remove(key)
                if len(pipeline['keys']) == 0:
                    cfg.test_pipeline.remove(pipeline)
            if 'meta_keys' in pipeline and key in pipeline['meta_keys']:
                pipeline['meta_keys'].remove(key)
    # build the data pipeline
    test_pipeline = Compose(cfg.test_pipeline)
    # prepare data
    data = dict(merged_path=img, trimap_path=trimap)
    data = test_pipeline(data)
    data = scatter(collate([data], samples_per_gpu=1), [device])[0]
    # forward the model
    with torch.no_grad():
        result = model(test_mode=True, **data)

    return result['pred_alpha']
コード例 #10
0
def restoration_video_inference(model, img_dir, window_size, start_idx,
                                filename_tmpl):
    """Inference image with the model.

    Args:
        model (nn.Module): The loaded model.
        img_dir (str): Directory of the input video.
        window_size (int): The window size used in sliding-window framework.
            This value should be set according to the settings of the network.
            A value smaller than 0 means using recurrent framework.
        start_idx (int): The index corresponds to the first frame in the
            sequence.
        filename_tmpl (str): Template for file name.

    Returns:
        Tensor: The predicted restoration result.
    """

    device = next(model.parameters()).device  # model device

    # build the data pipeline
    if model.cfg.get('demo_pipeline', None):
        test_pipeline = model.cfg.demo_pipeline
    elif model.cfg.get('test_pipeline', None):
        test_pipeline = model.cfg.test_pipeline
    else:
        test_pipeline = model.cfg.val_pipeline

    # the first element in the pipeline must be 'GenerateSegmentIndices'
    if test_pipeline[0]['type'] != 'GenerateSegmentIndices':
        raise TypeError('The first element in the pipeline must be '
                        f'"GenerateSegmentIndices", but got '
                        f'"{test_pipeline[0]["type"]}".')

    # specify start_idx and filename_tmpl
    test_pipeline[0]['start_idx'] = start_idx
    test_pipeline[0]['filename_tmpl'] = filename_tmpl

    # compose the pipeline
    test_pipeline = Compose(test_pipeline)

    # prepare data
    sequence_length = len(glob.glob(f'{img_dir}/*'))
    key = img_dir.split('/')[-1]
    lq_folder = '/'.join(img_dir.split('/')[:-1])
    data = dict(
        lq_path=lq_folder,
        gt_path='',
        key=key,
        sequence_length=sequence_length)
    data = test_pipeline(data)
    data = scatter(collate([data], samples_per_gpu=1), [device])[0]['lq']

    # forward the model
    with torch.no_grad():
        if window_size > 0:  # sliding window framework
            data = pad_sequence(data, window_size)
            result = []
            for i in range(0, data.size(1) - 2 * (window_size // 2)):
                data_i = data[:, i:i + window_size]
                result.append(model(lq=data_i, test_mode=True)['output'])
            result = torch.stack(result, dim=1)
        else:  # recurrent framework
            result = model(lq=data, test_mode=True)['output']

    return result
コード例 #11
0
ファイル: pytorch2onnx.py プロジェクト: ywu40/mmediting
    if model_type == 'mattor':
        keys_to_remove = ['alpha', 'ori_alpha']
    elif model_type == 'restorer':
        keys_to_remove = ['gt', 'gt_path']
    for key in keys_to_remove:
        for pipeline in list(config.test_pipeline):
            if 'key' in pipeline and key == pipeline['key']:
                config.test_pipeline.remove(pipeline)
            if 'keys' in pipeline and key in pipeline['keys']:
                pipeline['keys'].remove(key)
                if len(pipeline['keys']) == 0:
                    config.test_pipeline.remove(pipeline)
            if 'meta_keys' in pipeline and key in pipeline['meta_keys']:
                pipeline['meta_keys'].remove(key)
    # build the data pipeline
    test_pipeline = Compose(config.test_pipeline)
    # prepare data
    if model_type == 'mattor':
        data = dict(merged_path=args.img_path, trimap_path=args.trimap_path)
    elif model_type == 'restorer':
        data = dict(lq_path=args.img_path)
    data = test_pipeline(data)

    # convert model to onnx file
    pytorch2onnx(model,
                 data,
                 model_type,
                 opset_version=args.opset_version,
                 show=args.show,
                 output_file=args.output_file,
                 verify=args.verify,
コード例 #12
0
def restoration_face_inference(model, img, upscale_factor=1, face_size=1024):
    """Inference image with the model.

    Args:
        model (nn.Module): The loaded model.
        img (str): File path of input image.
        upscale_factor (int, optional): The number of times the input image
            is upsampled. Default: 1.
        face_size (int, optional): The size of the cropped and aligned faces.
            Default: 1024.

    Returns:
        Tensor: The predicted restoration result.
    """
    device = next(model.parameters()).device  # model device

    # build the data pipeline
    if model.cfg.get('demo_pipeline', None):
        test_pipeline = model.cfg.demo_pipeline
    elif model.cfg.get('test_pipeline', None):
        test_pipeline = model.cfg.test_pipeline
    else:
        test_pipeline = model.cfg.val_pipeline

    # remove gt from test_pipeline
    keys_to_remove = ['gt', 'gt_path']
    for key in keys_to_remove:
        for pipeline in list(test_pipeline):
            if 'key' in pipeline and key == pipeline['key']:
                test_pipeline.remove(pipeline)
            if 'keys' in pipeline and key in pipeline['keys']:
                pipeline['keys'].remove(key)
                if len(pipeline['keys']) == 0:
                    test_pipeline.remove(pipeline)
            if 'meta_keys' in pipeline and key in pipeline['meta_keys']:
                pipeline['meta_keys'].remove(key)
    # build the data pipeline
    test_pipeline = Compose(test_pipeline)

    # face helper for detecting and aligning faces
    assert has_facexlib, 'Please install FaceXLib to use the demo.'
    face_helper = FaceRestoreHelper(upscale_factor,
                                    face_size=face_size,
                                    crop_ratio=(1, 1),
                                    det_model='retinaface_resnet50',
                                    template_3points=True,
                                    save_ext='png',
                                    device=device)

    face_helper.read_image(img)
    # get face landmarks for each face
    face_helper.get_face_landmarks_5(only_center_face=False,
                                     eye_dist_threshold=None)
    # align and warp each face
    face_helper.align_warp_face()

    for i, img in enumerate(face_helper.cropped_faces):
        # prepare data
        data = dict(lq=img.astype(np.float32))
        data = test_pipeline(data)
        data = scatter(collate([data], samples_per_gpu=1), [device])[0]

        with torch.no_grad():
            output = model(test_mode=True, **data)['output'].clip_(0, 1)

        output = output.squeeze(0).permute(1, 2, 0)[:, :, [2, 1, 0]]
        output = output.cpu().numpy() * 255  # (0, 255)
        face_helper.add_restored_face(output)

    face_helper.get_inverse_affine(None)
    restored_img = face_helper.paste_faces_to_input_image(upsample_img=None)

    return restored_img
コード例 #13
0
def video_interpolation_inference(model,
                                  input_dir,
                                  start_idx=0,
                                  end_idx=None,
                                  batch_size=4):
    """Inference image with the model.

    Args:
        model (nn.Module): The loaded model.
        input_dir (str): Directory of the input video.
        start_idx (int): The index corresponding to the first frame in the
            sequence. Default: 0.
        end_idx (int | None): The index corresponding to the last interpolated
            frame in the sequence. If it is None, interpolate to the last
            frame of video or sequence. Default: None.
        batch_size (int): Batch size. Default: 4.

    Returns:
        output (list[numpy.array]): The predicted interpolation result.
            It is an image sequence.
        input_fps (float): The fps of input video. If the input is an image
            sequence, input_fps=0.0
    """

    device = next(model.parameters()).device  # model device

    # build the data pipeline
    if model.cfg.get('demo_pipeline', None):
        test_pipeline = model.cfg.demo_pipeline
    elif model.cfg.get('test_pipeline', None):
        test_pipeline = model.cfg.test_pipeline
    else:
        test_pipeline = model.cfg.val_pipeline

    # check if the input is a video
    input_fps = 0.0
    file_extension = os.path.splitext(input_dir)[1]
    if file_extension in VIDEO_EXTENSIONS:
        video_reader = mmcv.VideoReader(input_dir)
        input_fps = video_reader.fps
        images = []
        # load the images
        for img in video_reader[start_idx:end_idx]:
            images.append(np.flip(img, axis=2))  # BGR --> RGB
    else:
        files = os.listdir(input_dir)
        files = [osp.join(input_dir, f) for f in files]
        files.sort()
        files = files[start_idx:end_idx]
        images = [read_image(f) for f in files]

    data = dict(inputs=images, inputs_path=None, key=input_dir)

    # remove the data loading pipeline
    tmp_pipeline = []
    for pipeline in test_pipeline:
        if pipeline['type'] not in [
                'GenerateSegmentIndices', 'LoadImageFromFileList',
                'LoadImageFromFile'
        ]:
            tmp_pipeline.append(pipeline)
    test_pipeline = tmp_pipeline

    # compose the pipeline
    test_pipeline = Compose(test_pipeline)
    data = [test_pipeline(data)]
    data = collate(data, samples_per_gpu=1)['inputs']
    # data.shape: [1, t, c, h, w]

    # forward the model
    output_list = []
    data = model.split_frames(data)
    input_tensors = data.clone().detach()
    with torch.no_grad():
        length = data.shape[0]
        for start in range(0, length, batch_size):
            end = start + batch_size
            output = model(data[start:end].to(device),
                           test_mode=True)['output']
            if len(output.shape) == 4:
                output = output.unsqueeze(1)
            output_list.append(output.cpu())

    output_tensors = torch.cat(output_list, dim=0)

    result = model.merge_frames(input_tensors, output_tensors)

    return result, input_fps
コード例 #14
0
def restoration_video_inference(model,
                                img_dir,
                                window_size,
                                start_idx,
                                filename_tmpl,
                                max_seq_len=None):
    """Inference image with the model.

    Args:
        model (nn.Module): The loaded model.
        img_dir (str): Directory of the input video.
        window_size (int): The window size used in sliding-window framework.
            This value should be set according to the settings of the network.
            A value smaller than 0 means using recurrent framework.
        start_idx (int): The index corresponds to the first frame in the
            sequence.
        filename_tmpl (str): Template for file name.
        max_seq_len (int | None): The maximum sequence length that the model
            processes. If the sequence length is larger than this number,
            the sequence is split into multiple segments. If it is None,
            the entire sequence is processed at once.

    Returns:
        Tensor: The predicted restoration result.
    """

    device = next(model.parameters()).device  # model device

    # build the data pipeline
    if model.cfg.get('demo_pipeline', None):
        test_pipeline = model.cfg.demo_pipeline
    elif model.cfg.get('test_pipeline', None):
        test_pipeline = model.cfg.test_pipeline
    else:
        test_pipeline = model.cfg.val_pipeline

    # check if the input is a video
    file_extension = osp.splitext(img_dir)[1]
    if file_extension in VIDEO_EXTENSIONS:
        video_reader = mmcv.VideoReader(img_dir)
        # load the images
        data = dict(lq=[], lq_path=None, key=img_dir)
        for frame in video_reader:
            data['lq'].append(np.flip(frame, axis=2))

        # remove the data loading pipeline
        tmp_pipeline = []
        for pipeline in test_pipeline:
            if pipeline['type'] not in [
                    'GenerateSegmentIndices', 'LoadImageFromFileList'
            ]:
                tmp_pipeline.append(pipeline)
        test_pipeline = tmp_pipeline
    else:
        # the first element in the pipeline must be 'GenerateSegmentIndices'
        if test_pipeline[0]['type'] != 'GenerateSegmentIndices':
            raise TypeError('The first element in the pipeline must be '
                            f'"GenerateSegmentIndices", but got '
                            f'"{test_pipeline[0]["type"]}".')

        # specify start_idx and filename_tmpl
        test_pipeline[0]['start_idx'] = start_idx
        test_pipeline[0]['filename_tmpl'] = filename_tmpl

        # prepare data
        sequence_length = len(glob.glob(osp.join(img_dir, '*')))
        img_dir_split = re.split(r'[\\/]', img_dir)
        key = img_dir_split[-1]
        lq_folder = reduce(osp.join, img_dir_split[:-1])
        data = dict(lq_path=lq_folder,
                    gt_path='',
                    key=key,
                    sequence_length=sequence_length)

    # compose the pipeline
    test_pipeline = Compose(test_pipeline)
    data = test_pipeline(data)
    data = scatter(collate([data], samples_per_gpu=1), [device])[0]['lq']
    # forward the model
    with torch.no_grad():
        if window_size > 0:  # sliding window framework
            data = pad_sequence(data, window_size)
            result = []
            for i in range(0, data.size(1) - 2 * (window_size // 2)):
                data_i = data[:, i:i + window_size]
                result.append(model(lq=data_i, test_mode=True)['output'].cpu())
            result = torch.stack(result, dim=1)
        else:  # recurrent framework
            if max_seq_len is None:
                result = model(lq=data, test_mode=True)['output'].cpu()
            else:
                result = []
                for i in range(0, data.size(1), max_seq_len):
                    result.append(
                        model(lq=data[:, i:i + max_seq_len],
                              test_mode=True)['output'].cpu())
                result = torch.cat(result, dim=1)
    return result