예제 #1
0
    def test_mnist_happy_path(self):
        input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.pb')
        output_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_output.pb')

        with open(input_data_file, 'rb') as f:
            request_payload = f.read()

        request = predict_pb2.PredictRequest()
        request.ParseFromString(request_payload)
        uri = "{}:{}".format(self.server_ip, self.server_port)
        test_util.test_log(uri)
        with grpc.insecure_channel(uri) as channel:
            stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
            actual_result = stub.Predict(request)

        expected_result = predict_pb2.PredictResponse()
        with open(output_data_file, 'rb') as f:
            expected_result.ParseFromString(f.read())

        for k in expected_result.outputs.keys():
            self.assertEqual(actual_result.outputs[k].data_type, expected_result.outputs[k].data_type)

        count = 1
        for i in range(0, len(expected_result.outputs['Plus214_Output_0'].dims)):
            self.assertEqual(actual_result.outputs['Plus214_Output_0'].dims[i], expected_result.outputs['Plus214_Output_0'].dims[i])
            count = count * int(actual_result.outputs['Plus214_Output_0'].dims[i])

        actual_array = numpy.frombuffer(actual_result.outputs['Plus214_Output_0'].raw_data, dtype=numpy.float32)
        expected_array = numpy.frombuffer(expected_result.outputs['Plus214_Output_0'].raw_data, dtype=numpy.float32)
        self.assertEqual(len(actual_array), len(expected_array))
        self.assertEqual(len(actual_array), count)
        for i in range(0, count):
            self.assertTrue(test_util.compare_floats(actual_array[i], expected_array[i], rel_tol=0.001))
예제 #2
0
def __get_inference(url: str, port_number: str, data: np.array) -> dict:
    with grpc.insecure_channel('{url}:{port_number}'.format(
            url=url, port_number=port_number)) as channel:
        stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
        request_message = __get_request_message(data)
        response = stub.Predict(request_message)
        response = __parse_response(response)
        return response