Exemple #1
0
 def __init__(self,
              host,
              port,
              timeout,
              user,
              password,
              creds=None,
              options=None):
     """This class creates grpc calls using python.
         :param username: Username for device login
         :param password: Password for device login
         :param host: The ip address for the device
         :param port  The port for the device
         :param timeout: how long before the rpc call timesout
         :param creds: Input of the pem file
         :param options: TLS server name
         :type password: str
         :type username: str
         :type server: str
         :type port: int
         :type timeout:int
         :type creds: str
         :type options: str
     """
     if creds != None:
         self._target = '%s:%d' % (host, port)
         self._creds = implementations.ssl_channel_credentials(creds)
         self._options = options
         channel = grpc.secure_channel(self._target, self._creds, ((
             'grpc.ssl_target_name_override',
             self._options,
         ), ))
         self._channel = implementations.Channel(channel)
     else:
         self._host = host
         self._port = port
         self._channel = implementations.insecure_channel(
             self._host, self._port)
     self._stub = ems_grpc_pb2.beta_create_gRPCConfigOper_stub(
         self._channel)
     self._timeout = int(timeout)
     self._metadata = [('username', user), ('password', password)]
Exemple #2
0
def not_really_secure_channel(
    host, port, channel_credentials, server_host_override):
  """Creates an insecure Channel to a remote host.

  Args:
    host: The name of the remote host to which to connect.
    port: The port of the remote host to which to connect.
    channel_credentials: The implementations.ChannelCredentials with which to
      connect.
    server_host_override: The target name used for SSL host name checking.

  Returns:
    An implementations.Channel to the remote host through which RPCs may be
      conducted.
  """
  target = '%s:%d' % (host, port)
  channel = grpc.secure_channel(
      target, channel_credentials,
      ((b'grpc.ssl_target_name_override', server_host_override,),))
  return implementations.Channel(channel)
Exemple #3
0
def not_really_secure_channel(host, port, client_credentials,
                              server_host_override):
    """Creates an insecure Channel to a remote host.

  Args:
    host: The name of the remote host to which to connect.
    port: The port of the remote host to which to connect.
    client_credentials: The implementations.ClientCredentials with which to
      connect.
    server_host_override: The target name used for SSL host name checking.

  Returns:
    An implementations.Channel to the remote host through which RPCs may be
      conducted.
  """
    hostport = '%s:%d' % (host, port)
    intermediary_low_channel = _intermediary_low.Channel(
        hostport,
        client_credentials._intermediary_low_credentials,
        server_host_override=server_host_override)
    return implementations.Channel(intermediary_low_channel._internal,
                                   intermediary_low_channel)
host = '10.75.58.60'
port = 57400
options = 'ems.cisco.com'

ca_cert = 'ems.pem'  # credential file scp from devices
creds = open(ca_cert).read()

target = '%s:%d' % (host, port)
creds = implementations.ssl_channel_credentials(creds.encode(
    ('utf-8')))  # args with byte type
channel = grpc.secure_channel(target, creds, ((
    'grpc.ssl_target_name_override',
    options,
), ))
channel = implementations.Channel(channel)

stub = ems_grpc_pb2.beta_create_gRPCConfigOper_stub(channel)
sub_id = 'test_sub'  # Telemetry MDT subscribtion
sub_args = ems_grpc_pb2.CreateSubsArgs(ReqId=1, encode=3, subidstr=sub_id)

timeout = float(100000)
metadata = [('username', 'cisco'), ('password', 'cisco')]

stream = stub.CreateSubs(sub_args, timeout=timeout, metadata=metadata)

for segment in stream:
    telemetry_pb = telemetry_pb2.Telemetry()
    t = telemetry_pb.ParseFromString(segment.data)
    # Print Json Message
    print(MessageToJson(telemetry_pb))
    def _connect(self):
        """
        Create GRPC connection to target host
        :return: None
        """
        if not HAS_GRPC:
            raise AnsibleError(
                "grpcio is required to use the gRPC connection type. Please run 'pip install grpcio'"
            )
        host = self.get_option("host")
        host = self._play_context.remote_addr
        if self.connected:
            self.queue_message(
                "log", "gRPC connection to host %s already exist" % host
            )
            return

        port = self.get_option("port")
        self._target = host if port is None else "%s:%d" % (host, port)
        self._timeout = self.get_option("persistent_command_timeout")
        self._login_credentials = [
            ("username", self.get_option("remote_user")),
            ("password", self.get_option("password")),
        ]
        ssl_target_name_override = self.get_option("ssl_target_name_override")
        if ssl_target_name_override:
            self._channel_options = [
                ("grpc.ssl_target_name_override", ssl_target_name_override),
            ]
        else:
            self._channel_options = None

        certs = {}
        private_key_file = self.get_option("private_key_file")
        root_certificates_file = self.get_option("root_certificates_file")
        certificate_chain_file = self.get_option("certificate_chain_file")

        try:
            if root_certificates_file:
                with open(root_certificates_file, "rb") as f:
                    certs["root_certificates"] = f.read()
            if private_key_file:
                with open(private_key_file, "rb") as f:
                    certs["private_key"] = f.read()
            if certificate_chain_file:
                with open(certificate_chain_file, "rb") as f:
                    certs["certificate_chain"] = f.read()
        except Exception as e:
            raise AnsibleConnectionFailure(
                "Failed to read certificate keys: %s" % e
            )
        if certs:
            creds = ssl_channel_credentials(**certs)
            channel = secure_channel(
                self._target, creds, options=self._channel_options
            )
        else:
            channel = insecure_channel(
                self._target, options=self._channel_options
            )

        self.queue_message(
            "vvv",
            "ESTABLISH GRPC CONNECTION FOR USER: %s on PORT %s TO %s"
            % (self.get_option("remote_user"), port, host),
        )
        self._channel = implementations.Channel(channel)
        self.queue_message(
            "vvvv", "grpc connection has completed successfully"
        )
        self._connected = True
Exemple #6
0
def _start_model_server(model_path, flags):
    """Start the model server as a subprocess.

  Args:
    model_path: string path to pass to model server in --model_base_path flag
    flags: a _FlagConfig object containing configuration for the webapp

  Returns:
    model_server: model server process
    stub: grpc stub to PredictionService running on model server

  Raises:
    RuntimeError: if model server fails to come up and is used as prediction
    engine.
  """
    if not model_path.startswith("@"):
        port = "--port=%s" % MODEL_SERVER_PORT
        tensorflow_session_parallelism = (
            flags.tensorflow_session_parallelism
            or os.environ.get("tensorflow_session_parallelism") or 0)
        args = [
            flags.model_server_binary_path, port,
            "--model_base_path={}".format(os.path.dirname(model_path)),
            "--file_system_poll_wait_seconds={}".format(
                MODEL_SERVER_FS_POLLING_INTERVAL),
            "--tensorflow_session_parallelism={}".format(
                tensorflow_session_parallelism)
        ]
        logging.debug("Starting model server: %s", args)
        model_server = subprocess.Popen(args=args, stdin=None)
        hostport = "localhost:%s" % MODEL_SERVER_PORT
    else:
        model_server = None
        hostport = model_path[1:]
    # Default grpc message limit is 4 MiB. Some of our models like predict_io_test
    # exceed that limit. We raise this limit as suggested in
    # https://groups.google.com/a/google.com/forum/?utm_medium=email&utm_source=footer#!msg/grpc-users/LSf9JvK69bw/qDHXdTIZDAAJ
    # Note that we are using max_message_length as opposed to
    # max_receive_message_length because it was introduced in a later version of
    # grpc, and grpc 1.0.1 which we are currently using only has
    # max_message_length. We pick max message length to be 32MiB to match
    # AppEngine limit of 32MiB message size.
    logging.debug("Connecting to model server at %s", hostport)
    channel = grpc.insecure_channel(
        hostport,
        options=[
            ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH_MB),
            ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH_MB),
            ("grpc.max_message_length", MAX_MESSAGE_LENGTH_MB),
            # TODO(b/37854783): replace this with min_backoff.
            ("grpc.testing.fixed_reconnect_backoff_ms", 1000)
        ])
    channel_ready_future = grpc.channel_ready_future(channel)
    model_server_startup_timeout_secs = os.environ.get(
        "startup_timeout", MAX_STARTUP_TIMEOUT_SEC)
    try:
        channel_ready_future.result(timeout=model_server_startup_timeout_secs)
    except grpc.FutureTimeoutError:
        model_server.terminate()
        # Reload the model. This time we pipe the stderr to get the error message.
        # TODO(b/67371299): Remove this hack.
        model_server = subprocess.Popen(args=args,
                                        stdin=None,
                                        stderr=subprocess.PIPE)
        time.sleep(MODEL_RELOAD_SEC)
        model_server.terminate()
        _, stderr_data = model_server.communicate()
        raise RuntimeError(parse_error_message(stderr_data))
    grpc_channel = implementations.Channel(channel)
    stub = prediction_service_pb2.beta_create_PredictionService_stub(
        grpc_channel, pool=None, pool_size=None)

    logging.debug("Connected to model server at %s", hostport)
    return model_server, stub
def tensorflow_model_server_predict(host_port=None,
                                    model_id=None,
                                    signature_name=None,
                                    serialized_examples=None,
                                    serialized_examples_tensor_name=None,
                                    version=None,
                                    batch_size=256,
                                    retries=10):
    """Queries a tensorflow_model_server, potentially spawning it first.

  Args:
    host_port: tensorflow_model_server address. If None, will spawn a new
        tms process, and kill it afterwards.
    model_id: If host_port is given - the name of the model to query. Otherwise
        a path to the model saved_model dir.
    signature_name: model's signature to query against.
    serialized_exapmles: list of serialized tf.train.Example.
    serialized_examples_tensor_name: when querying the model, supply the
        serialized tf exmaples into a tensor with this name.
    version: saved_model version or None. None means: query the newest
        available version.
    batch_size: if number of serialized examples exceeds batch_size, the model
        will be queried multiple times, each time with at most batch_size
        examples.
    retries: it takes some time for the tms to spawn and load the model. We
        do a couple retries, each time waiting for 2 secs.

  """
    from tensorflow_serving.apis import predict_pb2
    from tensorflow_serving.apis import prediction_service_pb2

    if host_port is None:
        versions = [] if version is None else [version]
        models = {model_id: versions}
        mgr = TFModelServer(models=models)
    else:
        mgr = _DummyCtxMgr(None, host_port)

    trial_num = 0
    with mgr as (popen, (host, port)):
        while trial_num < retries:
            if trial_num >= 2:
                print '(Re)trying (%d) to query tensorflow_model_server.' % trial_num

            if popen is not None and popen.poll() is not None:
                raise RuntimeError(
                    'tensorflow_model_server exited with code %d' %
                    popen.returncode)

            trial_num += 1
            try:
                channel = grpc_beta_implementations.Channel(
                    grpc.insecure_channel(
                        target='%s:%s' % (host, port),
                        options=[
                            (cygrpc.ChannelArgKey.max_send_message_length, -1),
                            (cygrpc.ChannelArgKey.max_receive_message_length,
                             -1)
                        ]))
                stub = prediction_service_pb2.beta_create_PredictionService_stub(
                    channel)
                results = collections.defaultdict(list)
                for i, chunk in enumerate(
                        _chunks(serialized_examples, batch_size)):
                    if i > 9 and i % 10 == 0:
                        print 'Prediction for batch %d / %d' % (
                            i, (len(serialized_examples) + batch_size - 1) /
                            batch_size)
                    request = predict_pb2.PredictRequest()
                    request.model_spec.name = model_id
                    if version is not None:
                        request.model_spec.version.value = version
                    request.model_spec.signature_name = signature_name
                    request.inputs[serialized_examples_tensor_name].CopyFrom(
                        tf.contrib.util.make_tensor_proto(chunk))

                    result_future = stub.Predict.future(request, 15.0)
                    while result_future.done() is False:
                        time.sleep(0.01)
                    result = result_future.result()
                    for key, tensor in result.outputs.iteritems():
                        results[str(key)] = np.concatenate([
                            np.array(results[str(key)]),
                            np.array(tensor.float_val),
                            np.array(tensor.double_val),
                            np.array(tensor.int_val),
                            np.array(tensor.string_val),
                            np.array(tensor.int64_val),
                            np.array(tensor.string_val),
                        ])
                return results
            except Exception, e:
                if trial_num > 3:
                    print 'Failed to query tensorflow_model_server:', e
                time.sleep(2.)

        raise Exception(
            'Failed to query tensorflow_model_server. Double check ' +
            'the model you passed in exists.')