def test_batch_predict(self):
        # Setup Expected Response
        expected_response = {}
        expected_response = prediction_service_pb2.BatchPredictResult(
            **expected_response
        )
        operation = operations_pb2.Operation(
            name="operations/test_batch_predict", done=True
        )
        operation.response.Pack(expected_response)

        # Mock the API response
        channel = ChannelStub(responses=[operation])
        patch = mock.patch("google.api_core.grpc_helpers.create_channel")
        with patch as create_channel:
            create_channel.return_value = channel
            client = automl_v1beta1.PredictionServiceClient()

        # Setup Request
        name = client.model_path("[PROJECT]", "[LOCATION]", "[MODEL]")
        input_config = {}
        output_config = {}

        response = client.batch_predict(name, input_config, output_config)
        result = response.result()
        assert expected_response == result

        assert len(channel.requests) == 1
        expected_request = prediction_service_pb2.BatchPredictRequest(
            name=name, input_config=input_config, output_config=output_config
        )
        actual_request = channel.requests[0][1]
        assert expected_request == actual_request
Ejemplo n.º 2
0
    def batch_predict(
        self,
        name,
        input_config,
        output_config,
        params=None,
        retry=google.api_core.gapic_v1.method.DEFAULT,
        timeout=google.api_core.gapic_v1.method.DEFAULT,
        metadata=None,
    ):
        """
        Perform a batch prediction. Unlike the online ``Predict``, batch
        prediction result won't be immediately available in the response.
        Instead, a long running operation object is returned. User can poll the
        operation result via ``GetOperation`` method. Once the operation is
        done, ``BatchPredictResult`` is returned in the ``response`` field.
        Available for following ML problems:

        -  Image Classification
        -  Image Object Detection
        -  Video Classification
        -  Video Object Tracking \* Text Extraction
        -  Tables

        Example:
            >>> from google.cloud import automl_v1beta1
            >>>
            >>> client = automl_v1beta1.PredictionServiceClient()
            >>>
            >>> name = client.model_path('[PROJECT]', '[LOCATION]', '[MODEL]')
            >>>
            >>> # TODO: Initialize `input_config`:
            >>> input_config = {}
            >>>
            >>> # TODO: Initialize `output_config`:
            >>> output_config = {}
            >>>
            >>> response = client.batch_predict(name, input_config, output_config)
            >>>
            >>> def callback(operation_future):
            ...     # Handle result.
            ...     result = operation_future.result()
            >>>
            >>> response.add_done_callback(callback)
            >>>
            >>> # Handle metadata.
            >>> metadata = response.metadata()

        Args:
            name (str): Name of the model requested to serve the batch prediction.
            input_config (Union[dict, ~google.cloud.automl_v1beta1.types.BatchPredictInputConfig]): Required. The input configuration for batch prediction.

                If a dict is provided, it must be of the same form as the protobuf
                message :class:`~google.cloud.automl_v1beta1.types.BatchPredictInputConfig`
            output_config (Union[dict, ~google.cloud.automl_v1beta1.types.BatchPredictOutputConfig]): Required. The Configuration specifying where output predictions should
                be written.

                If a dict is provided, it must be of the same form as the protobuf
                message :class:`~google.cloud.automl_v1beta1.types.BatchPredictOutputConfig`
            params (dict[str -> str]): Additional domain-specific parameters for the predictions, any string
                must be up to 25000 characters long.

                -  For Text Classification:

                   ``score_threshold`` - (float) A value from 0.0 to 1.0. When the model
                   makes predictions for a text snippet, it will only produce results
                   that have at least this confidence score. The default is 0.5.

                -  For Image Classification:

                   ``score_threshold`` - (float) A value from 0.0 to 1.0. When the model
                   makes predictions for an image, it will only produce results that
                   have at least this confidence score. The default is 0.5.

                -  For Image Object Detection:

                   ``score_threshold`` - (float) When Model detects objects on the
                   image, it will only produce bounding boxes which have at least this
                   confidence score. Value in 0 to 1 range, default is 0.5.
                   ``max_bounding_box_count`` - (int64) No more than this number of
                   bounding boxes will be produced per image. Default is 100, the
                   requested value may be limited by server.

                -  For Video Classification : ``score_threshold`` - (float) A value from
                   0.0 to 1.0. When the model makes predictions for a video, it will
                   only produce results that have at least this confidence score. The
                   default is 0.5. ``segment_classification`` - (boolean) Set to true to
                   request segment-level classification. AutoML Video Intelligence
                   returns labels and their confidence scores for the entire segment of
                   the video that user specified in the request configuration. The
                   default is "true". ``shot_classification`` - (boolean) Set to true to
                   request shot-level classification. AutoML Video Intelligence
                   determines the boundaries for each camera shot in the entire segment
                   of the video that user specified in the request configuration. AutoML
                   Video Intelligence then returns labels and their confidence scores
                   for each detected shot, along with the start and end time of the
                   shot. WARNING: Model evaluation is not done for this classification
                   type, the quality of it depends on training data, but there are no
                   metrics provided to describe that quality. The default is "false".
                   ``1s_interval_classification`` - (boolean) Set to true to request
                   classification for a video at one-second intervals. AutoML Video
                   Intelligence returns labels and their confidence scores for each
                   second of the entire segment of the video that user specified in the
                   request configuration. WARNING: Model evaluation is not done for this
                   classification type, the quality of it depends on training data, but
                   there are no metrics provided to describe that quality. The default
                   is "false".

                -  For Video Object Tracking: ``score_threshold`` - (float) When Model
                   detects objects on video frames, it will only produce bounding boxes
                   which have at least this confidence score. Value in 0 to 1 range,
                   default is 0.5. ``max_bounding_box_count`` - (int64) No more than
                   this number of bounding boxes will be returned per frame. Default is
                   100, the requested value may be limited by server.
                   ``min_bounding_box_size`` - (float) Only bounding boxes with shortest
                   edge at least that long as a relative value of video frame size will
                   be returned. Value in 0 to 1 range. Default is 0.
            retry (Optional[google.api_core.retry.Retry]):  A retry object used
                to retry requests. If ``None`` is specified, requests will
                be retried using a default configuration.
            timeout (Optional[float]): The amount of time, in seconds, to wait
                for the request to complete. Note that if ``retry`` is
                specified, the timeout applies to each individual attempt.
            metadata (Optional[Sequence[Tuple[str, str]]]): Additional metadata
                that is provided to the method.

        Returns:
            A :class:`~google.cloud.automl_v1beta1.types._OperationFuture` instance.

        Raises:
            google.api_core.exceptions.GoogleAPICallError: If the request
                    failed for any reason.
            google.api_core.exceptions.RetryError: If the request failed due
                    to a retryable error and retry attempts failed.
            ValueError: If the parameters are invalid.
        """
        # Wrap the transport method to add retry and timeout logic.
        if "batch_predict" not in self._inner_api_calls:
            self._inner_api_calls[
                "batch_predict"] = google.api_core.gapic_v1.method.wrap_method(
                    self.transport.batch_predict,
                    default_retry=self._method_configs["BatchPredict"].retry,
                    default_timeout=self._method_configs["BatchPredict"].
                    timeout,
                    client_info=self._client_info,
                )

        request = prediction_service_pb2.BatchPredictRequest(
            name=name,
            input_config=input_config,
            output_config=output_config,
            params=params,
        )
        if metadata is None:
            metadata = []
        metadata = list(metadata)
        try:
            routing_header = [("name", name)]
        except AttributeError:
            pass
        else:
            routing_metadata = google.api_core.gapic_v1.routing_header.to_grpc_metadata(
                routing_header)
            metadata.append(routing_metadata)

        operation = self._inner_api_calls["batch_predict"](request,
                                                           retry=retry,
                                                           timeout=timeout,
                                                           metadata=metadata)
        return google.api_core.operation.from_gapic(
            operation,
            self.transport._operations_client,
            prediction_service_pb2.BatchPredictResult,
            metadata_type=proto_operations_pb2.OperationMetadata,
        )