def __init__(self, host, port, timeout, user, password, creds=None, options=None): """This class creates grpc calls using python. :param username: Username for device login :param password: Password for device login :param host: The ip address for the device :param port The port for the device :param timeout: how long before the rpc call timesout :param creds: Input of the pem file :param options: TLS server name :type password: str :type username: str :type server: str :type port: int :type timeout:int :type creds: str :type options: str """ if creds != None: self._target = '%s:%d' % (host, port) self._creds = implementations.ssl_channel_credentials(creds) self._options = options channel = grpc.secure_channel(self._target, self._creds, (( 'grpc.ssl_target_name_override', self._options, ), )) self._channel = implementations.Channel(channel) else: self._host = host self._port = port self._channel = implementations.insecure_channel( self._host, self._port) self._stub = ems_grpc_pb2.beta_create_gRPCConfigOper_stub( self._channel) self._timeout = int(timeout) self._metadata = [('username', user), ('password', password)]
def not_really_secure_channel( host, port, channel_credentials, server_host_override): """Creates an insecure Channel to a remote host. Args: host: The name of the remote host to which to connect. port: The port of the remote host to which to connect. channel_credentials: The implementations.ChannelCredentials with which to connect. server_host_override: The target name used for SSL host name checking. Returns: An implementations.Channel to the remote host through which RPCs may be conducted. """ target = '%s:%d' % (host, port) channel = grpc.secure_channel( target, channel_credentials, ((b'grpc.ssl_target_name_override', server_host_override,),)) return implementations.Channel(channel)
def not_really_secure_channel(host, port, client_credentials, server_host_override): """Creates an insecure Channel to a remote host. Args: host: The name of the remote host to which to connect. port: The port of the remote host to which to connect. client_credentials: The implementations.ClientCredentials with which to connect. server_host_override: The target name used for SSL host name checking. Returns: An implementations.Channel to the remote host through which RPCs may be conducted. """ hostport = '%s:%d' % (host, port) intermediary_low_channel = _intermediary_low.Channel( hostport, client_credentials._intermediary_low_credentials, server_host_override=server_host_override) return implementations.Channel(intermediary_low_channel._internal, intermediary_low_channel)
host = '10.75.58.60' port = 57400 options = 'ems.cisco.com' ca_cert = 'ems.pem' # credential file scp from devices creds = open(ca_cert).read() target = '%s:%d' % (host, port) creds = implementations.ssl_channel_credentials(creds.encode( ('utf-8'))) # args with byte type channel = grpc.secure_channel(target, creds, (( 'grpc.ssl_target_name_override', options, ), )) channel = implementations.Channel(channel) stub = ems_grpc_pb2.beta_create_gRPCConfigOper_stub(channel) sub_id = 'test_sub' # Telemetry MDT subscribtion sub_args = ems_grpc_pb2.CreateSubsArgs(ReqId=1, encode=3, subidstr=sub_id) timeout = float(100000) metadata = [('username', 'cisco'), ('password', 'cisco')] stream = stub.CreateSubs(sub_args, timeout=timeout, metadata=metadata) for segment in stream: telemetry_pb = telemetry_pb2.Telemetry() t = telemetry_pb.ParseFromString(segment.data) # Print Json Message print(MessageToJson(telemetry_pb))
def _connect(self): """ Create GRPC connection to target host :return: None """ if not HAS_GRPC: raise AnsibleError( "grpcio is required to use the gRPC connection type. Please run 'pip install grpcio'" ) host = self.get_option("host") host = self._play_context.remote_addr if self.connected: self.queue_message( "log", "gRPC connection to host %s already exist" % host ) return port = self.get_option("port") self._target = host if port is None else "%s:%d" % (host, port) self._timeout = self.get_option("persistent_command_timeout") self._login_credentials = [ ("username", self.get_option("remote_user")), ("password", self.get_option("password")), ] ssl_target_name_override = self.get_option("ssl_target_name_override") if ssl_target_name_override: self._channel_options = [ ("grpc.ssl_target_name_override", ssl_target_name_override), ] else: self._channel_options = None certs = {} private_key_file = self.get_option("private_key_file") root_certificates_file = self.get_option("root_certificates_file") certificate_chain_file = self.get_option("certificate_chain_file") try: if root_certificates_file: with open(root_certificates_file, "rb") as f: certs["root_certificates"] = f.read() if private_key_file: with open(private_key_file, "rb") as f: certs["private_key"] = f.read() if certificate_chain_file: with open(certificate_chain_file, "rb") as f: certs["certificate_chain"] = f.read() except Exception as e: raise AnsibleConnectionFailure( "Failed to read certificate keys: %s" % e ) if certs: creds = ssl_channel_credentials(**certs) channel = secure_channel( self._target, creds, options=self._channel_options ) else: channel = insecure_channel( self._target, options=self._channel_options ) self.queue_message( "vvv", "ESTABLISH GRPC CONNECTION FOR USER: %s on PORT %s TO %s" % (self.get_option("remote_user"), port, host), ) self._channel = implementations.Channel(channel) self.queue_message( "vvvv", "grpc connection has completed successfully" ) self._connected = True
def _start_model_server(model_path, flags): """Start the model server as a subprocess. Args: model_path: string path to pass to model server in --model_base_path flag flags: a _FlagConfig object containing configuration for the webapp Returns: model_server: model server process stub: grpc stub to PredictionService running on model server Raises: RuntimeError: if model server fails to come up and is used as prediction engine. """ if not model_path.startswith("@"): port = "--port=%s" % MODEL_SERVER_PORT tensorflow_session_parallelism = ( flags.tensorflow_session_parallelism or os.environ.get("tensorflow_session_parallelism") or 0) args = [ flags.model_server_binary_path, port, "--model_base_path={}".format(os.path.dirname(model_path)), "--file_system_poll_wait_seconds={}".format( MODEL_SERVER_FS_POLLING_INTERVAL), "--tensorflow_session_parallelism={}".format( tensorflow_session_parallelism) ] logging.debug("Starting model server: %s", args) model_server = subprocess.Popen(args=args, stdin=None) hostport = "localhost:%s" % MODEL_SERVER_PORT else: model_server = None hostport = model_path[1:] # Default grpc message limit is 4 MiB. Some of our models like predict_io_test # exceed that limit. We raise this limit as suggested in # https://groups.google.com/a/google.com/forum/?utm_medium=email&utm_source=footer#!msg/grpc-users/LSf9JvK69bw/qDHXdTIZDAAJ # Note that we are using max_message_length as opposed to # max_receive_message_length because it was introduced in a later version of # grpc, and grpc 1.0.1 which we are currently using only has # max_message_length. We pick max message length to be 32MiB to match # AppEngine limit of 32MiB message size. logging.debug("Connecting to model server at %s", hostport) channel = grpc.insecure_channel( hostport, options=[ ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH_MB), ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH_MB), ("grpc.max_message_length", MAX_MESSAGE_LENGTH_MB), # TODO(b/37854783): replace this with min_backoff. ("grpc.testing.fixed_reconnect_backoff_ms", 1000) ]) channel_ready_future = grpc.channel_ready_future(channel) model_server_startup_timeout_secs = os.environ.get( "startup_timeout", MAX_STARTUP_TIMEOUT_SEC) try: channel_ready_future.result(timeout=model_server_startup_timeout_secs) except grpc.FutureTimeoutError: model_server.terminate() # Reload the model. This time we pipe the stderr to get the error message. # TODO(b/67371299): Remove this hack. model_server = subprocess.Popen(args=args, stdin=None, stderr=subprocess.PIPE) time.sleep(MODEL_RELOAD_SEC) model_server.terminate() _, stderr_data = model_server.communicate() raise RuntimeError(parse_error_message(stderr_data)) grpc_channel = implementations.Channel(channel) stub = prediction_service_pb2.beta_create_PredictionService_stub( grpc_channel, pool=None, pool_size=None) logging.debug("Connected to model server at %s", hostport) return model_server, stub
def tensorflow_model_server_predict(host_port=None, model_id=None, signature_name=None, serialized_examples=None, serialized_examples_tensor_name=None, version=None, batch_size=256, retries=10): """Queries a tensorflow_model_server, potentially spawning it first. Args: host_port: tensorflow_model_server address. If None, will spawn a new tms process, and kill it afterwards. model_id: If host_port is given - the name of the model to query. Otherwise a path to the model saved_model dir. signature_name: model's signature to query against. serialized_exapmles: list of serialized tf.train.Example. serialized_examples_tensor_name: when querying the model, supply the serialized tf exmaples into a tensor with this name. version: saved_model version or None. None means: query the newest available version. batch_size: if number of serialized examples exceeds batch_size, the model will be queried multiple times, each time with at most batch_size examples. retries: it takes some time for the tms to spawn and load the model. We do a couple retries, each time waiting for 2 secs. """ from tensorflow_serving.apis import predict_pb2 from tensorflow_serving.apis import prediction_service_pb2 if host_port is None: versions = [] if version is None else [version] models = {model_id: versions} mgr = TFModelServer(models=models) else: mgr = _DummyCtxMgr(None, host_port) trial_num = 0 with mgr as (popen, (host, port)): while trial_num < retries: if trial_num >= 2: print '(Re)trying (%d) to query tensorflow_model_server.' % trial_num if popen is not None and popen.poll() is not None: raise RuntimeError( 'tensorflow_model_server exited with code %d' % popen.returncode) trial_num += 1 try: channel = grpc_beta_implementations.Channel( grpc.insecure_channel( target='%s:%s' % (host, port), options=[ (cygrpc.ChannelArgKey.max_send_message_length, -1), (cygrpc.ChannelArgKey.max_receive_message_length, -1) ])) stub = prediction_service_pb2.beta_create_PredictionService_stub( channel) results = collections.defaultdict(list) for i, chunk in enumerate( _chunks(serialized_examples, batch_size)): if i > 9 and i % 10 == 0: print 'Prediction for batch %d / %d' % ( i, (len(serialized_examples) + batch_size - 1) / batch_size) request = predict_pb2.PredictRequest() request.model_spec.name = model_id if version is not None: request.model_spec.version.value = version request.model_spec.signature_name = signature_name request.inputs[serialized_examples_tensor_name].CopyFrom( tf.contrib.util.make_tensor_proto(chunk)) result_future = stub.Predict.future(request, 15.0) while result_future.done() is False: time.sleep(0.01) result = result_future.result() for key, tensor in result.outputs.iteritems(): results[str(key)] = np.concatenate([ np.array(results[str(key)]), np.array(tensor.float_val), np.array(tensor.double_val), np.array(tensor.int_val), np.array(tensor.string_val), np.array(tensor.int64_val), np.array(tensor.string_val), ]) return results except Exception, e: if trial_num > 3: print 'Failed to query tensorflow_model_server:', e time.sleep(2.) raise Exception( 'Failed to query tensorflow_model_server. Double check ' + 'the model you passed in exists.')