コード例 #1
0
ファイル: client.py プロジェクト: yanshanjing/ML-Model-CI
class RpcClient(object):
    def __init__(self, port):
        self.channel = grpc.insecure_channel("localhost:" + str(port))
        self.stub = PredictStub(self.channel)

    def service_request(
            self,
            request: Union[InferRequest, Generator[None, InferRequest, None]],
            stream=False
    ):
        if stream:
            return self._service_request_stream(request)
        else:
            return self._service_request(request)

    def _service_request(self, request: InferRequest):
        response = self.stub.Infer(request)
        return response

    def _service_request_stream(self, request_generator: Generator[None, InferRequest, None]):
        responses = self.stub.StreamInfer(request_generator)
        return responses

    @staticmethod
    def make_request(model_name, inputs: Iterable[Union[np.ndarray, torch.Tensor]], meta=None):
        inputs = list(inputs)
        example = inputs[0]
        if meta is None:
            meta = dict()

        if isinstance(example, np.ndarray):
            to_byte = bytes
            torch_flag = False
        elif isinstance(example, torch.Tensor):
            to_byte = compose(bytes, torch.Tensor.numpy)
            torch_flag = True
        else:
            raise ValueError(
                'Argument `image` is expected to be an iterative numpy array, or an iterative torch Tensor')

        raw_input = list(map(to_byte, inputs))
        shape = example.shape
        dtype = type_to_data_type(example.dtype).value
        meta = json_update({'shape': shape, 'dtype': dtype, 'torch_flag': torch_flag}, meta)

        return InferRequest(model_name=model_name, raw_input=raw_input, meta=json.dumps(meta))

    def close(self):
        self.channel.close()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()
コード例 #2
0
class CVONNXClient(BaseModelInspector):
    def __init__(self, repeat_data, batch_num=1, batch_size=1, asynchronous=None):
        super().__init__(repeat_data=repeat_data, batch_num=batch_num, batch_size=batch_size, asynchronous=asynchronous)
        self.batches = self.__client_batch_request()  # FIXME: creating batches twice will increase the data preprocessing time
        self.stub = PredictStub(grpc.insecure_channel(f"localhost:{ONNX_GRPC_PORT}"))

    def data_preprocess(self):
        pass

    def __client_batch_request(self):
        print('ONNX: start data preprocessing...')
        batches = []
        batch = np.repeat(self.raw_data[np.newaxis, :, :, :], self.batch_size, axis=0)
        for i in range(self.batch_num):
            batches.append(self.transform_image(images=batch))
        return batches

    # TODO: this will be moved to data_preprocessor function
    def transform_image(self, images):
        t = transforms.Compose(
            [transforms.ToPILImage(), transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(),
             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), torch.Tensor.numpy])
        return list(map(t, images))

    def infer(self, input_batch):
        example = input_batch[0]
        meta = dict()
        raw_input = list(map(bytes, input_batch))
        shape = example.shape
        dtype = type_to_data_type(example.dtype).value
        meta = json_update({'shape': shape, 'dtype': dtype, 'torch_flag': True}, meta)
        self.stub.Infer(InferRequest(model_name='resnet50', raw_input=raw_input, meta=json.dumps(meta)))
コード例 #3
0
 def __init__(self, port):
     self.channel = grpc.insecure_channel("localhost:" + str(port))
     self.stub = PredictStub(self.channel)
コード例 #4
0
 def __init__(self, repeat_data, batch_num=1, batch_size=1, asynchronous=None):
     super().__init__(repeat_data=repeat_data, batch_num=batch_num, batch_size=batch_size, asynchronous=asynchronous)
     self.batches = self.__client_batch_request()  # FIXME: creating batches twice will increase the data preprocessing time
     self.stub = PredictStub(grpc.insecure_channel(f"localhost:{ONNX_GRPC_PORT}"))