コード例 #1
0
def unregister_cleanup_shm_regions(shm_regions, shm_handles,
                                   precreated_shm_regions, outputs,
                                   use_system_shared_memory,
                                   use_cuda_shared_memory):
    # Lazy shm imports...
    if use_system_shared_memory:
        import tritonclient.utils.shared_memory as shm
    if use_cuda_shared_memory:
        import tritonclient.utils.cuda_shared_memory as cudashm

    if not (use_system_shared_memory or use_cuda_shared_memory):
        return None

    triton_client = httpclient.InferenceServerClient(
        f"{_tritonserver_ipaddr}:8000")

    if use_cuda_shared_memory:
        triton_client.unregister_cuda_shared_memory(shm_regions[0] + '_data')
        triton_client.unregister_cuda_shared_memory(shm_regions[1] + '_data')
        cudashm.destroy_shared_memory_region(shm_handles[0])
        cudashm.destroy_shared_memory_region(shm_handles[1])
    else:
        triton_client.unregister_system_shared_memory(shm_regions[0] + '_data')
        triton_client.unregister_system_shared_memory(shm_regions[1] + '_data')
        shm.destroy_shared_memory_region(shm_handles[0])
        shm.destroy_shared_memory_region(shm_handles[1])

    if precreated_shm_regions is None:
        i = 0
        if "OUTPUT0" in outputs:
            if use_cuda_shared_memory:
                triton_client.unregister_cuda_shared_memory(shm_regions[2] +
                                                            '_data')
                cudashm.destroy_shared_memory_region(shm_handles[2])
            else:
                triton_client.unregister_system_shared_memory(shm_regions[2] +
                                                              '_data')
                shm.destroy_shared_memory_region(shm_handles[2])
            i += 1
        if "OUTPUT1" in outputs:
            if use_cuda_shared_memory:
                triton_client.unregister_cuda_shared_memory(shm_regions[2 +
                                                                        i] +
                                                            '_data')
                cudashm.destroy_shared_memory_region(shm_handles[3])
            else:
                triton_client.unregister_system_shared_memory(shm_regions[2 +
                                                                          i] +
                                                              '_data')
                shm.destroy_shared_memory_region(shm_handles[3])
コード例 #2
0
def get_triton_client():
    # set up Triton connection
    TRITONURL = "triton:8000"
    # TODO check that always available ...
    try:
        # Specify large enough concurrency to handle the
        # the number of requests.
        concurrency = 1
        triton_client = httpclient.InferenceServerClient(
            url=TRITONURL, concurrency=concurrency)
        logger.info(f"Server ready? {triton_client.is_server_ready()}")
    except Exception as e:
        logger.error("client creation failed: " + str(e))
    return triton_client
コード例 #3
0
 def test_bool(self):
     model_name = 'identity_bool'
     with self._shm_leak_detector.Probe() as shm_probe:
         with httpclient.InferenceServerClient("localhost:8000") as client:
             input_data = np.array([[True, False, True]], dtype=bool)
             inputs = [
                 httpclient.InferInput("INPUT0", input_data.shape,
                                       np_to_triton_dtype(input_data.dtype))
             ]
             inputs[0].set_data_from_numpy(input_data)
             result = client.infer(model_name, inputs)
             output0 = result.as_numpy('OUTPUT0')
             self.assertIsNotNone(output0)
             self.assertTrue(np.all(output0 == input_data))
コード例 #4
0
    def test_infer(self):
        try:
            triton_client = httpclient.InferenceServerClient(
                url="localhost:8000")
        except Exception as e:
            print("channel creation failed: " + str(e))
            sys.exit(1)

        model_name = "libtorch_int32_int32_int32"

        inputs = []
        outputs = []
        inputs.append(httpclient.InferInput('INPUT__0', [1, 16], "INT32"))
        inputs.append(httpclient.InferInput('INPUT__1', [1, 16], "INT32"))

        # Create the data for the two input tensors. Initialize the first
        # to unique integers and the second to all ones.
        input0_data = np.arange(start=0, stop=16, dtype=np.int32)
        input0_data = np.expand_dims(input0_data, axis=0)
        input1_data = np.full(shape=(1, 16), fill_value=-1, dtype=np.int32)

        # Initialize the data
        inputs[0].set_data_from_numpy(input0_data, binary_data=True)
        inputs[1].set_data_from_numpy(input1_data, binary_data=True)

        outputs.append(
            httpclient.InferRequestedOutput('OUTPUT__0', binary_data=True))
        outputs.append(
            httpclient.InferRequestedOutput('OUTPUT__1', binary_data=True))

        results = triton_client.infer(model_name, inputs, outputs=outputs)

        output0_data = results.as_numpy('OUTPUT__0')
        output1_data = results.as_numpy('OUTPUT__1')

        # Validate the results by comparing with precomputed values.
        for i in range(16):
            print(
                str(input0_data[0][i]) + " - " + str(input1_data[0][i]) +
                " = " + str(output0_data[0][i]))
            print(
                str(input0_data[0][i]) + " + " + str(input1_data[0][i]) +
                " = " + str(output1_data[0][i]))
            if (input0_data[0][i] - input1_data[0][i]) != output0_data[0][i]:
                print("sync infer error: incorrect difference")
                sys.exit(1)
            if (input0_data[0][i] + input1_data[0][i]) != output1_data[0][i]:
                print("sync infer error: incorrect sum")
                sys.exit(1)
コード例 #5
0
 def test_init_args(self):
     model_name = "init_args"
     shape = [2, 2]
     with httpclient.InferenceServerClient("localhost:8000") as client:
         input_data = np.zeros(shape, dtype=np.float32)
         inputs = [
             httpclient.InferInput("IN", input_data.shape,
                                   np_to_triton_dtype(input_data.dtype))
         ]
         inputs[0].set_data_from_numpy(input_data)
         result = client.infer(model_name, inputs)
         # output response in this model is the number of keys in the args
         self.assertTrue(
             result.as_numpy("OUT") == 7,
             "Number of keys in the init args is not correct")
コード例 #6
0
 def test_unregister_before_register(self):
     # Create a valid system shared memory region and unregister before register
     if _protocol == "http":
         triton_client = httpclient.InferenceServerClient(_url, verbose=True)
     else:
         triton_client = grpcclient.InferenceServerClient(_url, verbose=True)
     shm_op0_handle = shm.create_shared_memory_region(
         "dummy_data", "/dummy_data", 8)
     triton_client.unregister_system_shared_memory("dummy_data")
     shm_status = triton_client.get_system_shared_memory_status()
     if _protocol == "http":
         self.assertTrue(len(shm_status) == 0)
     else:
         self.assertTrue(len(shm_status.regions) == 0)
     shm.destroy_shared_memory_region(shm_op0_handle)
コード例 #7
0
    def test_dlpack_tensor(self):
        model_name = "dlpack_test"
        with httpclient.InferenceServerClient("localhost:8000") as client:
            # Input data is not used.
            input_data = np.array([1], dtype=np.float32)
            inputs = [
                httpclient.InferInput("INPUT0", input_data.shape,
                                      np_to_triton_dtype(input_data.dtype))
            ]
            inputs[0].set_data_from_numpy(input_data)
            result = client.infer(model_name, inputs)
            output0 = result.as_numpy('OUTPUT0')

            # The model returns 1 if the tests were sucessfully passed.
            # Otherwise, it will return 0.
            self.assertTrue(output0 == [1])
コード例 #8
0
 def test_http_get_settings(self):
     # Model trace settings will be the same as global trace settings since
     # no update has been made.
     initial_settings = {
         "trace_file": "global_unittest.log",
         "trace_level": ["TIMESTAMPS"],
         "trace_rate": "1",
         "trace_count": "-1",
         "log_frequency": "0"
     }
     triton_client = httpclient.InferenceServerClient("localhost:8000")
     self.assertEqual(initial_settings,
                      triton_client.get_trace_settings(model_name="simple"),
                      "Unexpected initial model trace settings")
     self.assertEqual(initial_settings, triton_client.get_trace_settings(),
                      "Unexpected initial global settings")
コード例 #9
0
 def test_unicode(self):
     model_name = "string"
     shape = [1]
     with httpclient.InferenceServerClient("localhost:8000") as client:
         utf8 = '😀'
         input_data = np.array([bytes(utf8, encoding='utf-8')],
                               dtype=np.bytes_)
         inputs = [
             httpclient.InferInput("INPUT0", shape,
                                   np_to_triton_dtype(input_data.dtype))
         ]
         inputs[0].set_data_from_numpy(input_data)
         result = client.infer(model_name, inputs)
         output0 = result.as_numpy('OUTPUT0')
         self.assertTrue(output0 is not None)
         self.assertTrue(output0[0] == input_data)
コード例 #10
0
 def __init__(self, uri):
     """
     Initializes the deployment plugin, sets the triton model repo
     """
     super(TritonPlugin, self).__init__(target_uri=uri)
     self.server_config = Config()
     triton_url, self.triton_model_repo = self._get_triton_server_config()
     self.supported_flavors = ['triton', 'onnx']  # need to add other flavors
     # URL cleaning for constructing Triton client
     ssl = False
     if triton_url.startswith("http://"):
         triton_url = triton_url[len("http://"):]
     elif triton_url.startswith("https://"):
         triton_url = triton_url[len("https://"):]
         ssl = True
     self.triton_client = tritonhttpclient.InferenceServerClient(
         url=triton_url, ssl=ssl)
コード例 #11
0
 def check_server_initial_state(self):
     # Helper function to make sure the trace setting is properly
     # initialized / reset before actually running the test case.
     # Note that this function uses HTTP client so the pass/fail of
     # the HTTP trace setting test cases should be checked to make sure
     # the initial state is checked properly before running other test cases.
     initial_settings = {
         "trace_file": "global_unittest.log",
         "trace_level": ["TIMESTAMPS"],
         "trace_rate": "1",
         "trace_count": "-1",
         "log_frequency": "0"
     }
     triton_client = httpclient.InferenceServerClient("localhost:8000")
     self.assertEqual(initial_settings,
                      triton_client.get_trace_settings(model_name="simple"))
     self.assertEqual(initial_settings, triton_client.get_trace_settings())
コード例 #12
0
    def test_wrong_implicit_state_name(self):
        triton_client = tritonhttpclient.InferenceServerClient(
            "localhost:8000")
        inputs = []
        inputs.append(tritonhttpclient.InferInput('INPUT', [1], 'INT32'))
        inputs.append(tritonhttpclient.InferInput('TEST_CASE', [1], 'INT32'))
        inputs[0].set_data_from_numpy(
            np.random.randint(5, size=[1], dtype=np.int32))
        inputs[1].set_data_from_numpy(np.asarray([0], dtype=np.int32))

        with self.assertRaises(InferenceServerException) as e:
            triton_client.infer(model_name="wrong_internal_state",
                                inputs=inputs,
                                sequence_id=2,
                                sequence_start=True)

        self.assertEqual(str(e.exception),
                         "state 'undefined_state' is not a valid state name.")
コード例 #13
0
    def client_task(self):
        with httpclient.InferenceServerClient('localhost:8000') as client:
            input0_data = np.array(['CN1C=NC2=C1C(=O)N(C(=O)N2C)C'
                                    ]).astype(np.object)
            inputs = [
                httpclient.InferInput("INPUT0", input0_data.shape,
                                      np_to_triton_dtype(input0_data.dtype)),
            ]

            inputs[0].set_data_from_numpy(input0_data)
            outputs = [
                httpclient.InferRequestedOutput("OUTPUT0"),
            ]
            response = client.infer(TirtonLocust.model_name,
                                    inputs,
                                    request_id=str(1),
                                    outputs=outputs)
            result = response.get_response()
コード例 #14
0
 def test_valid_create_set_register(self):
     # Create a valid system shared memory region, fill data in it and register
     if _protocol == "http":
         triton_client = httpclient.InferenceServerClient(_url, verbose=True)
     else:
         triton_client = grpcclient.InferenceServerClient(_url, verbose=True)
     shm_op0_handle = shm.create_shared_memory_region(
         "dummy_data", "/dummy_data", 8)
     shm.set_shared_memory_region(shm_op0_handle,
                                  [np.array([1, 2], dtype=np.float32)])
     triton_client.register_system_shared_memory("dummy_data", "/dummy_data",
                                                 8)
     shm_status = triton_client.get_system_shared_memory_status()
     if _protocol == "http":
         self.assertTrue(len(shm_status) == 1)
     else:
         self.assertTrue(len(shm_status.regions) == 1)
     shm.destroy_shared_memory_region(shm_op0_handle)
コード例 #15
0
 def test_model_reload(self):
     model_name = "identity_fp32"
     ensemble_model_name = 'simple_' + "identity_fp32"
     with httpclient.InferenceServerClient("localhost:8000") as client:
         for _ in range(5):
             self.assertFalse(client.is_model_ready(model_name))
             # Load the model before the ensemble model to make sure reloading the
             # model works properly in Python backend.
             client.load_model(model_name)
             client.load_model(ensemble_model_name)
             self.assertTrue(client.is_model_ready(model_name))
             self.assertTrue(client.is_model_ready(ensemble_model_name))
             self.send_identity_request(client, model_name)
             self.send_identity_request(client, ensemble_model_name)
             client.unload_model(ensemble_model_name)
             client.unload_model(model_name)
             self.assertFalse(client.is_model_ready(model_name))
             self.assertFalse(client.is_model_ready(ensemble_model_name))
コード例 #16
0
 def tearDown(self):
     # Clear all trace settings to initial state.
     # Note that the tearDown function uses HTTP client so the pass/fail
     # of the HTTP trace setting test cases should be checked to make sure
     # tearDown() is properly executed and not affecting start state of
     # other test cases
     clear_settings = {
         "trace_file": None,
         "trace_level": None,
         "trace_rate": None,
         "trace_count": None,
         "log_frequency": None
     }
     triton_client = httpclient.InferenceServerClient("localhost:8000")
     triton_client.update_trace_settings(model_name="simple",
                                         settings=clear_settings)
     triton_client.update_trace_settings(model_name=None,
                                         settings=clear_settings)
コード例 #17
0
    def _test_no_outputs_helper(self,
                                use_grpc=True,
                                use_http=True,
                                use_streaming=True):

        if use_grpc:
            triton_client = grpcclient.InferenceServerClient(
                url="localhost:8001", verbose=True)
            self._prepare_request("grpc")
            result = triton_client.infer(model_name=self.model_name_,
                                         inputs=self.inputs_,
                                         outputs=self.outputs_,
                                         client_timeout=1)
            # The response should not contain any outputs
            self.assertEqual(result.as_numpy('OUTPUT0'), None)

        if use_http:
            triton_client = httpclient.InferenceServerClient(
                url="localhost:8000", verbose=True, network_timeout=2.0)
            self._prepare_request("http")
            result = triton_client.infer(model_name=self.model_name_,
                                         inputs=self.inputs_,
                                         outputs=self.outputs_)
            # The response should not contain any outputs
            self.assertEqual(result.as_numpy('OUTPUT0'), None)

        if use_streaming:
            triton_client = grpcclient.InferenceServerClient(
                url="localhost:8001", verbose=True)
            self._prepare_request("grpc")
            user_data = UserData()

            triton_client.stop_stream()
            triton_client.start_stream(callback=partial(callback, user_data),
                                       stream_timeout=1)
            triton_client.async_stream_infer(model_name=self.model_name_,
                                             inputs=self.inputs_,
                                             outputs=self.outputs_)
            result = user_data._completed_requests.get()
            if type(result) == InferenceServerException:
                raise result

            # The response should not contain any outputs
            self.assertEqual(result.as_numpy('OUTPUT0'), None)
コード例 #18
0
ファイル: shm_util.py プロジェクト: zz397fl/server
def unregister_cleanup_shm_regions(shm_regions, shm_handles,
                                   precreated_shm_regions, outputs,
                                   use_system_shared_memory,
                                   use_cuda_shared_memory):
    if not (use_system_shared_memory or use_cuda_shared_memory):
        return None

    triton_client = httpclient.InferenceServerClient("localhost:8000")

    if use_cuda_shared_memory:
        triton_client.unregister_cuda_shared_memory(shm_regions[0] + '_data')
        triton_client.unregister_cuda_shared_memory(shm_regions[1] + '_data')
        cudashm.destroy_shared_memory_region(shm_handles[0])
        cudashm.destroy_shared_memory_region(shm_handles[1])
    else:
        triton_client.unregister_system_shared_memory(shm_regions[0] + '_data')
        triton_client.unregister_system_shared_memory(shm_regions[1] + '_data')
        shm.destroy_shared_memory_region(shm_handles[0])
        shm.destroy_shared_memory_region(shm_handles[1])

    if precreated_shm_regions is None:
        i = 0
        if "OUTPUT0" in outputs:
            if use_cuda_shared_memory:
                triton_client.unregister_cuda_shared_memory(shm_regions[2] +
                                                            '_data')
                cudashm.destroy_shared_memory_region(shm_handles[2])
            else:
                triton_client.unregister_system_shared_memory(shm_regions[2] +
                                                              '_data')
                shm.destroy_shared_memory_region(shm_handles[2])
            i += 1
        if "OUTPUT1" in outputs:
            if use_cuda_shared_memory:
                triton_client.unregister_cuda_shared_memory(shm_regions[2 +
                                                                        i] +
                                                            '_data')
                cudashm.destroy_shared_memory_region(shm_handles[3])
            else:
                triton_client.unregister_system_shared_memory(shm_regions[2 +
                                                                          i] +
                                                              '_data')
                shm.destroy_shared_memory_region(shm_handles[3])
コード例 #19
0
 def test_unregisterall(self):
     # Unregister all shared memory blocks
     shm_handles = self._configure_sever()
     if _protocol == "http":
         triton_client = httpclient.InferenceServerClient(_url, verbose=True)
     else:
         triton_client = grpcclient.InferenceServerClient(_url, verbose=True)
     status_before = triton_client.get_system_shared_memory_status()
     if _protocol == "http":
         self.assertTrue(len(status_before) == 4)
     else:
         self.assertTrue(len(status_before.regions) == 4)
     triton_client.unregister_system_shared_memory()
     status_after = triton_client.get_system_shared_memory_status()
     if _protocol == "http":
         self.assertTrue(len(status_after) == 0)
     else:
         self.assertTrue(len(status_after.regions) == 0)
     self._cleanup_server(shm_handles)
コード例 #20
0
    def test_string(self):
        model_name = "string_fixed"
        shape = [1]

        # Each time inference is performed with a new
        # API
        for i in range(3):
            with httpclient.InferenceServerClient("localhost:8000") as client:
                sample_input = '123456'
                input_data = np.array([sample_input], dtype=np.object_)
                inputs = [
                    httpclient.InferInput("INPUT0", shape,
                                          np_to_triton_dtype(input_data.dtype))
                ]
                inputs[0].set_data_from_numpy(input_data)
                result = client.infer(model_name, inputs)
                output0 = result.as_numpy('OUTPUT0')
                self.assertTrue(output0 is not None)
                self.assertTrue(output0[0] == input_data.astype(np.bytes_))
コード例 #21
0
 def test_unregister_after_inference(self):
     # Unregister after inference
     error_msg = []
     shm_handles = self._configure_sever()
     self._basic_inference(shm_handles[0], shm_handles[1], shm_handles[2],
                           shm_handles[3], error_msg)
     if len(error_msg) > 0:
         raise Exception(str(error_msg))
     if _protocol == "http":
         triton_client = httpclient.InferenceServerClient(_url, verbose=True)
     else:
         triton_client = grpcclient.InferenceServerClient(_url, verbose=True)
     triton_client.unregister_system_shared_memory("output0_data")
     shm_status = triton_client.get_system_shared_memory_status()
     if _protocol == "http":
         self.assertTrue(len(shm_status) == 3)
     else:
         self.assertTrue(len(shm_status.regions) == 3)
     self._cleanup_server(shm_handles)
コード例 #22
0
    def test_resnet50(self):
        try:
            triton_client = httpclient.InferenceServerClient(
                url="localhost:8000")
        except Exception as e:
            print("channel creation failed: " + str(e))
            sys.exit(1)

        image_filename = "../images/vulture.jpeg"
        model_name = "resnet50_plan"
        batch_size = 32

        img = Image.open(image_filename)
        image_data = self._preprocess(img, np.int8)
        image_data = np.expand_dims(image_data, axis=0)

        batched_image_data = image_data
        for i in range(1, batch_size):
            batched_image_data = np.concatenate(
                (batched_image_data, image_data), axis=0)

        inputs = [
            httpclient.InferInput('input_tensor_0', [batch_size, 3, 224, 224],
                                  'INT8')
        ]
        inputs[0].set_data_from_numpy(batched_image_data, binary_data=True)

        outputs = [
            httpclient.InferRequestedOutput('topk_layer_output_index',
                                            binary_data=True)
        ]

        results = triton_client.infer(model_name, inputs, outputs=outputs)

        output_data = results.as_numpy('topk_layer_output_index')
        print(output_data)

        # Validate the results by comparing with precomputed values.
        # VULTURE class corresponds with index 23
        EXPECTED_CLASS_INDEX = 23
        for i in range(batch_size):
            self.assertEqual(output_data[i][0][0], EXPECTED_CLASS_INDEX)
コード例 #23
0
 def test_too_big_shm(self):
     # Shared memory input region larger than needed - Throws error
     error_msg = []
     shm_handles = self._configure_sever()
     shm_ip2_handle = shm.create_shared_memory_region(
         "input2_data", "/input2_data", 128)
     if _protocol == "http":
         triton_client = httpclient.InferenceServerClient(_url, verbose=True)
     else:
         triton_client = grpcclient.InferenceServerClient(_url, verbose=True)
     triton_client.register_system_shared_memory("input2_data",
                                                 "/input2_data", 128)
     self._basic_inference(shm_handles[0], shm_ip2_handle, shm_handles[2],
                           shm_handles[3], error_msg, "input2_data", 128)
     if len(error_msg) > 0:
         self.assertTrue(
             "unexpected total byte size 128 for input 'INPUT1', expecting 64"
             in error_msg[-1])
     shm_handles.append(shm_ip2_handle)
     self._cleanup_server(shm_handles)
コード例 #24
0
ファイル: python_test.py プロジェクト: niqbal996/server
    def test_string(self):
        model_name = "string_fixed"
        shape = [1]

        for i in range(6):
            with httpclient.InferenceServerClient("localhost:8000") as client:
                input_data = np.array(['123456'], dtype=np.object_)
                inputs = [
                    httpclient.InferInput("INPUT0", shape,
                                          np_to_triton_dtype(input_data.dtype))
                ]
                inputs[0].set_data_from_numpy(input_data)
                result = client.infer(model_name, inputs)
                output0 = result.as_numpy('OUTPUT0')
                self.assertIsNotNone(output0)

                if i % 2 == 0:
                    self.assertEqual(output0[0], input_data.astype(np.bytes_))
                else:
                    self.assertEqual(output0.size, 0)
コード例 #25
0
 def test_infer_pymodel_error(self):
     model_name = "wrong_model"
     shape = [2, 2]
     with httpclient.InferenceServerClient("localhost:8000") as client:
         input_data = (16384 * np.random.randn(*shape)).astype(np.uint32)
         inputs = [
             httpclient.InferInput("IN", input_data.shape,
                                   np_to_triton_dtype(input_data.dtype))
         ]
         inputs[0].set_data_from_numpy(input_data)
         try:
             client.infer(model_name, inputs)
         except InferenceServerException as e:
             self.assertTrue(
                 e.message().startswith("GRPC Execute Failed, message:"),
                 "Exception message is not correct")
         else:
             self.assertTrue(
                 False,
                 "Wrong exception raised or did not raise an exception")
コード例 #26
0
    def setup_trt_client(self, url):
        def check_connection(trt_client):
            try:
                if (trt_client.is_server_live()
                        and trt_client.is_server_ready()):
                    return
            except:
                raise ValueError(
                    FAILED_TO_CONNECT_TO_TRITON_SERVER.format(
                        url, self.protocol))

        if self.protocol == 'http':
            trt_client = httpclient.InferenceServerClient(url=url,
                                                          verbose=False)
        elif self.protocol == 'grpc':
            trt_client = grpcclient.InferenceServerClient(url=url,
                                                          verbose=False)

        check_connection(trt_client)
        return trt_client
コード例 #27
0
    def test_no_implicit_state(self):
        triton_client = tritonhttpclient.InferenceServerClient(
            "localhost:8000")
        inputs = []
        inputs.append(tritonhttpclient.InferInput('INPUT', [1], 'INT32'))
        inputs.append(tritonhttpclient.InferInput('TEST_CASE', [1], 'INT32'))
        inputs[0].set_data_from_numpy(
            np.random.randint(5, size=[1], dtype=np.int32))
        inputs[1].set_data_from_numpy(np.asarray([0], dtype=np.int32))

        with self.assertRaises(InferenceServerException) as e:
            triton_client.infer(model_name="no_implicit_state",
                                inputs=inputs,
                                sequence_id=1,
                                sequence_start=True)

        self.assertEqual(
            str(e.exception),
            "unable to add state 'undefined_state'. State configuration is missing for model 'no_implicit_state'."
        )
コード例 #28
0
    def __init__(self, path_or_bytes, sess_options=None, providers=[]):
        self.client = tritonhttpclient.InferenceServerClient("localhost:8000")
        model_metadata = self.client.get_model_metadata(
            model_name=path_or_bytes)

        self.request_count = 0
        self.model_name = path_or_bytes
        self.inputs = []
        self.outputs = []
        self.dtype_mapping = {}

        for (src, dest) in (
            (model_metadata["inputs"], self.inputs),
            (model_metadata["outputs"], self.outputs),
        ):
            for element in src:
                dest.append(NodeArg(element["name"], element["shape"]))
                self.dtype_mapping[element["name"]] = element["datatype"]

        self.triton_enabled = True
コード例 #29
0
    def test_unicode(self):
        model_name = "string"
        shape = [1]

        for i in range(3):
            with self._shm_leak_detector.Probe() as shm_probe:
                with httpclient.InferenceServerClient(
                        "localhost:8000") as client:
                    utf8 = '😀'
                    input_data = np.array([bytes(utf8, encoding='utf-8')],
                                          dtype=np.bytes_)
                    inputs = [
                        httpclient.InferInput(
                            "INPUT0", shape,
                            np_to_triton_dtype(input_data.dtype))
                    ]
                    inputs[0].set_data_from_numpy(input_data)
                    result = client.infer(model_name, inputs)
                    output0 = result.as_numpy('OUTPUT0')
                    self.assertIsNotNone(output0)
                    self.assertEqual(output0[0], input_data)
コード例 #30
0
    def test_request_output(self):
        triton_client = tritonhttpclient.InferenceServerClient(
            "localhost:8000")
        inputs = []
        inputs.append(tritonhttpclient.InferInput('INPUT', [1], 'INT32'))
        inputs[0].set_data_from_numpy(np.asarray([1], dtype=np.int32))

        outputs = []
        outputs.append(tritonhttpclient.InferRequestedOutput('OUTPUT_STATE'))
        outputs.append(tritonhttpclient.InferRequestedOutput('OUTPUT'))

        for backend in BACKENDS.split(" "):
            result = triton_client.infer(
                model_name=f"{backend}_nobatch_sequence_int32_output",
                inputs=inputs,
                outputs=outputs,
                sequence_id=1,
                sequence_start=True,
                sequence_end=True)
            self.assertTrue(result.as_numpy('OUTPUT_STATE')[0], 1)
            self.assertTrue(result.as_numpy('OUTPUT')[0], 1)