예제 #1
0
    def Run(self, inputs):
        """
        Args:
            inputs: list, Each value corresponds to an input name of self._input_names
        Returns: 
            results: dict, {name : numpy.array}
        """
        infer_inputs = []
        for idx, data in enumerate(inputs):
            data = np.array([[x.encode('utf-8')] for x in data],
                            dtype=np.object_)
            infer_input = InferInput(self._input_names[idx], [len(data), 1],
                                     "BYTES")
            infer_input.set_data_from_numpy(data)
            infer_inputs.append(infer_input)

        results = self._client.infer(
            model_name=self._model_name,
            model_version=self._model_version,
            inputs=infer_inputs,
            outputs=self._outputs_req,
            client_timeout=self._response_wait_t,
        )
        results = {name: results.as_numpy(name) for name in self._output_names}
        return results
예제 #2
0
    def __iter__(self):
        client = InferenceServerClient(self._server_url, verbose=self._verbose)
        error = self._verify_triton_state(client)
        if error:
            raise RuntimeError(
                f"Could not communicate to Triton Server: {error}")

        LOGGER.debug(
            f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} "
            f"are up and ready!")

        model_config = client.get_model_config(self._model_name,
                                               self._model_version)
        model_metadata = client.get_model_metadata(self._model_name,
                                                   self._model_version)
        LOGGER.info(f"Model config {model_config}")
        LOGGER.info(f"Model metadata {model_metadata}")

        inputs = {tm.name: tm for tm in model_metadata.inputs}
        outputs = {tm.name: tm for tm in model_metadata.outputs}
        output_names = list(outputs)
        outputs_req = [InferRequestedOutput(name) for name in outputs]

        for ids, x, y_real in self._dataloader:
            infer_inputs = []
            for name in inputs:
                data = x[name]
                infer_input = InferInput(name, data.shape,
                                         inputs[name].datatype)

                target_np_dtype = client_utils.triton_to_np_dtype(
                    inputs[name].datatype)
                data = data.astype(target_np_dtype)

                infer_input.set_data_from_numpy(data)
                infer_inputs.append(infer_input)

            results = client.infer(
                model_name=self._model_name,
                model_version=self._model_version,
                inputs=infer_inputs,
                outputs=outputs_req,
                client_timeout=self._response_wait_t,
            )
            y_pred = {name: results.as_numpy(name) for name in output_names}
            yield ids, x, y_pred, y_real
예제 #3
0
 def v2_request_transform(self, input_tensors):
     request = ModelInferRequest()
     request.model_name = self.name
     input_0 = InferInput("INPUT__0", input_tensors.shape, "FP32")
     input_0.set_data_from_numpy(input_tensors)
     request.inputs.extend([input_0._get_tensor()])
     if input_0._get_content() is not None:
         request.raw_input_contents.extend([input_0._get_content()])
     return request
    def req_loop(self):
        client = InferenceServerClient(self._server_url, verbose=self._verbose)
        self._errors = self._verify_triton_state(client)
        if self._errors:
            return

        LOGGER.debug(
            f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} "
            f"are up and ready!")

        model_config = client.get_model_config(self._model_name,
                                               self._model_version)
        model_metadata = client.get_model_metadata(self._model_name,
                                                   self._model_version)
        LOGGER.info(f"Model config {model_config}")
        LOGGER.info(f"Model metadata {model_metadata}")

        inputs = {tm.name: tm for tm in model_metadata.inputs}
        outputs = {tm.name: tm for tm in model_metadata.outputs}
        output_names = list(outputs)
        outputs_req = [InferRequestedOutput(name) for name in outputs]

        self._num_waiting_for = 0

        for ids, x, y_real in self._dataloader:
            infer_inputs = []
            for name in inputs:
                data = x[name]
                infer_input = InferInput(name, data.shape,
                                         inputs[name].datatype)

                target_np_dtype = client_utils.triton_to_np_dtype(
                    inputs[name].datatype)
                data = data.astype(target_np_dtype)

                infer_input.set_data_from_numpy(data)
                infer_inputs.append(infer_input)

            with self._sync:

                def _check_can_send():
                    return self._num_waiting_for < self._max_unresp_reqs

                can_send = self._sync.wait_for(_check_can_send,
                                               timeout=self._response_wait_t)
                if not can_send:
                    error_msg = f"Runner could not send new requests for {self._response_wait_t}s"
                    self._errors.append(error_msg)
                    break

                callback = functools.partial(AsyncGRPCTritonRunner._on_result,
                                             self, ids, x, y_real,
                                             output_names)
                client.async_infer(
                    model_name=self._model_name,
                    model_version=self._model_version,
                    inputs=infer_inputs,
                    outputs=outputs_req,
                    callback=callback,
                )
                self._num_waiting_for += 1

        # wait till receive all requested data
        with self._sync:

            def _all_processed():
                LOGGER.debug(
                    f"wait for {self._num_waiting_for} unprocessed jobs")
                return self._num_waiting_for == 0

            self._processed_all = self._sync.wait_for(
                _all_processed, self.DEFAULT_MAX_FINISH_WAIT_S)
            if not self._processed_all:
                error_msg = f"Runner {self._response_wait_t}s timeout received while waiting for results from server"
                self._errors.append(error_msg)
        LOGGER.debug("Finished request thread")