Ejemplo n.º 1
0
 def __init__(self, trt_file: str, device_id: int):
     super().__init__()
     from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
     try:
         load_tensorrt_plugin()
     except (ImportError, ModuleNotFoundError):
         warnings.warn('If input model has custom op from mmcv, \
             you may have to build mmcv with TensorRT from source.')
     model = TRTWraper(
         trt_file, input_names=['input'], output_names=['output'])
     self.device_id = device_id
     self.model = model
Ejemplo n.º 2
0
    def __init__(self, engine_file, class_names, device_id, output_names=None):
        super(TensorRTDetector, self).__init__(class_names, device_id)
        warnings.warn('`output_names` is deprecated and will be removed in '
                      'future releases.')
        from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
        try:
            load_tensorrt_plugin()
        except (ImportError, ModuleNotFoundError):
            warnings.warn('If input model has custom op from mmcv, \
                you may have to build mmcv with TensorRT from source.')

        output_names = ['dets', 'labels']
        model = TRTWraper(engine_file, ['input'], output_names)
        with_masks = False
        # if TensorRT has totally 4 inputs/outputs, then
        # the detector should have `mask` output.
        if len(model.engine) == 4:
            model.output_names = output_names + ['masks']
            with_masks = True
        self.model = model
        self.with_masks = with_masks
Ejemplo n.º 3
0
    def __init__(self,
                 trt_file: str,
                 cfg: Any,
                 device_id: int,
                 show_score: bool = False):
        if 'type' in cfg.model:
            cfg.model.pop('type')
        EncodeDecodeRecognizer.__init__(self, **(cfg.model))
        from mmcv.tensorrt import TRTWrapper, load_tensorrt_plugin
        try:
            load_tensorrt_plugin()
        except (ImportError, ModuleNotFoundError):
            warnings.warn('If input model has custom op from mmcv, \
                you may have to build mmcv with TensorRT from source.')
        model = TRTWrapper(trt_file,
                           input_names=['input'],
                           output_names=['output'])

        self.model = model
        self.device_id = device_id
        self.cfg = cfg
Ejemplo n.º 4
0
    def __init__(self,
                 trt_file: str,
                 cfg: Any,
                 device_id: int,
                 show_score: bool = False):
        EncodeDecodeRecognizer.__init__(self, cfg.model.preprocessor,
                                        cfg.model.backbone, cfg.model.encoder,
                                        cfg.model.decoder, cfg.model.loss,
                                        cfg.model.label_convertor,
                                        cfg.train_cfg, cfg.test_cfg, 40,
                                        cfg.model.pretrained)
        from mmcv.tensorrt import TRTWrapper, load_tensorrt_plugin
        try:
            load_tensorrt_plugin()
        except (ImportError, ModuleNotFoundError):
            warnings.warn('If input model has custom op from mmcv, \
                you may have to build mmcv with TensorRT from source.')
        model = TRTWrapper(trt_file,
                           input_names=['input'],
                           output_names=['output'])

        self.model = model
        self.device_id = device_id
        self.cfg = cfg
Ejemplo n.º 5
0
    def __init__(self,
                 trt_file: str,
                 cfg: Any,
                 device_id: int,
                 show_score: bool = False):
        SingleStageTextDetector.__init__(self, cfg.model.backbone,
                                         cfg.model.neck, cfg.model.bbox_head,
                                         cfg.model.train_cfg,
                                         cfg.model.test_cfg,
                                         cfg.model.pretrained)
        TextDetectorMixin.__init__(self, show_score)
        from mmcv.tensorrt import TRTWrapper, load_tensorrt_plugin
        try:
            load_tensorrt_plugin()
        except (ImportError, ModuleNotFoundError):
            warnings.warn('If input model has custom op from mmcv, \
                you may have to build mmcv with TensorRT from source.')
        model = TRTWrapper(trt_file,
                           input_names=['input'],
                           output_names=['output'])

        self.model = model
        self.device_id = device_id
        self.cfg = cfg