def __iter__(self): client = InferenceServerClient(self._server_url, verbose=self._verbose) error = self._verify_triton_state(client) if error: raise RuntimeError( f"Could not communicate to Triton Server: {error}") LOGGER.debug( f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} " f"are up and ready!") model_config = client.get_model_config(self._model_name, self._model_version) model_metadata = client.get_model_metadata(self._model_name, self._model_version) LOGGER.info(f"Model config {model_config}") LOGGER.info(f"Model metadata {model_metadata}") inputs = {tm.name: tm for tm in model_metadata.inputs} outputs = {tm.name: tm for tm in model_metadata.outputs} output_names = list(outputs) outputs_req = [InferRequestedOutput(name) for name in outputs] for ids, x, y_real in self._dataloader: infer_inputs = [] for name in inputs: data = x[name] infer_input = InferInput(name, data.shape, inputs[name].datatype) target_np_dtype = client_utils.triton_to_np_dtype( inputs[name].datatype) data = data.astype(target_np_dtype) infer_input.set_data_from_numpy(data) infer_inputs.append(infer_input) results = client.infer( model_name=self._model_name, model_version=self._model_version, inputs=infer_inputs, outputs=outputs_req, client_timeout=self._response_wait_t, ) y_pred = {name: results.as_numpy(name) for name in output_names} yield ids, x, y_pred, y_real
def __init__( self, server_url: str, model_name: str, model_version: str, *, verbose=False, resp_wait_s: Optional[float] = None, ): self._server_url = server_url self._model_name = model_name self._model_version = model_version self._verbose = verbose self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s self._client = InferenceServerClient(self._server_url, verbose=self._verbose) error = self._verify_triton_state(self._client) if error: raise RuntimeError( f"Could not communicate to Triton Server: {error}") LOGGER.debug( f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} " f"are up and ready!") model_config = self._client.get_model_config(self._model_name, self._model_version) model_metadata = self._client.get_model_metadata( self._model_name, self._model_version) LOGGER.info(f"Model config {model_config}") LOGGER.info(f"Model metadata {model_metadata}") self._inputs = {tm.name: tm for tm in model_metadata.inputs} self._input_names = list(self._inputs) self._outputs = {tm.name: tm for tm in model_metadata.outputs} self._output_names = list(self._outputs) self._outputs_req = [ InferRequestedOutput(name) for name in self._outputs ]
def req_loop(self): client = InferenceServerClient(self._server_url, verbose=self._verbose) self._errors = self._verify_triton_state(client) if self._errors: return LOGGER.debug( f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} " f"are up and ready!") model_config = client.get_model_config(self._model_name, self._model_version) model_metadata = client.get_model_metadata(self._model_name, self._model_version) LOGGER.info(f"Model config {model_config}") LOGGER.info(f"Model metadata {model_metadata}") inputs = {tm.name: tm for tm in model_metadata.inputs} outputs = {tm.name: tm for tm in model_metadata.outputs} output_names = list(outputs) outputs_req = [InferRequestedOutput(name) for name in outputs] self._num_waiting_for = 0 for ids, x, y_real in self._dataloader: infer_inputs = [] for name in inputs: data = x[name] infer_input = InferInput(name, data.shape, inputs[name].datatype) target_np_dtype = client_utils.triton_to_np_dtype( inputs[name].datatype) data = data.astype(target_np_dtype) infer_input.set_data_from_numpy(data) infer_inputs.append(infer_input) with self._sync: def _check_can_send(): return self._num_waiting_for < self._max_unresp_reqs can_send = self._sync.wait_for(_check_can_send, timeout=self._response_wait_t) if not can_send: error_msg = f"Runner could not send new requests for {self._response_wait_t}s" self._errors.append(error_msg) break callback = functools.partial(AsyncGRPCTritonRunner._on_result, self, ids, x, y_real, output_names) client.async_infer( model_name=self._model_name, model_version=self._model_version, inputs=infer_inputs, outputs=outputs_req, callback=callback, ) self._num_waiting_for += 1 # wait till receive all requested data with self._sync: def _all_processed(): LOGGER.debug( f"wait for {self._num_waiting_for} unprocessed jobs") return self._num_waiting_for == 0 self._processed_all = self._sync.wait_for( _all_processed, self.DEFAULT_MAX_FINISH_WAIT_S) if not self._processed_all: error_msg = f"Runner {self._response_wait_t}s timeout received while waiting for results from server" self._errors.append(error_msg) LOGGER.debug("Finished request thread")