示例#1
0
    def load_model(self, descriptor):
        """Loads model on VisionBonnet.

        Args:
          descriptor: ModelDescriptor, meta info that defines model name,
            where to get the model and etc.
        Returns:
          Model identifier.
        """
        logging.info('Loading model "%s"...', descriptor.name)

        batch, height, width, depth = descriptor.input_shape
        assert batch == 1, 'Only batch == 1 is currently supported'
        assert depth == 3, 'Only depth == 3 is currently supported'
        mean, stddev = descriptor.input_normalizer

        request = protocol_pb2.Request()
        request.load_model.model_name = descriptor.name
        request.load_model.input_shape.batch = batch
        request.load_model.input_shape.height = height
        request.load_model.input_shape.width = width
        request.load_model.input_shape.depth = depth
        request.load_model.input_normalizer.mean = mean
        request.load_model.input_normalizer.stddev = stddev
        if descriptor.compute_graph:
            request.load_model.compute_graph = descriptor.compute_graph

        try:
            self._communicate(request)
        except InferenceException as e:
            logging.warning(str(e))

        return descriptor.name
示例#2
0
    def load_model(self, descriptor):
        """Loads model on VisionBonnet.

        Args:
          descriptor: ModelDescriptor, meta info that defines model name,
            where to get the model and etc.
        Returns:
          Model identifier.
        """
        _check_firmware_info(self.get_firmware_info())
        mean, stddev = descriptor.input_normalizer
        batch, height, width, depth = descriptor.input_shape
        if batch != 1:
            raise ValueError('Unsupported batch value: %d. Must be 1.')

        if depth != 3:
            raise ValueError('Unsupported depth value: %d. Must be 3.')

        try:
            logger.info('Load model "%s".', descriptor.name)
            self._communicate(
                pb2.Request(load_model=pb2.Request.LoadModel(
                    model_name=descriptor.name,
                    input_shape=pb2.TensorShape(
                        batch=batch, height=height, width=width, depth=depth),
                    input_normalizer=pb2.TensorNormalizer(mean=mean,
                                                          stddev=stddev),
                    compute_graph=descriptor.compute_graph)))
        except InferenceException as e:
            logger.warning(str(e))

        return descriptor.name
  def image_inference(self, model_name, image, params=None):
    """Runs inference on image using model (identified by model_name).

    Args:
      model_name: string, unique identifier used to refer a model.
      image: PIL.Image,
      params: dict, additional parameters to run inference

    Returns:
      protocol_pb2.Response
    """

    assert model_name, 'model_name must not be empty'
    assert image.mode == 'RGB', 'Only image.mode == RGB is supported.'

    logging.info('Image inference with model "%s"...', model_name)

    r, g, b = image.split()
    width, height = image.size

    request = protocol_pb2.Request()
    request.image_inference.model_name = model_name
    request.image_inference.tensor.shape.height = height
    request.image_inference.tensor.shape.width = width
    request.image_inference.tensor.shape.depth = 3
    request.image_inference.tensor.data = (
        _tobytes(r) + _tobytes(g) + _tobytes(b))

    for key, value in (params or {}).items():
      request.image_inference.params[key] = str(value)

    return self._communicate(request).result
示例#4
0
    def image_inference(self,
                        model_name,
                        image,
                        params=None,
                        sparse_configs=None):
        """Runs inference on image using model identified by model_name.

        Args:
          model_name: string, unique identifier used to refer a model.
          image: PIL.Image,
          params: dict, additional parameters to run inference

        Returns:
          pb2.Response.InferenceResult
        """
        _check_model_name(model_name)

        logger.info('Image inference on "%s".', model_name)
        return self._communicate(
            pb2.Request(image_inference=pb2.Request.ImageInference(
                model_name=model_name,
                tensor=_image_to_tensor(image),
                params=_get_params(params),
                sparse_configs=_get_sparse_configs(
                    sparse_configs)))).inference_result
示例#5
0
    def start_camera_inference(self, model_name, params=None):
        """Starts inference running on VisionBonnet."""
        request = protocol_pb2.Request()
        request.start_camera_inference.model_name = model_name

        for key, value in (params or {}).items():
            request.start_camera_inference.params[key] = str(value)

        self._communicate(request)
示例#6
0
    def start_camera_inference(self, model_name, params=None):
        """Starts inference running on VisionBonnet."""
        _check_model_name(model_name)

        logger.info('Start camera inference on "%s".', model_name)
        self._communicate(
            pb2.Request(
                start_camera_inference=pb2.Request.StartCameraInference(
                    model_name=model_name, params=_get_params(params))))
示例#7
0
 def get_firmware_info(self):
     """Returns firmware version as (major, minor) tuple."""
     request = protocol_pb2.Request()
     request.get_firmware_info.SetInParent()
     try:
         info = self._communicate(request).firmware_info
         return (info.major_version, info.minor_version)
     except InferenceException:
         # Request is not supported by firmware, default to 1.0
         return (1, 0)
示例#8
0
    def unload_model(self, model_name):
        """Deletes model on VisionBonnet.

        Args:
          model_name: string, unique identifier used to refer a model.
        """
        logging.info('Unloading model "%s"...', model_name)

        request = protocol_pb2.Request()
        request.unload_model.model_name = model_name
        self._communicate(request)
示例#9
0
    def unload_model(self, model_name):
        """Deletes model on VisionBonnet.

        Args:
          model_name: string, unique identifier used to refer a model.
        """
        _check_model_name(model_name)

        logger.info('Unload model "%s".', model_name)
        self._communicate(pb2.Request(
            unload_model=pb2.Request.UnloadModel(model_name=model_name)))
示例#10
0
    def load_model(self, descriptor):
        """Loads model on VisionBonnet.

        Args:
          descriptor: ModelDescriptor, meta info that defines model name,
            where to get the model and etc.
        Returns:
          Model identifier.
        """
        _check_firmware_info(self.get_firmware_info())

        logging.info('Loading model "%s"...', descriptor.name)

        batch, height, width, depth = descriptor.input_shape
        mean, stddev = descriptor.input_normalizer
        if batch != 1:
            raise ValueError('Unsupported batch value: %d. Must be 1.')

        if depth != 3:
            raise ValueError('Unsupported depth value: %d. Must be 3.')

        request = protocol_pb2.Request()
        request.load_model.model_name = descriptor.name
        request.load_model.input_shape.batch = batch
        request.load_model.input_shape.height = height
        request.load_model.input_shape.width = width
        request.load_model.input_shape.depth = depth
        request.load_model.input_normalizer.mean = mean
        request.load_model.input_normalizer.stddev = stddev
        if descriptor.compute_graph:
            request.load_model.compute_graph = descriptor.compute_graph

        try:
            self._communicate(request)
        except InferenceException as e:
            logging.warning(str(e))

        return descriptor.name
示例#11
0
    def image_inference(self, model_name, image, params=None):
        """Runs inference on image using model (identified by model_name).

        Args:
          model_name: string, unique identifier used to refer a model.
          image: PIL.Image,
          params: dict, additional parameters to run inference

        Returns:
          protocol_pb2.Response
        """
        if not model_name:
            raise ValueError('Model name must not be empty.')

        logging.info('Image inference with model "%s"...', model_name)

        width, height = image.size

        request = protocol_pb2.Request()
        request.image_inference.model_name = model_name
        request.image_inference.tensor.shape.height = height
        request.image_inference.tensor.shape.width = width

        if image.mode == 'RGB':
            r, g, b = image.split()
            request.image_inference.tensor.shape.depth = 3
            request.image_inference.tensor.data = r.tobytes() + g.tobytes() + b.tobytes()
        elif image.mode == 'L':
            request.image_inference.tensor.shape.depth = 1
            request.image_inference.tensor.data = image.tobytes()
        else:
            raise InferenceException('Unsupported image format: %s. Must be L or RGB.' % image.mode)

        for key, value in (params or {}).items():
            request.image_inference.params[key] = str(value)

        return self._communicate(request).inference_result
示例#12
0
def get_camera_state(spicomm, timeout=None):
    request = pb2.Request(get_camera_state=pb2.Request.GetCameraState())
    response = pb2.Response()
    response.ParseFromString(
        spicomm.transact(request.SerializeToString(), timeout))
    return response
示例#13
0
 def get_camera_state(self):
     request = protocol_pb2.Request()
     request.get_camera_state.SetInParent()
     return self._communicate(request).camera_state
示例#14
0
 def stop_camera_inference(self):
     """Stops inference running on VisionBonnet."""
     request = protocol_pb2.Request()
     request.stop_camera_inference.SetInParent()
     self._communicate(request)
示例#15
0
 def camera_inference(self):
     """Returns the latest inference result from VisionBonnet."""
     request = protocol_pb2.Request()
     request.camera_inference.SetInParent()
     return self._communicate(request).inference_result
示例#16
0
def _request_bytes(*args, **kwargs):
    return pb2.Request(*args, **kwargs).SerializeToString()