def set_configuration(self, cfg_in):
        cfg = self.get_configuration()

        # Imports used across this func
        import os
        import torch
        from bioharn import detect_predict

        # HACK: merge config doesn't support dictionary input
        _vital_config_update(cfg, cfg_in)

        for key in self._kwiver_config.keys():
            self._kwiver_config[key] = str(cfg.get_value(key))
        self._kwiver_config['thresh'] = float(self._kwiver_config['thresh'])

        self._thresh = float(self._kwiver_config['thresh'])

        if self._kwiver_config['batch_size'] == "auto":
            self._kwiver_config['batch_size'] = 2

            if torch.cuda.is_available():
                gpu_mem = 0
                if len(self._kwiver_config['xpu']) == 1 and \
                  self._kwiver_config['xpu'] != 0:
                    gpu_id = int(self._kwiver_config['xpu'])
                    gpu_mem = torch.cuda.get_device_properties(
                        gpu_id).total_memory
                else:
                    self._gpu_count = torch.cuda.device_count()
                    for i in range(self._gpu_count):
                        single_gpu_mem = torch.cuda.get_device_properties(
                            i).total_memory
                    if gpu_mem == 0:
                        gpu_mem = single_gpu_mem
                    else:
                        gpu_mem = min(gpu_mem, single_gpu_mem)
                if gpu_mem > 9e9:
                    self._kwiver_config['batch_size'] = 4
                elif gpu_mem >= 7e9:
                    self._kwiver_config['batch_size'] = 3

        if os.name == 'nt':
            os.environ["KWIMAGE_DISABLE_TORCHVISION_NMS"] = "1"

        pred_config = detect_predict.DetectPredictConfig()
        pred_config['batch_size'] = self._kwiver_config['batch_size']
        pred_config['deployed'] = self._kwiver_config['deployed']
        if torch.cuda.is_available():
            pred_config['xpu'] = self._kwiver_config['xpu']
        else:
            pred_config['xpu'] = "cpu"
        self.predictor = detect_predict.DetectPredictor(pred_config)

        self.predictor._ensure_model()
        return True
Beispiel #2
0
    def set_configuration(self, cfg_in):
        cfg = self.get_configuration()

        # HACK: merge config doesn't support dictionary input
        _vital_config_update(cfg, cfg_in)

        for key in self._kwiver_config.keys():
            self._kwiver_config[key] = str(cfg.get_value(key))
        self._kwiver_config['thresh'] = float(self._kwiver_config['thresh'])

        self._thresh = float(self._kwiver_config['thresh'])

        if self._kwiver_config['batch_size'] == "auto":
            self._kwiver_config['batch_size'] = 2
            import torch
            if torch.cuda.is_available():
                gpu_mem = 0
                if len(self._kwiver_config['xpu']) == 1 and \
                  self._kwiver_config['xpu'] != 0:
                    gpu_id = int(self._kwiver_config['xpu'])
                    gpu_mem = torch.cuda.get_device_properties(
                        gpu_id).total_memory
                else:
                    self._gpu_count = torch.cuda.device_count()
                    for i in range(self._gpu_count):
                        single_gpu_mem = torch.cuda.get_device_properties(
                            i).total_memory
                    if gpu_mem == 0:
                        gpu_mem = single_gpu_mem
                    else:
                        gpu_mem = min(gpu_mem, single_gpu_mem)
                if gpu_mem > 9e9:
                    self._kwiver_config['batch_size'] = 4
                elif gpu_mem >= 7e9:
                    self._kwiver_config['batch_size'] = 3

        from bioharn import detect_predict
        pred_config = detect_predict.DetectPredictConfig()
        pred_config['batch_size'] = self._kwiver_config['batch_size']
        pred_config['deployed'] = self._kwiver_config['deployed']
        pred_config['xpu'] = self._kwiver_config['xpu']
        self.predictor = detect_predict.DetectPredictor(pred_config)

        self.predictor._ensure_model()
        return True