示例#1
0
 def testMaybeSavedModelDir(self):
     base_path = test.test_src_dir_path("/python/saved_model")
     self.assertFalse(loader.maybe_saved_model_directory(base_path))
     base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
     self.assertTrue(loader.maybe_saved_model_directory(base_path))
     base_path = "complete_garbage"
     self.assertFalse(loader.maybe_saved_model_directory(base_path))
 def testMaybeSavedModelDir(self):
   base_path = test.test_src_dir_path("/python/saved_model")
   self.assertFalse(loader.maybe_saved_model_directory(base_path))
   base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
   self.assertTrue(loader.maybe_saved_model_directory(base_path))
   base_path = "complete_garbage"
   self.assertFalse(loader.maybe_saved_model_directory(base_path))
示例#3
0
def load_model(model_path, config=None):
  """Loads the model at the specified path.

  Args:
    model_path: the path to either session_bundle or SavedModel
    config: tf.ConfigProto containing session configuration options.

  Returns:
    A pair of (Session, SignatureDef) objects.

  Raises:
    PredictionError: if the model could not be loaded.
  """
  if loader.maybe_saved_model_directory(model_path):
    try:
      session = tf_session.Session(target="", graph=None, config=config)
      meta_graph = loader.load(session, tags=[tag_constants.SERVING],
                               export_dir=model_path)
    except Exception:  # pylint: disable=broad-except
      raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL,
                            "Failed to load the model due to bad model data.")
  else:
    raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL,
                          "Cloud ML only supports TF 1.0 or above and models "
                          "saved in SavedModel format.")

  if session is None:
    raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL,
                          "Failed to create session when loading the model")
  signature = _get_signature_from_meta_graph(session.graph, meta_graph)

  return session, signature
def load_tf_model(model_path, tags=(tag_constants.SERVING,), config=None):
  """Loads the model at the specified path.

  Args:
    model_path: the path to either session_bundle or SavedModel
    tags: the tags that determines the model to load.
    config: tf.ConfigProto containing session configuration options.

  Returns:
    A pair of (Session, map<string, SignatureDef>) objects.

  Raises:
    PredictionError: if the model could not be loaded.
  """
  if loader.maybe_saved_model_directory(model_path):
    try:
      logging.info("Importing tensorflow.contrib in load_tf_model")
      # pylint: disable=redefined-outer-name,unused-variable,g-import-not-at-top
      import tensorflow as tf
      from tensorflow.python.framework.ops import Graph
      # pylint: enable=redefined-outer-name,unused-variable,g-import-not-at-top
      if tf.__version__.startswith("1.0"):
        session = tf_session.Session(target="", graph=None, config=config)
      else:
        session = tf_session.Session(target="", graph=Graph(), config=config)
      meta_graph = loader.load(session, tags=list(tags), export_dir=model_path)
    except Exception as e:  # pylint: disable=broad-except
      raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL,
                            "Failed to load the model due to bad model data."
                            " tags: %s\n%s" % (list(tags), str(e)))
  else:
    raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL,
                          "Cloud ML only supports TF 1.0 or above and models "
                          "saved in SavedModel format.")

  if session is None:
    raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL,
                          "Failed to create session when loading the model")

  if not meta_graph.signature_def:
    raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL,
                          "MetaGraph must have at least one signature_def.")

  # Remove invalid signatures from the signature map.
  invalid_signatures = []
  for signature_name in meta_graph.signature_def:
    try:
      signature = meta_graph.signature_def[signature_name]
      _update_dtypes(session.graph, signature.inputs)
      _update_dtypes(session.graph, signature.outputs)
    except ValueError as e:
      logging.warn("Error updating signature %s: %s", signature_name, str(e))
      invalid_signatures.append(signature_name)
  for signature_name in invalid_signatures:
    del meta_graph.signature_def[signature_name]

  return session, meta_graph.signature_def
def load_model(
        model_path,
        tags=(tag_constants.SERVING, ),
        signature_name=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
        config=None):
    """Loads the model at the specified path.

  Args:
    model_path: the path to either session_bundle or SavedModel
    tags: the tags that determines the model to load.
    signature_name: the string used as the key to signature map to locate the
                   serving signature.
    config: tf.ConfigProto containing session configuration options.

  Returns:
    A pair of (Session, SignatureDef) objects.

  Raises:
    PredictionError: if the model could not be loaded.
  """
    if loader.maybe_saved_model_directory(model_path):
        try:
            session = tf_session.Session(target="", graph=None, config=config)
            meta_graph = loader.load(session,
                                     tags=list(tags),
                                     export_dir=model_path)
        except Exception:  # pylint: disable=broad-except
            raise PredictionError(
                PredictionError.FAILED_TO_LOAD_MODEL,
                "Failed to load the model due to bad model data."
                " tags: %s" % tags)
    else:
        raise PredictionError(
            PredictionError.FAILED_TO_LOAD_MODEL,
            "Cloud ML only supports TF 1.0 or above and models "
            "saved in SavedModel format.")

    if session is None:
        raise PredictionError(
            PredictionError.FAILED_TO_LOAD_MODEL,
            "Failed to create session when loading the model")
    signature = _get_signature_from_meta_graph(session.graph, meta_graph,
                                               signature_name)

    return session, signature
示例#6
0
def load_session_bundle_or_saved_model_bundle_from_path(export_dir,
                                                        tags=None,
                                                        target="",
                                                        config=None):
  """Load session bundle from the given path.

  The function reads input from the export_dir, constructs the graph data to the
  default graph and restores the parameters for the session created.

  Args:
    export_dir: the directory that contains files exported by exporter.
    tags: Set of string tags to identify the required MetaGraphDef when model is
          saved as SavedModel. These should correspond to the tags used when
          saving the variables using the SavedModel `save()` API.
    target: The execution engine to connect to. See target in tf.Session()
    config: A ConfigProto proto with configuration options. See config in
            tf.Session()

  Returns:
    session: a tensorflow session created from the variable files.
    meta_graph: a meta graph proto saved in the exporter directory.

  Raises:
    RuntimeError: if the required files are missing or contain unrecognizable
    fields, i.e. the exported model is invalid.
  """
  metagraph_def = None
  sess = None
  if loader.maybe_saved_model_directory(export_dir):
    sess = session.Session(target, graph=None, config=config)
    metagraph_def = loader.load(sess, tags, export_dir)
  elif session_bundle.maybe_session_bundle_dir(export_dir):
    sess, metagraph_def = _load_saved_model_from_session_bundle_path(export_dir,
                                                                     target,
                                                                     config)
  else:
    raise RuntimeError("SessionBundle or SavedModelBundle not found at "
                       "specified export location: %s" % export_dir)

  return sess, metagraph_def
示例#7
0
def load_session_bundle_or_saved_model_bundle_from_path(export_dir,
                                                        tags=None,
                                                        target="",
                                                        config=None):
  """Load session bundle from the given path.

  The function reads input from the export_dir, constructs the graph data to the
  default graph and restores the parameters for the session created.

  Args:
    export_dir: the directory that contains files exported by exporter.
    tags: Set of string tags to identify the required MetaGraphDef when model is
          saved as SavedModel. These should correspond to the tags used when
          saving the variables using the SavedModel `save()` API.
    target: The execution engine to connect to. See target in tf.Session()
    config: A ConfigProto proto with configuration options. See config in
            tf.Session()

  Returns:
    session: a tensorflow session created from the variable files.
    meta_graph: a meta graph proto saved in the exporter directory.

  Raises:
    RuntimeError: if the required files are missing or contain unrecognizable
    fields, i.e. the exported model is invalid.
  """
  metagraph_def = None
  sess = None
  if loader.maybe_saved_model_directory(export_dir):
    sess = session.Session(target, graph=None, config=config)
    metagraph_def = loader.load(sess, tags, export_dir)
  elif session_bundle.maybe_session_bundle_dir(export_dir):
    sess, metagraph_def = _load_saved_model_from_session_bundle_path(export_dir,
                                                                     target,
                                                                     config)
  else:
    raise RuntimeError("SessionBundle or SavedModelBundle not found at "
                       "specified export location: %s" % export_dir)

  return sess, metagraph_def