Пример #1
0
    def testInterceptedHeaderManipulationWithServerSideVerification(self):
        request = b'\x07\x08'

        channel = grpc.intercept_channel(self._channel,
                                         _append_request_header_interceptor(
                                             'secret', '42'))
        channel = grpc.intercept_channel(channel,
                                         _LoggingInterceptor(
                                             'c1', self._record),
                                         _LoggingInterceptor(
                                             'c2', self._record))

        self._record[:] = []

        multi_callable = _unary_unary_multi_callable(channel)
        multi_callable.with_call(
            request,
            metadata=(
                ('test',
                 'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),))

        self.assertSequenceEqual(self._record, [
            'c1:intercept_unary_unary', 'c2:intercept_unary_unary',
            's1:intercept_service', 's3:intercept_service',
            's2:intercept_service'
        ])
def test_five_retries_internal():
    service = _FailFirstAttempts(5, code=grpc.StatusCode.INTERNAL)
    server = grpc_server(service.handler)
    request = zone_service_pb2.GetZoneRequest(zone_id="id")
    retriable_codes = (
        grpc.StatusCode.UNAVAILABLE,
        grpc.StatusCode.RESOURCE_EXHAUSTED,
        grpc.StatusCode.INTERNAL,
    )

    with default_channel() as channel:
        for max_retry_count in range(4):
            interceptor = RetryInterceptor(max_retry_count=max_retry_count, retriable_codes=retriable_codes)
            ch = grpc.intercept_channel(channel, interceptor)
            client = zone_service_pb2_grpc.ZoneServiceStub(ch)

            with pytest.raises(grpc.RpcError) as e:
                client.Get(request)

            assert e.value.code() == grpc.StatusCode.INTERNAL
            service.reset(5)

        interceptor = RetryInterceptor(max_retry_count=5, retriable_codes=retriable_codes)
        ch = grpc.intercept_channel(channel, interceptor)
        client = zone_service_pb2_grpc.ZoneServiceStub(ch)
        res = client.Get(request)

        assert res == DEFAULT_ZONE

    server.stop(0)
Пример #3
0
    def _setup_client_channel_config(self):
        """get grpc client configuration from server and setup channel and stub for use.
        """
        tmp_insec_channel = grpc.insecure_channel(self.address)
        tmp_channel = grpc.intercept_channel(tmp_insec_channel, self.header_adder_int)
        tmp_stub = hangar_service_pb2_grpc.HangarServiceStub(tmp_channel)
        t_init, t_tot = time.time(), 0
        while t_tot < self.wait_ready_timeout:
            try:
                request = hangar_service_pb2.GetClientConfigRequest()
                response = tmp_stub.GetClientConfig(request)
                self.cfg['push_max_nbytes'] = int(response.config['push_max_nbytes'])
                self.cfg['enable_compression'] = bool(int(response.config['enable_compression']))
                self.cfg['optimization_target'] = response.config['optimization_target']
            except grpc.RpcError as err:
                if not (err.code() == grpc.StatusCode.UNAVAILABLE) and (self.wait_ready is True):
                    logger.error(err)
                    raise err
            else:
                break
            time.sleep(0.05)
            t_tot = time.time() - t_init
        else:
            err = ConnectionError(f'Server did not connect after: {self.wait_ready_timeout} sec.')
            logger.error(err)
            raise err

        tmp_channel.close()
        tmp_insec_channel.close()
        configured_channel = grpc.insecure_channel(
            self.address,
            options=[('grpc.default_compression_algorithm', self.cfg['enable_compression']),
                     ('grpc.optimization_target', self.cfg['optimization_target'])])
        self.channel = grpc.intercept_channel(configured_channel, self.header_adder_int)
        self.stub = hangar_service_pb2_grpc.HangarServiceStub(self.channel)
Пример #4
0
    def __init__(self,
                 address: Optional[str] = None,
                 interceptors: Optional[List[Union[
                     UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
                     StreamUnaryClientInterceptor,
                     StreamStreamClientInterceptor]]] = None):
        """Connects to Dapr Runtime and initialize gRPC client stub.

        Args:
            address (str, optional): Dapr Runtime gRPC endpoint address.
            interceptors (list of UnaryUnaryClientInterceptor or
                UnaryStreamClientInterceptor or
                StreamUnaryClientInterceptor or
                StreamStreamClientInterceptor, optional): gRPC interceptors.
        """
        if not address:
            address = f"{settings.DAPR_RUNTIME_HOST}:{settings.DAPR_GRPC_PORT}"
        self._address = address
        self._channel = grpc.insecure_channel(address)  # type: ignore

        if settings.DAPR_API_TOKEN:
            api_token_interceptor = DaprClientInterceptor([
                ('dapr-api-token', settings.DAPR_API_TOKEN),
            ])
            self._channel = grpc.intercept_channel(  # type: ignore
                self._channel, api_token_interceptor)
        if interceptors:
            self._channel = grpc.intercept_channel(  # type: ignore
                self._channel, *interceptors)

        self._stub = api_service_v1.DaprStub(self._channel)
Пример #5
0
def ChangeCoreProc():
    global SE
    global best_edge
    global BRANCH
    global fragmentLevel
    global waiting_to_connect_to

    if (fragmentLevel > 0):

        if ( SE[ int( best_edge[1][4:] ) ] == BRANCH ):
            with grpc.insecure_channel(best_edge[1] + ':50050') as channel:
                    interceptors = (RetryOnRpcErrorClientInterceptor(max_attempts=100, sleeping_policy=ExponentialBackoff(init_backoff_ms=100, max_backoff_ms=1600, multiplier=2), status_for_retry=(grpc.StatusCode.UNAVAILABLE,),),)
                    intercept_channel = grpc.intercept_channel(channel, *interceptors)
                    stub = GHS_pb2_grpc.MessagingStub(intercept_channel)
                    changeCore = GHS_pb2.ChangeCoreMSG(nodename= socket.gethostname())
                    stub.ChangeCore(changeCore)
        else:
            with grpc.insecure_channel(best_edge[1] + ':50050') as channel:
                        waiting_to_connect_to = best_edge[1]
                        interceptors = (RetryOnRpcErrorClientInterceptor(max_attempts=100, sleeping_policy=ExponentialBackoff(init_backoff_ms=100, max_backoff_ms=1600, multiplier=2), status_for_retry=(grpc.StatusCode.UNAVAILABLE,),),)
                        intercept_channel = grpc.intercept_channel(channel, *interceptors)
                        stub = GHS_pb2_grpc.MessagingStub(intercept_channel)
                        Connect = GHS_pb2.ConnectMSG(nodename= socket.gethostname(), fragmentLevel= fragmentLevel)
                        stub.Connect(Connect)
            SE[ int( best_edge[1][4:] ) ] = BRANCH
def test_retriable_codes():
    retriable_codes = [grpc.StatusCode.RESOURCE_EXHAUSTED,
                       grpc.StatusCode.UNAVAILABLE,
                       grpc.StatusCode.DATA_LOSS]

    service = _RetriableCodes(retriable_codes)
    server = grpc_server(service.handler)

    with default_channel() as channel:
        for retry_qty in range(len(retriable_codes)):
            interceptor = RetryInterceptor(max_retry_count=retry_qty, retriable_codes=retriable_codes)
            ch = grpc.intercept_channel(channel, interceptor)
            client = zone_service_pb2_grpc.ZoneServiceStub(ch)

            with pytest.raises(grpc.RpcError) as e:
                client.Get(zone_service_pb2.GetZoneRequest(zone_id="id"))

            assert e.value.code() == retriable_codes[retry_qty]
            service.reset_state()

        interceptor = RetryInterceptor(max_retry_count=len(retriable_codes), retriable_codes=retriable_codes)
        ch = grpc.intercept_channel(channel, interceptor)
        client = zone_service_pb2_grpc.ZoneServiceStub(ch)
        assert client.Get(zone_service_pb2.GetZoneRequest(zone_id="id")) == DEFAULT_ZONE

    server.stop(0)
Пример #7
0
    def __init__(self,
                 address: Optional[str] = None,
                 tracer: Optional[Tracer] = None):
        """Connects to Dapr Runtime and initialize gRPC client stub.

        Args:
            address (str, optional): Dapr Runtime gRPC endpoint address.
        """
        if not address:
            address = f"{settings.DAPR_RUNTIME_HOST}:{settings.DAPR_GRPC_PORT}"
        self._address = address
        self._channel = grpc.insecure_channel(address)  # type: ignore

        if settings.DAPR_API_TOKEN:
            api_token_interceptor = DaprClientInterceptor([
                ('dapr-api-token', settings.DAPR_API_TOKEN),
            ])
            self._channel = grpc.intercept_channel(  # type: ignore
                self._channel, api_token_interceptor)
        if tracer:
            self._channel = grpc.intercept_channel(  # type: ignore
                self._channel,
                client_interceptor.OpenCensusClientInterceptor(tracer=tracer))

        self._stub = api_service_v1.DaprStub(self._channel)
Пример #8
0
 def client(self, stub_ctor, interceptor=None):
     service = _service_for_ctor(stub_ctor)
     channel = self._channels.channel(service)
     if interceptor is not None:
         channel = grpc.intercept_channel(channel, interceptor)
     elif self._default_interceptor is not None:
         channel = grpc.intercept_channel(channel,
                                          self._default_interceptor)
     return stub_ctor(channel)
Пример #9
0
def get_yatai_service(
    channel_address=None,
    access_token=None,
    db_url=None,
    repo_base_url=None,
    s3_endpoint_url=None,
    default_namespace=None,
):
    channel_address = channel_address or config('yatai_service').get('url')
    access_token = access_token or config('yatai_service').get('access_token')
    channel_address = channel_address.strip()
    if channel_address:
        from bentoml.yatai.proto.yatai_service_pb2_grpc import YataiStub

        if any([db_url, repo_base_url, s3_endpoint_url, default_namespace]):
            logger.warning(
                "Using remote YataiService at `%s`, local YataiService configs "
                "including db_url, repo_base_url, s3_endpoint_url and default_namespace"
                "will all be ignored.",
                channel_address,
            )

        logger.debug("Connecting YataiService gRPC server at: %s",
                     channel_address)
        scheme, addr = parse_grpc_url(channel_address)
        header_adder_interceptor = header_client_interceptor.header_adder_interceptor(
            'access_token', access_token)
        if scheme in ('grpcs', 'https'):
            client_cacert_path = (
                config().get('yatai_service', 'client_certificate_file')
                or certifi.where()  # default: Mozilla ca cert
            )
            with open(client_cacert_path, 'rb') as ca_cert_file:
                ca_cert = ca_cert_file.read()
            credentials = grpc.ssl_channel_credentials(ca_cert, None, None)
            channel = grpc.intercept_channel(
                grpc.secure_channel(addr, credentials),
                header_adder_interceptor)
        else:
            channel = grpc.intercept_channel(grpc.insecure_channel(addr),
                                             header_adder_interceptor)
        return YataiStub(channel)
    else:
        from bentoml.yatai.yatai_service_impl import YataiService

        logger.debug("Creating local YataiService instance")
        return YataiService(
            db_url=db_url,
            repo_base_url=repo_base_url,
            s3_endpoint_url=s3_endpoint_url,
            default_namespace=default_namespace,
        )
Пример #10
0
def ExecConnect(request):
    global fragmentLevel
    global fragmentID
    global SE
    global BRANCH
    global state
    global FIND
    global findCount
    global BASIC
    global MsgQueue
    global waiting_to_connect_to
    global in_branch

    WakeUpIfNeeded()

    level = request.fragmentLevel
    node = request.nodename

    print ('recieved Connect from ' + node + ' with fragmentLevel = ' + str(level))

    if (level < fragmentLevel):
        SE[int(node[4:])] = BRANCH
        in_branch.append(node[4:])
        with grpc.insecure_channel('coordinator:50050') as channel:
            interceptors = (RetryOnRpcErrorClientInterceptor(max_attempts=100, sleeping_policy=ExponentialBackoff(init_backoff_ms=100, max_backoff_ms=1600, multiplier=2), status_for_retry=(grpc.StatusCode.UNAVAILABLE,),),)
            intercept_channel = grpc.intercept_channel(channel, *interceptors)
            stub = GHS_pb2_grpc.MessagingStub(intercept_channel)
            branchMSG = GHS_pb2.BranchesMSG(nodename= socket.gethostname(), branches = in_branch)
            stub.Branches(branchMSG)
        with grpc.insecure_channel(node + ':50050') as channel:
            interceptors = (RetryOnRpcErrorClientInterceptor(max_attempts=100, sleeping_policy=ExponentialBackoff(init_backoff_ms=100, max_backoff_ms=1600, multiplier=2), status_for_retry=(grpc.StatusCode.UNAVAILABLE,),),)
            intercept_channel = grpc.intercept_channel(channel, *interceptors)
            stub = GHS_pb2_grpc.MessagingStub(intercept_channel)
            init = GHS_pb2.InitiateMSG(nodename= socket.gethostname(), fragmentLevel= fragmentLevel, fragmentID = fragmentID, state= state)
            stub.Initiate(init)
        if (state == FIND):
            findCount = findCount + 1
    elif (SE[int(node[4:])] == BASIC):
        ConMSG = Message("Connect", node, request)
        MsgQueue.appendleft(ConMSG)
    else:
        with grpc.insecure_channel(node + ':50050') as channel:
            interceptors = (RetryOnRpcErrorClientInterceptor(max_attempts=100, sleeping_policy=ExponentialBackoff(init_backoff_ms=100, max_backoff_ms=1600, multiplier=2), status_for_retry=(grpc.StatusCode.UNAVAILABLE,),),)
            intercept_channel = grpc.intercept_channel(channel, *interceptors)
            stub = GHS_pb2_grpc.MessagingStub(intercept_channel)
            weightOfNode = int
            for x in edges[socket.gethostname()]:
                if ( x[0] == int(node[4:]) ):
                    weightOfNode = x[1]
            init = GHS_pb2.InitiateMSG(nodename= socket.gethostname(), fragmentLevel= (fragmentLevel + 1), fragmentID = weightOfNode, state = FIND)
            stub.Initiate(init)     
Пример #11
0
def make_insecure_channel(address,
                          mode=ChannelType.INTERNAL,
                          options=None,
                          compression=None):
    if check_address_valid(address):
        return grpc.insecure_channel(address, options, compression)

    if mode == ChannelType.REMOTE:
        if not EGRESS_URL:
            logging.error("EGRESS_URL is invalid,"
                          "not found in environment variable.")
            return grpc.insecure_channel(address, options, compression)

        if options is None:
            options = []
        if not isinstance(options, list):
            raise Exception('grpc channel options must be list')

        logging.debug("EGRESS_URL is [%s]", EGRESS_URL)
        if EGRESS_HOST:
            options.append(('grpc.default_authority', EGRESS_HOST))
            header_adder = header_adder_interceptor('x-host', address)
            channel = grpc.insecure_channel(
                EGRESS_URL, options, compression)
            return grpc.intercept_channel(channel, header_adder)

        options.append(('grpc.default_authority', address))
        return grpc.insecure_channel(EGRESS_URL, options, compression)

    if mode == ChannelType.INTERNAL:
        return grpc.insecure_channel(address, options, compression)

    raise Exception("UNKNOWN Channel by uuid %s" % address)
Пример #12
0
def call_server():
    channel = grpc.intercept_channel(grpc.insecure_channel("localhost:50051"),
                                     PromClientInterceptor())
    stub = hello_world_grpc.GreeterStub(channel)

    # Call the unary-unary.
    response = stub.SayHello(hello_world_pb2.HelloRequest(name="Unary"))
    _LOGGER.info("Unary response: %s", response.message)
    _LOGGER.info("")

    # Call the unary stream.
    _LOGGER.info("Running Unary Stream client")
    response_iter = stub.SayHelloUnaryStream(
        hello_world_pb2.HelloRequest(name="unary stream"))
    _LOGGER.info("Response for Unary Stream")
    for response in response_iter:
        _LOGGER.info("Unary Stream response item: %s", response.message)
    _LOGGER.info("")

    # Call the stream_unary.
    _LOGGER.info("Running Stream Unary client")
    response = stub.SayHelloStreamUnary(generate_requests("Stream Unary"))
    _LOGGER.info("Stream Unary response: %s", response.message)
    _LOGGER.info("")

    # Call stream & stream.
    _LOGGER.info("Running Bidi Stream client")
    response_iter = stub.SayHelloBidiStream(generate_requests("Bidi Stream"))
    for response in response_iter:
        _LOGGER.info("Bidi Stream response item: %s", response.message)
    _LOGGER.info("")
Пример #13
0
def ReportProc(request):
    global findCount
    global test_edge
    global state
    global FOUND
    global best_wt
    global in_branch
    global neighbor
    global report_edge
    global INFINITY
    global edges

    node = request.nodename
    if ( (findCount == 0) and (test_edge == None) ):
        state == FOUND
        for x in in_branch:
            if int(x) in neighbor:

                edge = [ int( socket.gethostname()[4:] ), int( node[4:] )]
                for y in edges[socket.gethostname()]:
                    if ( y[0] == int(node[4:]) ):
                        weightOfNode = y[1]
                        if ( ( weightOfNode == best_wt ) or (best_wt == INFINITY) ):
                            report_edge = edge

                with grpc.insecure_channel('node' + str(x) + ':50050') as channel:
                            interceptors = (RetryOnRpcErrorClientInterceptor(max_attempts=100, sleeping_policy=ExponentialBackoff(init_backoff_ms=100, max_backoff_ms=1600, multiplier=2), status_for_retry=(grpc.StatusCode.UNAVAILABLE,),),)
                            intercept_channel = grpc.intercept_channel(channel, *interceptors)
                            stub = GHS_pb2_grpc.MessagingStub(intercept_channel)
                            Report = GHS_pb2.ReportMSG(nodename = socket.gethostname(), weight = best_wt, edge = report_edge)
                            stub.Report(Report)
Пример #14
0
 def create_state_handler(self, api_service_descriptor):
     if not api_service_descriptor:
         return self._throwing_state_handler
     url = api_service_descriptor.url
     if url not in self._state_handler_cache:
         with self._lock:
             if url not in self._state_handler_cache:
                 # Options to have no limits (-1) on the size of the messages
                 # received or sent over the data plane. The actual buffer size is
                 # controlled in a layer above.
                 options = [('grpc.max_receive_message_length', -1),
                            ('grpc.max_send_message_length', -1)]
                 if self._credentials is None:
                     _LOGGER.info('Creating insecure state channel for %s.',
                                  url)
                     grpc_channel = GRPCChannelFactory.insecure_channel(
                         url, options=options)
                 else:
                     _LOGGER.info('Creating secure state channel for %s.',
                                  url)
                     grpc_channel = GRPCChannelFactory.secure_channel(
                         url, self._credentials, options=options)
                 _LOGGER.info('State channel established.')
                 # Add workerId to the grpc channel
                 grpc_channel = grpc.intercept_channel(
                     grpc_channel, WorkerIdInterceptor())
                 self._state_handler_cache[url] = CachingStateHandler(
                     self._state_cache,
                     GrpcStateHandler(
                         beam_fn_api_pb2_grpc.BeamFnStateStub(
                             grpc_channel)))
     return self._state_handler_cache[url]
Пример #15
0
    def __init__(self, control_address, worker_count, credentials=None):
        self._worker_count = worker_count
        self._worker_index = 0
        if credentials is None:
            logging.info('Creating insecure control channel.')
            self._control_channel = grpc.insecure_channel(control_address)
        else:
            logging.info('Creating secure control channel.')
            self._control_channel = grpc.secure_channel(
                control_address, credentials)
        grpc.channel_ready_future(self._control_channel).result(timeout=60)
        logging.info('Control channel established.')

        self._control_channel = grpc.intercept_channel(self._control_channel,
                                                       WorkerIdInterceptor())
        self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
            credentials)
        self._state_handler_factory = GrpcStateHandlerFactory()
        self.workers = queue.Queue()
        # one thread is enough for getting the progress report.
        # Assumption:
        # Progress report generation should not do IO or wait on other resources.
        #  Without wait, having multiple threads will not improve performance and
        #  will only add complexity.
        self._progress_thread_pool = futures.ThreadPoolExecutor(max_workers=1)
        self._process_thread_pool = futures.ThreadPoolExecutor(
            max_workers=self._worker_count)
        self._instruction_id_vs_worker = {}
        self._fns = {}
        self._responses = queue.Queue()
        self._process_bundle_queue = queue.Queue()
        self._unscheduled_process_bundle = set()
        logging.info('Initializing SDKHarness with %s workers.',
                     self._worker_count)
Пример #16
0
def WakeUpIfNeeded():
    global state
    global findCount
    global INFINITY
    global SLEEPING
    global SE
    global BRANCH
    global waiting_to_connect_to
    global fragmentLevel

    if (state == SLEEPING):
            print(socket.gethostname() + " is waking up!")
            # todo: call a function that finds min-wt BASIC node instead..
            minedge = FindMinBasicEdge()

            state = FOUND

            findCount = 0

            if (minedge[0] != INFINITY):
                with grpc.insecure_channel('node' + str(minedge[0]) + ':50050') as channel:
                    interceptors = (RetryOnRpcErrorClientInterceptor(max_attempts=100, sleeping_policy=ExponentialBackoff(init_backoff_ms=100, max_backoff_ms=1600, multiplier=2), status_for_retry=(grpc.StatusCode.UNAVAILABLE,),),)
                    intercept_channel = grpc.intercept_channel(channel, *interceptors)
                    stub = GHS_pb2_grpc.MessagingStub(intercept_channel)
                    Connect = GHS_pb2.ConnectMSG(nodename= socket.gethostname(), fragmentLevel= fragmentLevel)
                    stub.Connect(Connect)
                SE[minedge[0]] = BRANCH
                waiting_to_connect_to = minedge[0]
Пример #17
0
  def __init__(self, control_address, worker_count, credentials=None):
    self._worker_count = worker_count
    self._worker_index = 0
    if credentials is None:
      logging.info('Creating insecure control channel.')
      self._control_channel = grpc.insecure_channel(control_address)
    else:
      logging.info('Creating secure control channel.')
      self._control_channel = grpc.secure_channel(control_address, credentials)
    grpc.channel_ready_future(self._control_channel).result(timeout=60)
    logging.info('Control channel established.')

    self._control_channel = grpc.intercept_channel(
        self._control_channel, WorkerIdInterceptor())
    self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
        credentials)
    self._state_handler_factory = GrpcStateHandlerFactory()
    self.workers = queue.Queue()
    # one thread is enough for getting the progress report.
    # Assumption:
    # Progress report generation should not do IO or wait on other resources.
    #  Without wait, having multiple threads will not improve performance and
    #  will only add complexity.
    self._progress_thread_pool = futures.ThreadPoolExecutor(max_workers=1)
    self._process_thread_pool = futures.ThreadPoolExecutor(
        max_workers=self._worker_count)
    self._instruction_id_vs_worker = {}
    self._fns = {}
    self._responses = queue.Queue()
    self._process_bundle_queue = queue.Queue()
    self._unscheduled_process_bundle = set()
    logging.info('Initializing SDKHarness with %s workers.', self._worker_count)
Пример #18
0
  def create_data_channel_from_url(self, url):
    # type: (str) -> Optional[GrpcClientDataChannel]
    if not url:
      return None
    if url not in self._data_channel_cache:
      with self._lock:
        if url not in self._data_channel_cache:
          _LOGGER.info('Creating client data channel for %s', url)
          # Options to have no limits (-1) on the size of the messages
          # received or sent over the data plane. The actual buffer size
          # is controlled in a layer above.
          channel_options = [("grpc.max_receive_message_length", -1),
                             ("grpc.max_send_message_length", -1)]
          grpc_channel = None
          if self._credentials is None:
            grpc_channel = GRPCChannelFactory.insecure_channel(
                url, options=channel_options)
          else:
            grpc_channel = GRPCChannelFactory.secure_channel(
                url, self._credentials, options=channel_options)
          # Add workerId to the grpc channel
          grpc_channel = grpc.intercept_channel(
              grpc_channel, WorkerIdInterceptor(self._worker_id))
          self._data_channel_cache[url] = GrpcClientDataChannel(
              beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel),
              self._data_buffer_time_limit_ms)

    return self._data_channel_cache[url]
Пример #19
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--log_payloads',
        action='store_true',
        default=True,
        help='log request/response objects to open-tracing spans')
    args = parser.parse_args()

    config = Config(config={
        'sampler': {
            'type': 'const',
            'param': 1,
        },
        'logging': True,
    },
                    service_name='hello_world_client')
    tracer = config.initialize_tracer()
    tracer_interceptor = open_tracing_client_interceptor.OpenTracingClientInterceptor(
        tracer, log_payloads=args.log_payloads)
    with tracer.start_span("step1") as span:
        scope.set_active_span(span)
        time.sleep(0.01)
    channel = grpc.insecure_channel(HOST_PORT)
    channel = grpc.intercept_channel(channel, tracer_interceptor)
    stub = hello_world_pb2_grpc.GreeterStub(channel)
    response = stub.SayHello(hello_world_pb2.HelloRequest(name='you'))
    print("Message received: " + response.message)
Пример #20
0
 def _connect(self):
     TCLogger.info("span channel connect %s with meta:%s", self.address,
                   self.meta)
     self.channel = grpc.intercept_channel(
         grpc.insecure_channel(self.address),
         GrpcClient.InterceptorAddHeader(self.meta))
     self.span_stub = Service_pb2_grpc.SpanStub(self.channel)
Пример #21
0
  def _grpc_data_channel_test(self, time_based_flush=False):
    if time_based_flush:
      data_servicer = data_plane.BeamFnDataServicer(
          data_buffer_time_limit_ms=100)
    else:
      data_servicer = data_plane.BeamFnDataServicer()
    worker_id = 'worker_0'
    data_channel_service = \
      data_servicer.get_conn_by_worker_id(worker_id)

    server = grpc.server(UnboundedThreadPoolExecutor())
    beam_fn_api_pb2_grpc.add_BeamFnDataServicer_to_server(
        data_servicer, server)
    test_port = server.add_insecure_port('[::]:0')
    server.start()

    grpc_channel = grpc.insecure_channel('localhost:%s' % test_port)
    # Add workerId to the grpc channel
    grpc_channel = grpc.intercept_channel(
        grpc_channel, WorkerIdInterceptor(worker_id))
    data_channel_stub = beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel)
    if time_based_flush:
      data_channel_client = data_plane.GrpcClientDataChannel(
          data_channel_stub, data_buffer_time_limit_ms=100)
    else:
      data_channel_client = data_plane.GrpcClientDataChannel(data_channel_stub)

    try:
      self._data_channel_test(
          data_channel_service, data_channel_client, time_based_flush)
    finally:
      data_channel_client.close()
      data_channel_service.close()
      data_channel_client.wait()
      data_channel_service.wait()
Пример #22
0
    def get_result(self, file_path):
        """
        通过文件路径获取最终解码结果的迭代器
        :param file_path:
        :return: response的迭代
        """
        header_adder_interceptor = header_manipulator_client_interceptor.header_adder_interceptor(
            'audio_meta', base64.b64encode(self.request.SerializeToString()))

        # 添加ca认证
        # with open(
        #         '/path/of/xxx.crt',
        #         'rb') as f:
        #     creds = grpc.ssl_channel_credentials(f.read())
        # with grpc.secure_channel(self.host, creds) as channel:

        with grpc.insecure_channel(target=self.host,
                                   options=[
                                       ('grpc.keepalive_timeout_ms', 1000000),
                                   ]) as channel:
            intercept_channel = grpc.intercept_channel(
                channel, header_adder_interceptor)
            stub = audio_streaming_pb2_grpc.AsrServiceStub(intercept_channel)
            responses = stub.send(self.generate_file_stream(file_path),
                                  timeout=100000)
            for response in responses:
                yield response
Пример #23
0
    def create_data_channel(self, remote_grpc_port):
        url = remote_grpc_port.api_service_descriptor.url
        if url not in self._data_channel_cache:
            with self._lock:
                if url not in self._data_channel_cache:
                    logging.info('Creating client data channel for %s', url)
                    # Options to have no limits (-1) on the size of the messages
                    # received or sent over the data plane. The actual buffer size
                    # is controlled in a layer above.
                    channel_options = [("grpc.max_receive_message_length", -1),
                                       ("grpc.max_send_message_length", -1)]
                    grpc_channel = None
                    if self._credentials is None:
                        grpc_channel = GRPCChannelFactory.insecure_channel(
                            url, options=channel_options)
                    else:
                        grpc_channel = GRPCChannelFactory.secure_channel(
                            url, self._credentials, options=channel_options)
                    # Add workerId to the grpc channel
                    grpc_channel = grpc.intercept_channel(
                        grpc_channel, WorkerIdInterceptor(self._worker_id))
                    self._data_channel_cache[url] = GrpcClientDataChannel(
                        beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel))

        return self._data_channel_cache[url]
Пример #24
0
 def create_state_handler(self, api_service_descriptor):
   if not api_service_descriptor:
     return self._throwing_state_handler
   url = api_service_descriptor.url
   if url not in self._state_handler_cache:
     with self._lock:
       if url not in self._state_handler_cache:
         # Options to have no limits (-1) on the size of the messages
         # received or sent over the data plane. The actual buffer size is
         # controlled in a layer above.
         options = [('grpc.max_receive_message_length', -1),
                    ('grpc.max_send_message_length', -1)]
         if self._credentials is None:
           logging.info('Creating insecure state channel for %s.', url)
           grpc_channel = GRPCChannelFactory.insecure_channel(
               url, options=options)
         else:
           logging.info('Creating secure state channel for %s.', url)
           grpc_channel = GRPCChannelFactory.secure_channel(
               url, self._credentials, options=options)
         logging.info('State channel established.')
         # Add workerId to the grpc channel
         grpc_channel = grpc.intercept_channel(grpc_channel,
                                               WorkerIdInterceptor())
         self._state_handler_cache[url] = GrpcStateHandler(
             beam_fn_api_pb2_grpc.BeamFnStateStub(grpc_channel))
   return self._state_handler_cache[url]
Пример #25
0
    def testTripleRequestMessagesClientInterceptor(self):
        def triple(request_iterator):
            while True:
                try:
                    item = next(request_iterator)
                    yield item
                    yield item
                    yield item
                except StopIteration:
                    break

        interceptor = _wrap_request_iterator_stream_interceptor(triple)
        channel = grpc.intercept_channel(self._channel, interceptor)
        requests = tuple(b'\x07\x08'
                         for _ in range(test_constants.STREAM_LENGTH))

        multi_callable = _stream_stream_multi_callable(channel)
        response_iterator = multi_callable(
            iter(requests),
            metadata=(
                ('test',
                 'InterceptedStreamRequestBlockingUnaryResponseWithCall'), ))

        responses = tuple(response_iterator)
        self.assertEqual(len(responses), 3 * test_constants.STREAM_LENGTH)

        multi_callable = _stream_stream_multi_callable(self._channel)
        response_iterator = multi_callable(
            iter(requests),
            metadata=(
                ('test',
                 'InterceptedStreamRequestBlockingUnaryResponseWithCall'), ))

        responses = tuple(response_iterator)
        self.assertEqual(len(responses), test_constants.STREAM_LENGTH)
    def get_grpc_channel(self):
        """Gets a gRPC channel to the emulator

           It will inject the proper tokens when needed.

           Note: Currently only produces insecure channels.
        """
        # This should default to max
        MAX_MESSAGE_LENGTH = -1

        addr = "localhost:{}".format(self.get("grpc.port", 8554))
        channel = grpc.insecure_channel(
            addr,
            options=[
                ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH),
                ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH),
            ],
        )
        if "grpc.token" in self._description:
            bearer = "Bearer {}".format(self.get("grpc.token", ""))
            logging.debug("Insecure Channel with token to: %s", addr)
            return grpc.intercept_channel(
                channel, header_adder_interceptor("authorization", bearer))

        logging.debug("Insecure channel to %s", addr)
        return channel
Пример #27
0
def auth_api_session():
    """
    Create an Auth API for testing

    This needs to use the real server since it plays around with headers
    """
    with futures.ThreadPoolExecutor(1) as executor:
        server = grpc.server(executor)
        port = server.add_secure_port("localhost:0",
                                      grpc.local_server_credentials())
        auth_pb2_grpc.add_AuthServicer_to_server(Auth(), server)
        server.start()

        try:
            with grpc.secure_channel(
                    f"localhost:{port}",
                    grpc.local_channel_credentials()) as channel:

                class _MetadataKeeperInterceptor(
                        grpc.UnaryUnaryClientInterceptor):
                    def __init__(self):
                        self.latest_headers = {}

                    def intercept_unary_unary(self, continuation,
                                              client_call_details, request):
                        call = continuation(client_call_details, request)
                        self.latest_headers = dict(call.initial_metadata())
                        return call

                metadata_interceptor = _MetadataKeeperInterceptor()
                channel = grpc.intercept_channel(channel, metadata_interceptor)
                yield auth_pb2_grpc.AuthStub(channel), metadata_interceptor
        finally:
            server.stop(None).wait()
def MakeTransport(client_class,
                  credentials,
                  address_override_func,
                  mtls_enabled=False):
    """Instantiates a grpc transport."""
    transport_class = client_class.get_transport_class()
    address = _GetAddress(client_class, address_override_func, mtls_enabled)

    channel = transport_class.create_channel(
        host=address,
        credentials=credentials,
        ssl_credentials=GetSSLCredentials(mtls_enabled),
        options=MakeChannelOptions())

    interceptors = []
    interceptors.append(RequestReasonInterceptor())
    interceptors.append(UserAgentInterceptor())
    interceptors.append(TimeoutInterceptor())
    interceptors.append(IAMAuthHeadersInterceptor())
    interceptors.append(RPCDurationReporterInterceptor())
    interceptors.append(QuotaProjectInterceptor(credentials))
    interceptors.append(APIEnablementInterceptor())
    if properties.VALUES.core.log_http.GetBool():
        interceptors.append(LoggingInterceptor(credentials))

    channel = grpc.intercept_channel(channel, *interceptors)
    return transport_class(channel=channel, host=address)
Пример #29
0
def create_channel(
    target: str,
    options: Optional[List[Tuple[str, Any]]] = None,
    interceptors: Optional[List[ClientInterceptor]] = None,
) -> grpc.Channel:
    """Creates a gRPC channel

    The gRPC channel is created with the provided options and intercepts each
    invocation via the provided interceptors.

    The created channel is configured with the following default options:
        - "grpc.max_send_message_length": 100MB,
        - "grpc.max_receive_message_length": 100MB.

    :param target: the server address.
    :param options: optional list of key-value pairs to configure the channel.
    :param interceptors: optional list of client interceptors.
    :returns: a gRPC channel.

    """
    # The list of possible options is available here:
    # https://grpc.io/grpc/core/group__grpc__arg__keys.html
    options = (options or []) + [
        ("grpc.max_send_message_length", grpc_max_msg_size),
        ("grpc.max_receive_message_length", grpc_max_msg_size),
    ]
    interceptors = interceptors or []
    channel = grpc.insecure_channel(target, options)
    return grpc.intercept_channel(channel, *interceptors)
Пример #30
0
    def test_custom_interceptor_exception(self):
        # add an interceptor that raises a custom exception and check error tags
        # are added to spans
        raise_exception_interceptor = _RaiseExceptionClientInterceptor()
        with grpc.insecure_channel("localhost:%d" % (_GRPC_PORT)) as channel:
            with self.assertRaises(_CustomException):
                intercept_channel = grpc.intercept_channel(
                    channel, raise_exception_interceptor)
                stub = HelloStub(intercept_channel)
                stub.SayHello(HelloRequest(name="custom-exception"))

        spans = self.get_spans_with_sync_and_assert(size=2)
        client_span, server_span = spans

        assert client_span.resource == "/helloworld.Hello/SayHello"
        assert client_span.error == 1
        assert client_span.get_tag(ERROR_MSG) == "custom"
        assert client_span.get_tag(
            ERROR_TYPE) == "tests.contrib.grpc.test_grpc._CustomException"
        assert client_span.get_tag(ERROR_STACK) is not None
        assert client_span.get_tag("grpc.status.code") == "StatusCode.INTERNAL"

        # no exception on server end
        assert server_span.resource == "/helloworld.Hello/SayHello"
        assert server_span.error == 0
        assert server_span.get_tag(ERROR_MSG) is None
        assert server_span.get_tag(ERROR_TYPE) is None
        assert server_span.get_tag(ERROR_STACK) is None
Пример #31
0
    def __init__(self):
        self.state = None

        if config.force_tls:
            self.channel = grpc.secure_channel(
                config.collector_address,
                grpc.ssl_channel_credentials(),
                options=(('grpc.max_connection_age_grace_ms',
                          1000 * config.GRPC_TIMEOUT), ))
        else:
            self.channel = grpc.insecure_channel(
                config.collector_address,
                options=(('grpc.max_connection_age_grace_ms',
                          1000 * config.GRPC_TIMEOUT), ))

        if config.authentication:
            self.channel = grpc.intercept_channel(
                self.channel,
                header_adder_interceptor('authentication',
                                         config.authentication))

        self.channel.subscribe(self._cb, try_to_connect=True)
        self.service_management = GrpcServiceManagementClient(self.channel)
        self.traces_reporter = GrpcTraceSegmentReportService(self.channel)
        self.profile_query = GrpcProfileTaskChannelService(self.channel)
Пример #32
0
  def create_data_channel(self, remote_grpc_port):
    url = remote_grpc_port.api_service_descriptor.url
    if url not in self._data_channel_cache:
      with self._lock:
        if url not in self._data_channel_cache:
          logging.info('Creating channel for %s', url)
          # Options to have no limits (-1) on the size of the messages
          # received or sent over the data plane. The actual buffer size
          # is controlled in a layer above.
          channel_options = [("grpc.max_receive_message_length", -1),
                             ("grpc.max_send_message_length", -1)]
          grpc_channel = None
          if self._credentials is None:
            grpc_channel = GRPCChannelFactory.insecure_channel(
                url, options=channel_options)
          else:
            grpc_channel = GRPCChannelFactory.secure_channel(
                url, self._credentials, options=channel_options)
          # Add workerId to the grpc channel
          grpc_channel = grpc.intercept_channel(grpc_channel,
                                                WorkerIdInterceptor())
          self._data_channel_cache[url] = GrpcClientDataChannel(
              beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel))

    return self._data_channel_cache[url]
Пример #33
0
    def __init__(self, org_id, service_id, service_metadata, group,
                 service_stub, payment_strategy, options, mpe_contract,
                 account, sdk_web3):
        self.org_id = org_id
        self.service_id = service_id
        self.options = options
        self.group = group
        self.service_metadata = service_metadata

        self.payment_strategy = payment_strategy
        self.expiry_threshold = self.group["payment"][
            "payment_expiration_threshold"]
        self.__base_grpc_channel = self._get_grpc_channel()
        self.grpc_channel = grpc.intercept_channel(
            self.__base_grpc_channel,
            generic_client_interceptor.create(self._intercept_call))
        self.payment_channel_provider = PaymentChannelProvider(
            sdk_web3, self._generate_payment_channel_state_service_client(),
            mpe_contract)
        self.service = self._generate_grpc_stub(service_stub)
        self.payment_channels = []
        self.last_read_block = 0
        self.account = account
        self.sdk_web3 = sdk_web3
        self.mpe_address = mpe_contract.contract.address
Пример #34
0
def run():
    default_value = helloworld_pb2.HelloReply(
        message='Hello from your local interceptor!')
    default_value_interceptor = default_value_client_interceptor.DefaultValueClientInterceptor(
        default_value)
    channel = grpc.insecure_channel('localhost:50051')
    channel = grpc.intercept_channel(channel, default_value_interceptor)
    stub = helloworld_pb2_grpc.GreeterStub(channel)
    response = stub.SayHello(helloworld_pb2.HelloRequest(name='you'))
    print("Greeter client received: " + response.message)
Пример #35
0
def run():
    header_adder_interceptor = header_manipulator_client_interceptor.header_adder_interceptor(
        'one-time-password', '42')
    # NOTE(gRPC Python Team): .close() is possible on a channel and should be
    # used in circumstances in which the with statement does not fit the needs
    # of the code.
    with grpc.insecure_channel('localhost:50051') as channel:
        intercept_channel = grpc.intercept_channel(channel,
                                                   header_adder_interceptor)
        stub = helloworld_pb2_grpc.GreeterStub(intercept_channel)
        response = stub.SayHello(helloworld_pb2.HelloRequest(name='you'))
    print("Greeter client received: " + response.message)
Пример #36
0
def send_confirmation_email(email, order):
  channel = grpc.insecure_channel('0.0.0.0:8080')
  channel = grpc.intercept_channel(channel, tracer_interceptor)
  stub = demo_pb2_grpc.EmailServiceStub(channel)
  try:
    response = stub.SendOrderConfirmation(demo_pb2.SendOrderConfirmationRequest(
      email = email,
      order = order
    ))
    logger.info('Request sent.')
  except grpc.RpcError as err:
    logger.error(err.details())
    logger.error('{}, {}'.format(err.code().name, err.code().value))
Пример #37
0
def run():
    default_value = helloworld_pb2.HelloReply(
        message='Hello from your local interceptor!')
    default_value_interceptor = default_value_client_interceptor.DefaultValueClientInterceptor(
        default_value)
    # NOTE(gRPC Python Team): .close() is possible on a channel and should be
    # used in circumstances in which the with statement does not fit the needs
    # of the code.
    with grpc.insecure_channel('localhost:50051') as channel:
        intercept_channel = grpc.intercept_channel(channel,
                                                   default_value_interceptor)
        stub = helloworld_pb2_grpc.GreeterStub(intercept_channel)
        response = stub.SayHello(helloworld_pb2.HelloRequest(name='you'))
    print("Greeter client received: " + response.message)
Пример #38
0
    def testDefectiveClientInterceptor(self):
        interceptor = _DefectiveClientInterceptor()
        defective_channel = grpc.intercept_channel(self._channel, interceptor)

        request = b'\x07\x08'

        multi_callable = _unary_unary_multi_callable(defective_channel)
        call_future = multi_callable.future(
            request,
            metadata=(
                ('test', 'InterceptedUnaryRequestBlockingUnaryResponse'),))

        self.assertIsNotNone(call_future.exception())
        self.assertEqual(call_future.code(), grpc.StatusCode.INTERNAL)
Пример #39
0
  def __init__(self, log_service_descriptor):
    super(FnApiLogRecordHandler, self).__init__()
    self._log_channel = grpc.intercept_channel(
        grpc.insecure_channel(log_service_descriptor.url),
        WorkerIdInterceptor())
    self._logging_stub = beam_fn_api_pb2_grpc.BeamFnLoggingStub(
        self._log_channel)
    self._log_entry_queue = queue.Queue()

    log_control_messages = self._logging_stub.Logging(self._write_log_entries())
    self._reader = threading.Thread(
        target=lambda: self._read_log_control_messages(log_control_messages),
        name='read_log_control_messages')
    self._reader.daemon = True
    self._reader.start()
Пример #40
0
  def __init__(self, log_service_descriptor):
    super(FnApiLogRecordHandler, self).__init__()
    # Make sure the channel is ready to avoid [BEAM-4649]
    ch = grpc.insecure_channel(log_service_descriptor.url)
    grpc.channel_ready_future(ch).result(timeout=60)
    self._log_channel = grpc.intercept_channel(ch, WorkerIdInterceptor())
    self._logging_stub = beam_fn_api_pb2_grpc.BeamFnLoggingStub(
        self._log_channel)
    self._log_entry_queue = queue.Queue()

    log_control_messages = self._logging_stub.Logging(self._write_log_entries())
    self._reader = threading.Thread(
        target=lambda: self._read_log_control_messages(log_control_messages),
        name='read_log_control_messages')
    self._reader.daemon = True
    self._reader.start()
    def _intercept_channel(self, *interceptors):
        """ Experimental. Bind gRPC interceptors to the gRPC channel.

        Args:
            interceptors (*Union[grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamingClientInterceptor, grpc.StreamingUnaryClientInterceptor, grpc.StreamingStreamingClientInterceptor]):
              Zero or more gRPC interceptors. Interceptors are given control in the order
              they are listed.
        Raises:
            TypeError: If interceptor does not derive from any of
              UnaryUnaryClientInterceptor,
              UnaryStreamClientInterceptor,
              StreamUnaryClientInterceptor, or
              StreamStreamClientInterceptor.
        """
        self.channel = grpc.intercept_channel(self.channel, *interceptors)
        self._text_to_speech_stub = (cloud_tts_pb2_grpc.TextToSpeechStub(
            self.channel))
        self._inner_api_calls.clear()
Пример #42
0
    def testInterceptedUnaryRequestFutureUnaryResponse(self):
        request = b'\x07\x08'

        self._record[:] = []
        channel = grpc.intercept_channel(
            self._channel,
            _LoggingInterceptor('c1', self._record),
            _LoggingInterceptor('c2', self._record))

        multi_callable = _unary_unary_multi_callable(channel)
        response_future = multi_callable.future(
            request,
            metadata=(('test', 'InterceptedUnaryRequestFutureUnaryResponse'),))
        response_future.result()

        self.assertSequenceEqual(self._record, [
            'c1:intercept_unary_unary', 'c2:intercept_unary_unary',
            's1:intercept_service', 's2:intercept_service'
        ])
Пример #43
0
    def testInterceptedUnaryRequestStreamResponse(self):
        request = b'\x37\x58'

        self._record[:] = []
        channel = grpc.intercept_channel(
            self._channel,
            _LoggingInterceptor('c1', self._record),
            _LoggingInterceptor('c2', self._record))

        multi_callable = _unary_stream_multi_callable(channel)
        response_iterator = multi_callable(
            request,
            metadata=(('test', 'InterceptedUnaryRequestStreamResponse'),))
        tuple(response_iterator)

        self.assertSequenceEqual(self._record, [
            'c1:intercept_unary_stream', 'c2:intercept_unary_stream',
            's1:intercept_service', 's2:intercept_service'
        ])
Пример #44
0
    def get_conn(self):
        base_url = self.conn.host

        if self.conn.port:
            base_url = base_url + ":" + str(self.conn.port)

        auth_type = self._get_field("auth_type")

        if auth_type == "NO_AUTH":
            channel = grpc.insecure_channel(base_url)
        elif auth_type == "SSL" or auth_type == "TLS":
            credential_file_name = self._get_field("credential_pem_file")
            creds = grpc.ssl_channel_credentials(open(credential_file_name).read())
            channel = grpc.secure_channel(base_url, creds)
        elif auth_type == "JWT_GOOGLE":
            credentials, _ = google_auth.default()
            jwt_creds = google_auth_jwt.OnDemandCredentials.from_signing_credentials(
                credentials)
            channel = google_auth_transport_grpc.secure_authorized_channel(
                jwt_creds, None, base_url)
        elif auth_type == "OATH_GOOGLE":
            scopes = self._get_field("scopes").split(",")
            credentials, _ = google_auth.default(scopes=scopes)
            request = google_auth_transport_requests.Request()
            channel = google_auth_transport_grpc.secure_authorized_channel(
                credentials, request, base_url)
        elif auth_type == "CUSTOM":
            if not self.custom_connection_func:
                raise AirflowConfigException(
                    "Customized connection function not set, not able to establish a channel")
            channel = self.custom_connection_func(self.conn)
        else:
            raise AirflowConfigException(
                "auth_type not supported or not provided, channel cannot be established,\
                given value: %s" % str(auth_type))

        if self.interceptors:
            for interceptor in self.interceptors:
                channel = grpc.intercept_channel(channel,
                                                 interceptor)

        return channel
Пример #45
0
    def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self):
        request = b'\x07\x08'

        channel = grpc.intercept_channel(
            self._channel,
            _LoggingInterceptor('c1', self._record),
            _LoggingInterceptor('c2', self._record))

        self._record[:] = []

        multi_callable = _unary_unary_multi_callable(channel)
        multi_callable.with_call(
            request,
            metadata=(
                ('test',
                 'InterceptedUnaryRequestBlockingUnaryResponseWithCall'),))

        self.assertSequenceEqual(self._record, [
            'c1:intercept_unary_unary', 'c2:intercept_unary_unary',
            's1:intercept_service', 's2:intercept_service'
        ])
Пример #46
0
    def testInterceptedStreamRequestStreamResponse(self):
        requests = tuple(b'\x77\x58'
                         for _ in range(test_constants.STREAM_LENGTH))
        request_iterator = iter(requests)

        self._record[:] = []
        channel = grpc.intercept_channel(
            self._channel,
            _LoggingInterceptor('c1', self._record),
            _LoggingInterceptor('c2', self._record))

        multi_callable = _stream_stream_multi_callable(channel)
        response_iterator = multi_callable(
            request_iterator,
            metadata=(('test', 'InterceptedStreamRequestStreamResponse'),))
        tuple(response_iterator)

        self.assertSequenceEqual(self._record, [
            'c1:intercept_stream_stream', 'c2:intercept_stream_stream',
            's1:intercept_service', 's2:intercept_service'
        ])
Пример #47
0
    def testInterceptedStreamRequestBlockingUnaryResponseWithCall(self):
        requests = tuple(b'\x07\x08'
                         for _ in range(test_constants.STREAM_LENGTH))
        request_iterator = iter(requests)

        self._record[:] = []
        channel = grpc.intercept_channel(
            self._channel,
            _LoggingInterceptor('c1', self._record),
            _LoggingInterceptor('c2', self._record))

        multi_callable = _stream_unary_multi_callable(channel)
        multi_callable.with_call(
            request_iterator,
            metadata=(
                ('test',
                 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),))

        self.assertSequenceEqual(self._record, [
            'c1:intercept_stream_unary', 'c2:intercept_stream_unary',
            's1:intercept_service', 's2:intercept_service'
        ])
Пример #48
0
    def testTripleRequestMessagesClientInterceptor(self):

        def triple(request_iterator):
            while True:
                try:
                    item = next(request_iterator)
                    yield item
                    yield item
                    yield item
                except StopIteration:
                    break

        interceptor = _wrap_request_iterator_stream_interceptor(triple)
        channel = grpc.intercept_channel(self._channel, interceptor)
        requests = tuple(b'\x07\x08'
                         for _ in range(test_constants.STREAM_LENGTH))

        multi_callable = _stream_stream_multi_callable(channel)
        response_iterator = multi_callable(
            iter(requests),
            metadata=(
                ('test',
                 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),))

        responses = tuple(response_iterator)
        self.assertEqual(len(responses), 3 * test_constants.STREAM_LENGTH)

        multi_callable = _stream_stream_multi_callable(self._channel)
        response_iterator = multi_callable(
            iter(requests),
            metadata=(
                ('test',
                 'InterceptedStreamRequestBlockingUnaryResponseWithCall'),))

        responses = tuple(response_iterator)
        self.assertEqual(len(responses), test_constants.STREAM_LENGTH)