Ejemplo n.º 1
0
def main():
    parser = ArgumentParser()
    parser.add_argument('img', help='Image file')
    parser.add_argument('config', help='Config file')
    parser.add_argument('checkpoint', help='Checkpoint file')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--palette',
        default='cityscapes',
        help='Color palette used for segmentation map')
    parser.add_argument(
        '--opacity',
        type=float,
        default=0.5,
        help='Opacity of painted segmentation map. In (0, 1] range.')
    args = parser.parse_args()

    # build the model from a config file and a checkpoint file
    model = init_segmentor(args.config, args.checkpoint, device=args.device)
    # test a single image
    result = inference_segmentor(model, args.img)
    # show the results
    show_result_pyplot(
        model,
        args.img,
        result,
        get_palette(args.palette),
        opacity=args.opacity)
Ejemplo n.º 2
0
def main():
    parser = ArgumentParser()
    parser.add_argument('img', help='Image file')
    parser.add_argument('config', help='Config file')
    parser.add_argument('checkpoint', help='Checkpoint file')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--palette',
        default='ade',
        help='Color palette used for segmentation map')
    args = parser.parse_args()

    # build the model from a config file and a checkpoint file
    model = init_segmentor(args.config, args.checkpoint, device=args.device)
    # test a single image
    demo_im_names = os.listdir(args.img)
    random.shuffle(demo_im_names)
    for im_name in demo_im_names:
        if 'png' in im_name or 'jpg' in im_name:
            full_name = os.path.join(args.img, im_name)
            result = inference_segmentor(model, full_name)
            # show the results
            pl=[[220, 220, 220],[17, 142, 35], [152, 251, 152], [0, 60, 100], [70, 130, 180], [220, 20, 20]]
            show_result_pyplot(model, full_name, result,pl)
Ejemplo n.º 3
0
def main():
    parser = ArgumentParser()
    parser.add_argument('--img', help='Image file', default=img_path)
    parser.add_argument('--config', help='Config file', default=config)
    parser.add_argument('--checkpoint', help='Checkpoint file', default=ckpt)
    parser.add_argument(
        '--device', default='cuda:1', help='Device used for inference')
    parser.add_argument(
        '--palette',
        default=None,
        help='Color palette used for segmentation map')
    args = parser.parse_args()

    # build the model from a config file and a checkpoint file
    model = init_segmentor(args.config, args.checkpoint, device=args.device)

    if args.img=='':
        list_img=get_list_file_in_folder(img_dir)
        list_img=sorted(list_img)
        for img_ in list_img:
            img=os.path.join(img_dir,img_)
            print(img)
            result = inference_segmentor(model, img)
            show_result_pyplot(model, img, result, get_palette(args.palette))
    else:
        result = inference_segmentor(model, args.img)
        show_result_pyplot(model, args.img, result, get_palette(args.palette))
Ejemplo n.º 4
0
def main():
    parser = ArgumentParser()
    parser.add_argument('img', help='Image file')
    parser.add_argument('config', help='Config file')
    parser.add_argument('checkpoint', help='Checkpoint file')
    parser.add_argument('--device',
                        default='cuda:0',
                        help='Device used for inference')
    parser.add_argument('--palette',
                        default='drive',
                        help='Color palette used for segmentation map')
    args = parser.parse_args()

    # build the model from a config file and a checkpoint file
    model = init_segmentor(args.config, args.checkpoint, device=args.device)
    # test a single image
    result = inference_segmentor(model, args.img)
    # show the results
    img = model.show_result(args.img, result, show=False)
    cv2.imwrite('out.jpg', img)
    show_result_pyplot(model, args.img, result)
def main():
    parser = ArgumentParser()
    # parser.add_argument('--img', default="Image_20200925100338349.bmp", help='Image file')
    parser.add_argument('--img', default="star.png", help='Image file')
    # parser.add_argument('--img', default="demo.png", help='Image file')
    parser.add_argument(
        '--config',
        # default="../configs/deeplabv3/deeplabv3_r50-d8_512x1024_40k_cityscapes_custom_binary.py",
        default=
        "../configs/danet/danet_r50-d8_512x1024_40k_cityscapes_custom.py",
        # default="../configs/deeplabv3plus/deeplabv3plus_r101-d16-mg124_512x1024_40k_cityscapes.py",
        help='Config file')
    parser.add_argument(
        '--checkpoint',
        # default="../tools/work_dirs/deeplabv3_r50-d8_512x1024_40k_cityscapes_custom_binary/iter_200.pth",
        default=
        "../tools/work_dirs/danet_r50-d8_512x1024_40k_cityscapes_custom/iter_4000.pth",
        # default="../checkpoints/deeplabv3plus_r101-d16-mg124_512x1024_40k_cityscapes_20200908_005644-cf9ce186.pth",
        help='Checkpoint file')
    parser.add_argument('--device',
                        default='cuda:0',
                        help='Device used for inference')
    parser.add_argument(
        '--palette',
        default='cityscapes_custom',
        # default='cityscapes',
        help='Color palette used for segmentation map')
    args = parser.parse_args()

    # build the model from a config file and a checkpoint file
    model = init_segmentor(args.config, args.checkpoint, device=args.device)
    # test a single image
    result = inference_segmentor(model, args.img)

    # io.imsave("result.png", result[0])
    # io.imshow(result[0])
    # io.show()
    # show the results
    show_result_pyplot(model, args.img, result, get_palette(args.palette))
    """
def inference_model(config_name, checkpoint, args, logger=None):
    cfg = Config.fromfile(config_name)
    if args.aug:
        if 'flip' in cfg.data.test.pipeline[
                1] and 'img_scale' in cfg.data.test.pipeline[1]:
            cfg.data.test.pipeline[1].img_ratios = [
                0.5, 0.75, 1.0, 1.25, 1.5, 1.75
            ]
            cfg.data.test.pipeline[1].flip = True
        else:
            if logger is not None:
                logger.error(f'{config_name}: unable to start aug test')
            else:
                print(f'{config_name}: unable to start aug test', flush=True)

    model = init_segmentor(cfg, checkpoint, device=args.device)
    # test a single image
    result = inference_segmentor(model, args.img)

    # show the results
    if args.show:
        show_result_pyplot(model, args.img, result)
    return result
Ejemplo n.º 7
0
def pytorch2onnx(model,
                 mm_inputs,
                 opset_version=11,
                 show=False,
                 output_file='tmp.onnx',
                 verify=False,
                 dynamic_export=False):
    """Export Pytorch model to ONNX model and verify the outputs are same
    between Pytorch and ONNX.

    Args:
        model (nn.Module): Pytorch model we want to export.
        mm_inputs (dict): Contain the input tensors and img_metas information.
        opset_version (int): The onnx op version. Default: 11.
        show (bool): Whether print the computation graph. Default: False.
        output_file (string): The path to where we store the output ONNX model.
            Default: `tmp.onnx`.
        verify (bool): Whether compare the outputs between Pytorch and ONNX.
            Default: False.
        dynamic_export (bool): Whether to export ONNX with dynamic axis.
            Default: False.
    """
    model.cpu().eval()
    test_mode = model.test_cfg.mode

    if isinstance(model.decode_head, nn.ModuleList):
        num_classes = model.decode_head[-1].num_classes
    else:
        num_classes = model.decode_head.num_classes

    imgs = mm_inputs.pop('imgs')
    img_metas = mm_inputs.pop('img_metas')

    img_list = [img[None, :] for img in imgs]
    img_meta_list = [[img_meta] for img_meta in img_metas]
    # update img_meta
    img_list, img_meta_list = _update_input_img(img_list, img_meta_list)

    # replace original forward function
    origin_forward = model.forward
    model.forward = partial(
        model.forward,
        img_metas=img_meta_list,
        return_loss=False,
        rescale=True)
    dynamic_axes = None
    if dynamic_export:
        if test_mode == 'slide':
            dynamic_axes = {'input': {0: 'batch'}, 'output': {1: 'batch'}}
        else:
            dynamic_axes = {
                'input': {
                    0: 'batch',
                    2: 'height',
                    3: 'width'
                },
                'output': {
                    1: 'batch',
                    2: 'height',
                    3: 'width'
                }
            }

    register_extra_symbolics(opset_version)
    with torch.no_grad():
        torch.onnx.export(
            model, (img_list, ),
            output_file,
            input_names=['input'],
            output_names=['output'],
            export_params=True,
            keep_initializers_as_inputs=False,
            verbose=show,
            opset_version=opset_version,
            dynamic_axes=dynamic_axes)
        print(f'Successfully exported ONNX model: {output_file}')
    model.forward = origin_forward

    if verify:
        # check by onnx
        import onnx
        onnx_model = onnx.load(output_file)
        onnx.checker.check_model(onnx_model)

        if dynamic_export and test_mode == 'whole':
            # scale image for dynamic shape test
            img_list = [resize(_, scale_factor=1.5) for _ in img_list]
            # concate flip image for batch test
            flip_img_list = [_.flip(-1) for _ in img_list]
            img_list = [
                torch.cat((ori_img, flip_img), 0)
                for ori_img, flip_img in zip(img_list, flip_img_list)
            ]

            # update img_meta
            img_list, img_meta_list = _update_input_img(
                img_list, img_meta_list, test_mode == 'whole')

        # check the numerical value
        # get pytorch output
        with torch.no_grad():
            pytorch_result = model(img_list, img_meta_list, return_loss=False)
            pytorch_result = np.stack(pytorch_result, 0)

        # get onnx output
        input_all = [node.name for node in onnx_model.graph.input]
        input_initializer = [
            node.name for node in onnx_model.graph.initializer
        ]
        net_feed_input = list(set(input_all) - set(input_initializer))
        assert (len(net_feed_input) == 1)
        sess = rt.InferenceSession(output_file)
        onnx_result = sess.run(
            None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0]
        # show segmentation results
        if show:
            import cv2
            import os.path as osp
            img = img_meta_list[0][0]['filename']
            if not osp.exists(img):
                img = imgs[0][:3, ...].permute(1, 2, 0) * 255
                img = img.detach().numpy().astype(np.uint8)
                ori_shape = img.shape[:2]
            else:
                ori_shape = LoadImage()({'img': img})['ori_shape']

            # resize onnx_result to ori_shape
            onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8),
                                      (ori_shape[1], ori_shape[0]))
            show_result_pyplot(
                model,
                img, (onnx_result_, ),
                palette=model.PALETTE,
                block=False,
                title='ONNXRuntime',
                opacity=0.5)

            # resize pytorch_result to ori_shape
            pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8),
                                         (ori_shape[1], ori_shape[0]))
            show_result_pyplot(
                model,
                img, (pytorch_result_, ),
                title='PyTorch',
                palette=model.PALETTE,
                opacity=0.5)
        # compare results
        np.testing.assert_allclose(
            pytorch_result.astype(np.float32) / num_classes,
            onnx_result.astype(np.float32) / num_classes,
            rtol=1e-5,
            atol=1e-5,
            err_msg='The outputs are different between Pytorch and ONNX')
        print('The outputs are same between Pytorch and ONNX')
def plot_result(model, img: str):
    result = inference_segmentor(model, img)
    show_result_pyplot(model, img, result, palette=PALETTE)