Exemplo n.º 1
0
def test_create_new_channel_after_timeout_expires():

    channel_mock_loaded = {'value': 0}

    def unary_unary(id, request_serializer, response_deserializer):
        result = mock.MagicMock()
        if id == '/tensorflow.serving.PredictionService/Predict':
            return_data = np.asarray([[1, 2, 3]])
            return_tensor = tf.contrib.util.make_tensor_proto(
                return_data, types_pb2.DT_FLOAT, return_data.shape)
            result.outputs = {"output_alias": return_tensor}
        return lambda req, timeout: result

    def load_channel_mock():
        channel_mock_loaded['value'] += 1
        return channel_mock

    now = datetime.now()

    channel_mock = mock.Mock()
    channel_mock.unary_unary = mock.MagicMock(side_effect=unary_unary)

    image_file_path = os.path.join(tempfile.mkdtemp(), "img.png")
    image_file = open(image_file_path, "w")
    image_file.write("abc")
    image_file.close()

    client = PredictionClient("localhost",
                              50051,
                              channel_shutdown_timeout=timedelta(minutes=1))
    client._channel_func = load_channel_mock
    client._get_datetime_now = lambda: now

    result = client.score_image(image_file_path)
    assert all([x == y for x, y in zip(result, [1, 2, 3])])
    assert channel_mock_loaded['value'] == 1

    now = now + timedelta(seconds=50)
    result = client.score_image(image_file_path)
    assert all([x == y for x, y in zip(result, [1, 2, 3])])
    assert channel_mock_loaded['value'] == 1

    now = now + timedelta(seconds=20)
    result = client.score_image(image_file_path)
    assert all([x == y for x, y in zip(result, [1, 2, 3])])
    assert channel_mock_loaded['value'] == 1

    now = now + timedelta(seconds=70)
    result = client.score_image(image_file_path)
    assert all([x == y for x, y in zip(result, [1, 2, 3])])
    assert channel_mock_loaded['value'] == 2
Exemplo n.º 2
0
def test_retrying_rpc_exception():

    first_call = [True]

    channel_mock_loaded = {'value': 0}
    channel_mock_closed = {'value': 0}

    def unary_unary(id, request_serializer, response_deserializer):
        result = mock.MagicMock()
        if id == '/tensorflow.serving.PredictionService/Predict':
            if (first_call[0]):
                first_call[0] = False
                return lambda req, timeout: (_ for _ in
                                             ()).throw(grpc.RpcError())

            return_data = np.asarray([[11, 22]])
            return_tensor = tf.contrib.util.make_tensor_proto(
                return_data, types_pb2.DT_FLOAT, return_data.shape)
            result.outputs = {"output_alias": return_tensor}
        return lambda req, timeout: result

    def load_channel_mock():
        channel_mock_loaded['value'] += 1
        return channel_mock

    def close_channel_mock():
        channel_mock_closed['value'] += 1

    now = datetime.now()

    channel_mock = mock.Mock()
    channel_mock.unary_unary = mock.MagicMock(side_effect=unary_unary)
    channel_mock.close = close_channel_mock

    client = PredictionClient("localhost",
                              50051,
                              channel_shutdown_timeout=timedelta(minutes=1))
    client._channel_func = load_channel_mock
    client._get_datetime_now = lambda: now

    result = client.score_numpy_array(np.asarray([[1, 2]], dtype='f'))
    assert all([x == y for x, y in zip(result[0], [11, 22])])

    assert channel_mock_loaded['value'] == 2
    assert channel_mock_closed['value'] == 1