예제 #1
0
def test_predict(transport: str = "grpc"):
    client = PredictionServiceClient(
        credentials=credentials.AnonymousCredentials(), transport=transport
    )

    # Everything is optional in proto3 as far as the runtime is concerned,
    # and we are mocking out the actual API, so just send an empty request.
    request = prediction_service.PredictRequest()

    # Mock the actual call within the gRPC stub, and fake the request.
    with mock.patch.object(type(client._transport.predict), "__call__") as call:
        # Designate an appropriate return value for the call.
        call.return_value = prediction_service.PredictResponse(
            recommendation_token="recommendation_token_value",
            items_missing_in_catalog=["items_missing_in_catalog_value"],
            dry_run=True,
            next_page_token="next_page_token_value",
        )

        response = client.predict(request)

        # Establish that the underlying gRPC stub method was called.
        assert len(call.mock_calls) == 1
        _, args, _ = call.mock_calls[0]

        assert args[0] == request

    # Establish that the response is the type that we expect.
    assert isinstance(response, pagers.PredictPager)
    assert response.recommendation_token == "recommendation_token_value"
    assert response.items_missing_in_catalog == ["items_missing_in_catalog_value"]

    assert response.dry_run is True
    assert response.next_page_token == "next_page_token_value"
예제 #2
0
def test_transport_instance():
    # A client may be instantiated with a custom transport instance.
    transport = transports.PredictionServiceGrpcTransport(
        credentials=credentials.AnonymousCredentials()
    )
    client = PredictionServiceClient(transport=transport)
    assert client._transport is transport
예제 #3
0
def test_predict_pages():
    client = PredictionServiceClient(credentials=credentials.AnonymousCredentials)

    # Mock the actual call within the gRPC stub, and fake the request.
    with mock.patch.object(type(client._transport.predict), "__call__") as call:
        # Set the response to a series of pages.
        call.side_effect = (
            prediction_service.PredictResponse(
                results=[
                    prediction_service.PredictResponse.PredictionResult(),
                    prediction_service.PredictResponse.PredictionResult(),
                    prediction_service.PredictResponse.PredictionResult(),
                ],
                next_page_token="abc",
            ),
            prediction_service.PredictResponse(results=[], next_page_token="def"),
            prediction_service.PredictResponse(
                results=[prediction_service.PredictResponse.PredictionResult()],
                next_page_token="ghi",
            ),
            prediction_service.PredictResponse(
                results=[
                    prediction_service.PredictResponse.PredictionResult(),
                    prediction_service.PredictResponse.PredictionResult(),
                ]
            ),
            RuntimeError,
        )
        pages = list(client.predict(request={}).pages)
        for page, token in zip(pages, ["abc", "def", "ghi", ""]):
            assert page.raw_page.next_page_token == token
예제 #4
0
def test_predict_pager():
    client = PredictionServiceClient(credentials=credentials.AnonymousCredentials)

    # Mock the actual call within the gRPC stub, and fake the request.
    with mock.patch.object(type(client._transport.predict), "__call__") as call:
        # Set the response to a series of pages.
        call.side_effect = (
            prediction_service.PredictResponse(
                results=[
                    prediction_service.PredictResponse.PredictionResult(),
                    prediction_service.PredictResponse.PredictionResult(),
                    prediction_service.PredictResponse.PredictionResult(),
                ],
                next_page_token="abc",
            ),
            prediction_service.PredictResponse(results=[], next_page_token="def"),
            prediction_service.PredictResponse(
                results=[prediction_service.PredictResponse.PredictionResult()],
                next_page_token="ghi",
            ),
            prediction_service.PredictResponse(
                results=[
                    prediction_service.PredictResponse.PredictionResult(),
                    prediction_service.PredictResponse.PredictionResult(),
                ]
            ),
            RuntimeError,
        )
        results = [i for i in client.predict(request={})]
        assert len(results) == 6
        assert all(
            isinstance(i, prediction_service.PredictResponse.PredictionResult)
            for i in results
        )
예제 #5
0
def test_prediction_service_auth_adc():
    # If no credentials are provided, we should use ADC credentials.
    with mock.patch.object(auth, "default") as adc:
        adc.return_value = (credentials.AnonymousCredentials(), None)
        PredictionServiceClient()
        adc.assert_called_once_with(
            scopes=("https://www.googleapis.com/auth/cloud-platform",)
        )
예제 #6
0
def test_prediction_service_client_client_options_from_dict():
    with mock.patch(
        "google.cloud.recommendationengine_v1beta1.services.prediction_service.PredictionServiceClient.get_transport_class"
    ) as gtc:
        transport = gtc.return_value = mock.MagicMock()
        client = PredictionServiceClient(
            client_options={"api_endpoint": "squid.clam.whelk"}
        )
        transport.assert_called_once_with(credentials=None, host="squid.clam.whelk")
예제 #7
0
def test_prediction_service_host_with_port():
    client = PredictionServiceClient(
        credentials=credentials.AnonymousCredentials(),
        client_options=client_options.ClientOptions(
            api_endpoint="recommendationengine.googleapis.com:8000"
        ),
        transport="grpc",
    )
    assert client._transport._host == "recommendationengine.googleapis.com:8000"
예제 #8
0
def test_credentials_transport_error():
    # It is an error to provide credentials and a transport instance.
    transport = transports.PredictionServiceGrpcTransport(
        credentials=credentials.AnonymousCredentials()
    )
    with pytest.raises(ValueError):
        client = PredictionServiceClient(
            credentials=credentials.AnonymousCredentials(), transport=transport
        )
예제 #9
0
def test_prediction_service_client_client_options():
    # Check the default options have their expected values.
    assert (
        PredictionServiceClient.DEFAULT_OPTIONS.api_endpoint
        == "recommendationengine.googleapis.com"
    )

    # Check that options can be customized.
    options = client_options.ClientOptions(api_endpoint="squid.clam.whelk")
    with mock.patch(
        "google.cloud.recommendationengine_v1beta1.services.prediction_service.PredictionServiceClient.get_transport_class"
    ) as gtc:
        transport = gtc.return_value = mock.MagicMock()
        client = PredictionServiceClient(client_options=options)
        transport.assert_called_once_with(credentials=None, host="squid.clam.whelk")
예제 #10
0
def test_transport_grpc_default():
    # A client should use the gRPC transport by default.
    client = PredictionServiceClient(credentials=credentials.AnonymousCredentials())
    assert isinstance(client._transport, transports.PredictionServiceGrpcTransport)