示例#1
0
 def test_parse_saved_model_exception(self, builder_cls):
     """Test that error message for not exist model have OS-depend delimiter in path"""
     path = _get_export_dir("not_existing_dir")
     pattern = os.path.sep + "{"
     with self.assertRaises(IOError) as err:
         loader_impl.parse_saved_model(path)
     self.assertTrue(pattern in str(err.exception))
示例#2
0
def smoke_test_model(model_path):
    try:
        resolved_model = hub.resolve(model_path)
        loader_impl.parse_saved_model(resolved_model)
    except Exception as e:  # pylint: disable=broad-except
        return False, e
    return True, None
示例#3
0
def load_model(filepath, custom_objects=None, compile=True, options=None):  # pylint: disable=redefined-builtin
    """Loads a model saved via `model.save()`.

  Usage:

  >>> model = tf.keras.Sequential([
  ...     tf.keras.layers.Dense(5, input_shape=(3,)),
  ...     tf.keras.layers.Softmax()])
  >>> model.save('/tmp/model')
  >>> loaded_model = tf.keras.models.load_model('/tmp/model')
  >>> x = tf.random.uniform((10, 3))
  >>> assert np.allclose(model.predict(x), loaded_model.predict(x))

  Note that the model weights may have different scoped names after being
  loaded. Scoped names include the model/layer names, such as
  `"dense_1/kernel:0"`. It is recommended that you use the layer properties to
  access specific variables, e.g. `model.get_layer("dense_1").kernel`.

  Args:
      filepath: One of the following:
          - String or `pathlib.Path` object, path to the saved model
          - `h5py.File` object from which to load the model
      custom_objects: Optional dictionary mapping names
          (strings) to custom classes or functions to be
          considered during deserialization.
      compile: Boolean, whether to compile the model
          after loading.
      options: Optional `tf.saved_model.LoadOptions` object that specifies
        options for loading from SavedModel.

  Returns:
      A Keras model instance. If the original model was compiled, and saved with
      the optimizer, then the returned model will be compiled. Otherwise, the
      model will be left uncompiled. In the case that an uncompiled model is
      returned, a warning is displayed if the `compile` argument is set to
      `True`.

  Raises:
      ImportError: if loading from an hdf5 file and h5py is not available.
      IOError: In case of an invalid savefile.
  """
    with generic_utils.SharedObjectLoadingScope():
        with generic_utils.CustomObjectScope(custom_objects or {}):
            with load_context.load_context(options):
                if (h5py is not None and (isinstance(filepath, h5py.File)
                                          or h5py.is_hdf5(filepath))):
                    return hdf5_format.load_model_from_hdf5(
                        filepath, custom_objects, compile)

                filepath = path_to_string(filepath)
                if isinstance(filepath, six.string_types):
                    loader_impl.parse_saved_model(filepath)
                    return saved_model_load.load(filepath, compile, options)

    raise IOError(
        'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
        'available) or SavedModel.')
示例#4
0
 def assert_can_resolve_asset(self, asset_path: str):
   """Attempt to hub.resolve the given asset path."""
   try:
     resolved_model = hub.resolve(asset_path)
     loader_impl.parse_saved_model(resolved_model)
     _validate_file_paths(resolved_model)
   except Exception as e:  # pylint: disable=broad-except
     raise MarkdownDocumentationError(
         f"The model on path {asset_path} failed to parse. Please make sure "
         "that the asset-path metadata points to a valid TF2 SavedModel or a "
         "TF1 Hub module, compressed as described in section 'Model' of "
         f"README.md. Underlying reason for failure: {e}.")
示例#5
0
 def assert_can_resolve_asset(self, asset_path: str):
   """Attempt to hub.resolve the given asset path."""
   try:
     resolved_model = hub.resolve(asset_path)
     loader_impl.parse_saved_model(resolved_model)
     _validate_file_paths(resolved_model)
   except Exception as e:  # pylint: disable=broad-except
     raise MarkdownDocumentationError(
         f"The model on path {asset_path} failed to parse. Please make sure "
         "that the asset-path metadata points to a valid TF2 SavedModel or a "
         "TF1 Hub module as described on "
         "https://www.tensorflow.org/hub/exporting_tf2_saved_model. "
         f"Underlying reason for failure: {e}.")
示例#6
0
  def _MakeSavedModelV1(self, run_params):
    """Write the saved model as an input for testing.

    In addition to creating a SavedModel like its parent method, this method
    replaces this SavedModel by adding TF-TRT conversion parameters as function
    attributes to each function in the SavedModel.

    Args:
      run_params: The current test run parameters.

    Returns:
      The directory of the saved model.
    """
    saved_model_dir = trt_test.TfTrtIntegrationTestBase._MakeSavedModelV1(
        self, run_params)
    saved_model_proto = loader_impl.parse_saved_model(saved_model_dir)
    new_saved_model = saved_model_pb2.SavedModel()
    new_saved_model.CopyFrom(saved_model_proto)
    new_meta_graph_def = new_saved_model.meta_graphs[0]
    for func_def in new_meta_graph_def.graph_def.library.function:
      # Disable function inlining.
      func_def.attr["_noinline"].CopyFrom(attr_value_pb2.AttrValue(b=True))
      self._copy_test_attributes_to_func_def(func_def)
    old_saved_model_file = os.path.join(saved_model_dir,
                                        constants.SAVED_MODEL_FILENAME_PB)
    if os.path.exists(old_saved_model_file):
      os.remove(old_saved_model_file)
    path = os.path.join(
        compat.as_bytes(saved_model_dir),
        compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
    file_io.write_string_to_file(
        path, new_saved_model.SerializeToString(deterministic=True))
    return saved_model_dir
示例#7
0
def load_internal(export_dir, tags=None, loader_cls=Loader):
  """Loader implementation."""
  if tags is not None and not isinstance(tags, set):
    # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
    # sequences for nest.flatten, so we put those through as-is.
    tags = nest.flatten(tags)
  saved_model_proto = loader_impl.parse_saved_model(export_dir)
  if (len(saved_model_proto.meta_graphs) == 1
      and saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
    meta_graph_def = saved_model_proto.meta_graphs[0]
    if (tags is not None
        and set(tags) != set(meta_graph_def.meta_info_def.tags)):
      raise ValueError(
          ("The SavedModel at {} has one MetaGraph with tags {}, but got an "
           "incompatible argument tags={} to tf.saved_model.load. You may omit "
           "it, pass 'None', or pass matching tags.")
          .format(export_dir, meta_graph_def.meta_info_def.tags, tags))
    object_graph_proto = meta_graph_def.object_graph_def
    with ops.init_scope():
      loader = loader_cls(object_graph_proto,
                          saved_model_proto,
                          export_dir)
      root = loader.get(0)
    root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
    root.tensorflow_git_version = (
        meta_graph_def.meta_info_def.tensorflow_git_version)
  else:
    with ops.init_scope():
      root = load_v1_in_v2.load(export_dir, tags)
  return root
    def test_export_tpu_savedmodel_export_to_cpu_false(self):
        # Test that when `export_to_cpu` is `False`, CPU metagraph is not exported.
        tmpdir = tempfile.mkdtemp()

        model_fn = get_model_fn(export_tpu_tensor=True, export_cpu_tensor=True)
        run_config = create_run_config(iterations_per_loop=4)

        def _input_fn(params):
            return dummy_input_fn(params['batch_size'])

        est = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                         config=run_config,
                                         train_batch_size=16,
                                         export_to_tpu=True,
                                         export_to_cpu=False)
        est.train(_input_fn, steps=1)

        export_dir_base = os.path.join(compat.as_bytes(tmpdir),
                                       compat.as_bytes('export_no_tpu'))
        export_dir = est.export_saved_model(export_dir_base,
                                            self._serving_input_receiver_fn)
        saved_model = loader_impl.parse_saved_model(export_dir)
        self.assertLen(saved_model.meta_graphs, 1)
        tags = set(saved_model.meta_graphs[0].meta_info_def.tags)
        self.assertEqual(tags, set([tag_constants.SERVING, tag_constants.TPU]))

        # Clean up.
        gfile.DeleteRecursively(tmpdir)
示例#9
0
def available_signature_names(answers):
    """Generate the available saved model signatures from the proto file
    and selected tags.
  Args:
    ansowers: user selected parameter dict.
  """
    if (is_saved_model(answers[common.INPUT_FORMAT])
            and common.SAVED_MODEL_TAGS in answers):
        path = answers[common.INPUT_PATH]
        tags = answers[common.SAVED_MODEL_TAGS]
        saved_model = loader_impl.parse_saved_model(path)
        for meta_graph in saved_model.meta_graphs:
            if tags == ",".join(meta_graph.meta_info_def.tags):
                signatures = []
                for key in meta_graph.signature_def:
                    input_nodes = meta_graph.signature_def[key].inputs
                    output_nodes = meta_graph.signature_def[key].outputs
                    signatures.append({
                        'value':
                        key,
                        'name':
                        format_signature(key, input_nodes, output_nodes)
                    })
                return signatures
    return []
示例#10
0
  def test_save_variable_devices(self, save_devices, meta_graph_only):
    context._reset_context()
    cpus = context.context().list_physical_devices("CPU")
    if len(cpus) == 1:
      context.context().set_logical_device_configuration(
          cpus[0], [
              context.LogicalDeviceConfiguration(),
              context.LogicalDeviceConfiguration()
          ])
    context.ensure_initialized()

    root = tracking.AutoTrackable()
    with ops.device("CPU:0"):
      root.v0 = variables.Variable(1., name="v0")
    with ops.device("CPU:1"):
      root.v1 = variables.Variable(1., name="v1")

    options = save_options.SaveOptions(
        experimental_variable_policy=save_devices)
    file_name = os.path.join(self.get_temp_dir(), "saved_model")
    if meta_graph_only:
      save.export_meta_graph(obj=root, filename=file_name, options=options)
    else:
      save.save(obj=root, export_dir=file_name, options=options)

    meta = None
    if meta_graph_only:
      meta = meta_graph.read_meta_graph_file(file_name)
    else:
      meta = loader_impl.parse_saved_model(file_name).meta_graphs[0]

    # Check devices in meta graph nodes.
    graph_def = meta.graph_def
    v0 = next((n for n in graph_def.node if n.name == "v0"), None)
    v1 = next((n for n in graph_def.node if n.name == "v1"), None)
    self.assertIsNotNone(v0)
    self.assertIsNotNone(v1)
    if save_devices == save_options.VariablePolicy.SAVE_VARIABLE_DEVICES:
      self.assertIn("CPU:0", v0.device)
      self.assertIn("CPU:1", v1.device)
    else:
      self.assertEmpty(v0.device)
      self.assertEmpty(v1.device)

    # Check devices in object graph nodes.
    object_graph_def = meta.object_graph_def
    v0 = next((n.variable
               for n in object_graph_def.nodes
               if n.HasField("variable") and n.variable.name == "v0"), None)
    v1 = next((n.variable
               for n in object_graph_def.nodes
               if n.HasField("variable") and n.variable.name == "v1"), None)
    self.assertIsNotNone(v0)
    self.assertIsNotNone(v1)
    if save_devices == save_options.VariablePolicy.SAVE_VARIABLE_DEVICES:
      self.assertIn("CPU:0", v0.device)
      self.assertIn("CPU:1", v1.device)
    else:
      self.assertEmpty(v0.device)
      self.assertEmpty(v1.device)
示例#11
0
  def _get_model_signature(self, model_path: Text) -> _SignatureDef:
    """Returns a model signature."""

    saved_model_pb = loader_impl.parse_saved_model(model_path)
    meta_graph_def = None
    for graph_def in saved_model_pb.meta_graphs:
      if graph_def.meta_info_def.tags == [
          tf.compat.v1.saved_model.tag_constants.SERVING
      ]:
        meta_graph_def = graph_def
    if not meta_graph_def:
      raise RuntimeError('Tag tf.compat.v1.saved_model.tag_constants.SERVING'
                         ' does not exist in saved model: %s. This is required'
                         ' for remote inference.' % model_path)
    if tf.saved_model.PREDICT_METHOD_NAME in meta_graph_def.signature_def:
      return meta_graph_def.signature_def[tf.saved_model.PREDICT_METHOD_NAME]
    if (tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY in
        meta_graph_def.signature_def):
      return meta_graph_def.signature_def[
          tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
    raise RuntimeError(
        'Cannot find serving signature in saved model: %s,'
        ' tf.saved_model.PREDICT_METHOD_NAME or '
        ' tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY is needed.' %
        model_path)
示例#12
0
def replace_graph_def_of_saved_model(input_model, output_model, graph_def):
    model_variables_dir = os.path.join(input_model, 'variables')
    if not os.path.exists(output_model):
        os.makedirs(output_model)
    export_variables_dir = os.path.join(output_model, 'variables')
    export_saved_model = os.path.join(output_model, 'saved_model.pb')

    checkpoint_file = os.path.join(export_variables_dir, 'checkpoint')
    if not os.path.exists(export_variables_dir):
        os.makedirs(export_variables_dir)

    with open(checkpoint_file, 'w') as f:
        f.write("model_checkpoint_path: \"variables\"\n")
    from tensorflow.python.saved_model import loader_impl
    saved_model = loader_impl.parse_saved_model(input_model)
    meta_graph = saved_model.meta_graphs[0]
    # not all saved model have variables
    try:
        with tf.compat.v1.Session(graph=tf.Graph()) as sess:
            loaded = tf.compat.v1.saved_model.loader.load(
                sess, ["serve"], input_model)
            # sess.run('init_all_tables')
            saver = tf.compat.v1.train.Saver()
            saver.save(sess,
                       os.path.join(export_variables_dir, 'variables'),
                       write_meta_graph=False,
                       write_state=False)
    except:
        logger.info('no variables in the saved model')
    meta_graph.graph_def.CopyFrom(graph_def)
    from tensorflow.python.lib.io import file_io
    file_io.write_string_to_file(export_saved_model,
                                 saved_model.SerializeToString())
示例#13
0
def load_model(filepath, custom_objects=None, compile=True):  # pylint: disable=redefined-builtin
    """Loads a model saved via `save_model`.

  Note that the model weights may have different scoped names after being
  loaded. Scoped names include the model/layer names, such as
  "dense_1/kernel:0"`. It is recommended that you use the layer properties to
  access specific variables, e.g. `model.get_layer("dense_1").kernel`.

  Arguments:
      filepath: One of the following:
          - String or `pathlib.Path` object, path to the saved model
          - `h5py.File` object from which to load the model
      custom_objects: Optional dictionary mapping names
          (strings) to custom classes or functions to be
          considered during deserialization.
      compile: Boolean, whether to compile the model
          after loading.

  Returns:
      A Keras model instance. If an optimizer was found
      as part of the saved model, the model is already
      compiled. Otherwise, the model is uncompiled and
      a warning will be displayed. When `compile` is set
      to False, the compilation is omitted without any
      warning.

  Raises:
      ImportError: if loading from an hdf5 file and h5py is not available.
      IOError: In case of an invalid savefile.
  """
    with generic_utils.CustomObjectScope(custom_objects or {}):
        if (h5py is not None and
            (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
            return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
                                                    compile)

        if sys.version_info >= (3, 4) and isinstance(filepath, pathlib.Path):
            filepath = str(filepath)
        if isinstance(filepath, six.string_types):
            loader_impl.parse_saved_model(filepath)
            return saved_model_load.load(filepath, compile)

    raise IOError(
        'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
        'available) or SavedModel.')
示例#14
0
def detect_saved_model(input_path):
    if os.path.exists(os.path.join(input_path, 'assets', 'saved_model.json')):
        return common.KERAS_SAVED_MODEL
    saved_model = loader_impl.parse_saved_model(input_path)
    graph_def = saved_model.meta_graphs[0].object_graph_def
    if graph_def.nodes:
        if 'tf_keras' in graph_def.nodes[0].user_object.identifier:
            return common.KERAS_SAVED_MODEL
    return common.TF_SAVED_MODEL
示例#15
0
 def test_version_information_included(self):
     root = tracking.AutoTrackable()
     save_dir = os.path.join(self.get_temp_dir(), "saved_model")
     save.save(root, save_dir)
     saved_model_proto = loader_impl.parse_saved_model(save_dir)
     self.assertEqual(
         versions.__version__,
         saved_model_proto.meta_graphs[0].meta_info_def.tensorflow_version)
     self.assertEqual(
         versions.__git_version__, saved_model_proto.meta_graphs[0].
         meta_info_def.tensorflow_git_version)
示例#16
0
def _is_qat_saved_model(saved_model_path: str):
  """Checks if the SavedModel is QAT-enabled by looking for 'FakeQuant' ops."""
  saved_model_proto = saved_model_loader.parse_saved_model(saved_model_path)
  for meta_graph in saved_model_proto.meta_graphs:
    if any(
        node.op.startswith('FakeQuant') for node in meta_graph.graph_def.node):
      return True
    for function in meta_graph.graph_def.library.function:
      if any(node.op.startswith('FakeQuant') for node in function.node_def):
        return True
  return False
示例#17
0
 def test_version_information_included(self):
   root = tracking.AutoTrackable()
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   save.save(root, save_dir)
   saved_model_proto = loader_impl.parse_saved_model(save_dir)
   self.assertEqual(
       versions.__version__,
       saved_model_proto.meta_graphs[0].meta_info_def.tensorflow_version)
   self.assertEqual(
       versions.__git_version__,
       saved_model_proto.meta_graphs[0].meta_info_def.tensorflow_git_version)
示例#18
0
  def __init__(self, export_dir):
    """Creates an MethodNameUpdater object.

    Args:
      export_dir: Directory containing the SavedModel files.

    Raises:
      IOError: If the saved model file does not exist, or cannot be successfully
      parsed.
    """
    self._export_dir = export_dir
    self._saved_model = loader.parse_saved_model(export_dir)
示例#19
0
def available_tags(answers):
    """Generate the available saved model tags from the proto file.
  Args:
    ansowers: user selected parameter dict.
  """
    if is_saved_model(answers[common.INPUT_FORMAT]):
        saved_model = loader_impl.parse_saved_model(answers[common.INPUT_PATH])
        tags = []
        for meta_graph in saved_model.meta_graphs:
            tags.append(",".join(meta_graph.meta_info_def.tags))
        return tags
    return []
示例#20
0
def load_model(filepath, custom_objects=None, compile=True):  # pylint: disable=redefined-builtin
    """Loads a model saved via `save_model`.

  Arguments:
      filepath: One of the following:
          - String or `pathlib.Path` object, path to the saved model
          - `h5py.File` object from which to load the model
      custom_objects: Optional dictionary mapping names
          (strings) to custom classes or functions to be
          considered during deserialization.
      compile: Boolean, whether to compile the model
          after loading.

  Returns:
      A Keras model instance. If an optimizer was found
      as part of the saved model, the model is already
      compiled. Otherwise, the model is uncompiled and
      a warning will be displayed. When `compile` is set
      to False, the compilation is omitted without any
      warning.

  Raises:
      ImportError: if loading from an hdf5 file and h5py is not available.
      IOError: In case of an invalid savefile.
  """
    if (h5py is not None
            and (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
        return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
                                                compile)

    if sys.version >= '3.4' and isinstance(filepath, pathlib.Path):
        filepath = str(filepath)
    if isinstance(filepath, six.string_types):
        loader_impl.parse_saved_model(filepath)
        return saved_model_load.load(filepath, compile)

    raise IOError(
        'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
        'available) or SavedModel.')
示例#21
0
def load_model(filepath, custom_objects=None, compile=True):  # pylint: disable=redefined-builtin
  """Loads a model saved via `save_model`.

  Arguments:
      filepath: One of the following:
          - String, path to the saved model
          - `h5py.File` object from which to load the model
      custom_objects: Optional dictionary mapping names
          (strings) to custom classes or functions to be
          considered during deserialization.
      compile: Boolean, whether to compile the model
          after loading.

  Returns:
      A Keras model instance. If an optimizer was found
      as part of the saved model, the model is already
      compiled. Otherwise, the model is uncompiled and
      a warning will be displayed. When `compile` is set
      to False, the compilation is omitted without any
      warning.

  Raises:
      ImportError: if loading from an hdf5 file and h5py is not available.
      IOError: In case of an invalid savefile.
  """
  if not tf2.enabled() or (
      h5py is not None and (
          isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
    return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)

  if isinstance(filepath, six.string_types):
    loader_impl.parse_saved_model(filepath)
    return saved_model.load_from_saved_model(filepath)

  raise IOError(
      'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
      'available) or SavedModel.')
示例#22
0
    def test_export_correct_output_shapes(self):
        """Asserts that nodes are exported with the correct number of output shapes.

    After backpropagation rewrite, functions are rewritten with additional
    outputs. When exporting to SavedModel, the shapes of the additional outputs
    were incorrectly added to the FunctionDef proto (b/133666530).
    """
        obj = tracking.AutoTrackable()
        obj.v = variables.Variable(2.)

        @def_function.function(
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
        def f(x):
            return (math_ops.multiply(obj.v,
                                      x), math_ops.multiply(obj.v,
                                                            (x + 1)), None)

        obj.f = f

        @def_function.function(
            input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
        def g(x):
            return obj.f(x)[1]

        obj.g = g

        # After the following lines, the concrete functions of obj.g and obj.f are
        # rewritten with many extra outputs.
        with backprop.GradientTape():
            obj.g(constant_op.constant(3.0))

        save_dir = os.path.join(self.get_temp_dir(), "saved_model")
        save.save(obj, save_dir, signatures={"g": obj.g})
        graph_def = loader_impl.parse_saved_model(
            save_dir).meta_graphs[0].graph_def

        def assert_correct_number_of_output_shapes(node):
            if node.op == "StatefulPartitionedCall":
                fn_name = node.attr["f"].func.name
                if fn_name.startswith("__inference_f"):
                    self.assertLen(node.attr["_output_shapes"].list.shape, 2)
                if fn_name.startswith("__inference_g"):
                    self.assertLen(node.attr["_output_shapes"].list.shape, 1)

        for f in graph_def.library.function:
            if (f.signature.name.startswith("__inference_f")
                    or f.signature.name.startswith("__inference_g")):
                for node in f.node_def:
                    assert_correct_number_of_output_shapes(node)
示例#23
0
 def test_function_aliases(self):
   root = tracking.AutoTrackable()
   root.f = def_function.function(
       lambda x: 2. * x,
       input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
   save_dir = os.path.join(self.get_temp_dir(), "saved_model")
   options = save_options.SaveOptions(function_aliases={
       "my_func": root.f,
   })
   save.save(root, save_dir, root.f, options=options)
   function_cache = list(root.f._stateful_fn._function_cache.all_values())
   function_aliases = loader_impl.parse_saved_model(
       save_dir).meta_graphs[0].meta_info_def.function_aliases
   self.assertLen(function_cache, 1)
   self.assertEqual(function_cache[0].name.decode("utf-8"),
                    list(function_aliases.keys())[0])
示例#24
0
  def testTF1HubFormattedModel(self):
    """Test a TF1 hub formatted model."""
    saved_model_dir = self._createV1SavedModel(shape=[1, 16, 16, 3])

    # TF1 hub model is based on V1 saved model and they omit the saved model
    # schema version setting.
    saved_model_proto = parse_saved_model(saved_model_dir)
    saved_model_proto.saved_model_schema_version = 0

    saved_model_pb_file_path = os.path.join(saved_model_dir, 'saved_model.pb')
    with file_io.FileIO(saved_model_pb_file_path, 'wb') as writer:
      writer.write(saved_model_proto.SerializeToString())

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
    tflite_model = converter.convert()
    self.assertTrue(tflite_model)
示例#25
0
def load(export_dir):
  """Load a SavedModel from `export_dir`."""
  saved_model_proto = loader_impl.parse_saved_model(export_dir)
  object_graph_filename = os.path.join(
      compat.as_bytes(export_dir),
      compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY),
      compat.as_bytes("object_graph.pb"))
  if file_io.file_exists(object_graph_filename):
    object_graph_proto = _load_saved_object_graph_proto(object_graph_filename)
    loader = _Loader(object_graph_proto,
                     saved_model_proto,
                     export_dir)
    root = loader.get(0)
  else:
    raise NotImplementedError(
        "Currently only SavedModels exported with `tf.saved_model.save` may be "
        "imported. Other SavedModels may eventually be supported via load().")
  return root
示例#26
0
def load(export_dir):
  """Load a SavedModel from `export_dir`."""
  saved_model_proto = loader_impl.parse_saved_model(export_dir)
  object_graph_filename = os.path.join(
      compat.as_bytes(export_dir),
      compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY),
      compat.as_bytes("object_graph.pb"))
  if file_io.file_exists(object_graph_filename):
    object_graph_proto = _load_saved_object_graph_proto(object_graph_filename)
    loader = _Loader(object_graph_proto,
                     saved_model_proto,
                     export_dir)
    root = loader.get(0)
  else:
    raise NotImplementedError(
        "Currently only SavedModels exported with `tf.saved_model.save` may be "
        "imported. Other SavedModels may eventually be supported via load().")
  return root
示例#27
0
def _get_signatures(model_path: Text, signatures: Sequence[Text],
                    tags: Sequence[Text]) -> Sequence[_Signature]:
  """Returns a sequence of {model_signature_name: signature}."""

  if signatures:
    signature_names = signatures
  else:
    signature_names = [tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

  saved_model_pb = loader_impl.parse_saved_model(model_path)
  meta_graph_def = _get_meta_graph_def(saved_model_pb, tags)
  result = []
  for signature_name in signature_names:
    if signature_name in meta_graph_def.signature_def:
      result.append(
          _Signature(signature_name,
                     meta_graph_def.signature_def[signature_name]))
    else:
      raise RuntimeError('Signature %s could not be found in SavedModel' %
                         signature_name)
  return result
示例#28
0
def load(path, compile=True, options=None):  # pylint: disable=redefined-builtin
  """Loads Keras objects from a SavedModel.

  Any Keras layer or model saved to the SavedModel will be loaded back
  as Keras objects. Other objects are loaded as regular trackable objects (same
  as `tf.saved_model.load`).

  Currently, Keras saving/loading only retains the Keras object's weights,
  losses, and call function.

  The loaded model can be re-compiled, but the original optimizer, compiled loss
  functions, and metrics are not retained. This is temporary, and `model.save`
  will soon be able to serialize compiled models.

  Args:
    path: Path to SavedModel.
    compile: If true, compile the model after loading it.
    options: Optional `tf.saved_model.LoadOptions` object that specifies
      options for loading from SavedModel.


  Returns:
    Object loaded from SavedModel.
  """
  # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
  # TODO(kathywu): Add code to load from objects that contain all endpoints

  # The Keras metadata file is not yet saved, so create it from the SavedModel.
  metadata = saved_metadata_pb2.SavedMetadata()
  meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0]
  object_graph_def = meta_graph_def.object_graph_def
  # TODO(kathywu): When the keras metadata file is saved, load it directly
  # instead of calling the _read_legacy_metadata function.
  _read_legacy_metadata(object_graph_def, metadata)

  if not metadata.nodes:
    # When there are no Keras objects, return the results from the core loader
    return tf_load.load(path, options=options)

  # Recreate layers and metrics using the info stored in the metadata.
  keras_loader = KerasObjectLoader(metadata, object_graph_def)
  keras_loader.load_layers(compile=compile)

  # Generate a dictionary of all loaded nodes.
  nodes_to_load = {'root': None}
  for node_id, loaded_node in keras_loader.loaded_nodes.items():
    nodes_to_load[keras_loader.get_path(node_id)] = loaded_node
  loaded = tf_load.load_partial(path, nodes_to_load, options=options)

  # Finalize the loaded layers and remove the extra tracked dependencies.
  keras_loader.finalize_objects()
  keras_loader.del_tracking()

  model = loaded['root']

  # pylint: disable=protected-access
  if isinstance(model, training_lib.Model) and compile:
    # TODO(kathywu): Use compiled objects from SavedModel, instead of
    # creating new objects from the training config.
    training_config = model._serialized_attributes['metadata'].get(
        'training_config', None)
    if training_config is not None:
      model.compile(**saving_utils.compile_args_from_training_config(
          training_config))
      saving_utils.try_build_compiled_arguments(model)
    else:
      logging.warning('No training configuration found in save file, so the '
                      'model was *not* compiled. Compile it manually.')
  # pylint: enable=protected-access

  # Force variables and resources to initialize.
  if not context.executing_eagerly():
    sess = backend.get_session()  # Variables are initialized by this call.
    sess.run(ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS))

  return model
示例#29
0
def load(path, compile=True, options=None):  # pylint: disable=redefined-builtin
  """Loads Keras objects from a SavedModel.

  Any Keras layer or model saved to the SavedModel will be loaded back
  as Keras objects. Other objects are loaded as regular trackable objects (same
  as `tf.saved_model.load`).

  Currently, Keras saving/loading only retains the Keras object's weights,
  losses, and call function.

  The loaded model can be re-compiled, but the original optimizer, compiled loss
  functions, and metrics are not retained. This is temporary, and `model.save`
  will soon be able to serialize compiled models.

  Args:
    path: Path to SavedModel.
    compile: If true, compile the model after loading it.
    options: Optional `tf.saved_model.LoadOptions` object that specifies
      options for loading from SavedModel.


  Returns:
    Object loaded from SavedModel.
  """
  # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
  # TODO(kathywu): Add code to load from objects that contain all endpoints

  # Look for metadata file or parse the SavedModel
  metadata = saved_metadata_pb2.SavedMetadata()
  meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0]
  object_graph_def = meta_graph_def.object_graph_def
  path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH)
  if gfile.Exists(path_to_metadata_pb):
    try:
      with gfile.GFile(path_to_metadata_pb, 'rb') as f:
        file_content = f.read()
      metadata.ParseFromString(file_content)
    except message.DecodeError as e:
      raise IOError('Cannot parse keras metadata {}: {}.'
                    .format(path_to_metadata_pb, str(e)))
  else:
    logging.warning('SavedModel saved prior to TF 2.4 detected when loading '
                    'Keras model. Please ensure that you are saving the model '
                    'with model.save() or tf.keras.models.save_model(), *NOT* '
                    'tf.saved_model.save(). To confirm, there should be a file '
                    'named "keras_metadata.pb" in the SavedModel directory.')
    _read_legacy_metadata(object_graph_def, metadata)

  if not metadata.nodes:
    # When there are no Keras objects, return the results from the core loader
    return tf_load.load(path, options=options)

  # Recreate layers and metrics using the info stored in the metadata.
  keras_loader = KerasObjectLoader(metadata, object_graph_def)
  keras_loader.load_layers(compile=compile)

  # Generate a dictionary of all loaded nodes.
  nodes_to_load = {'root': None}
  for node_id, loaded_node in keras_loader.loaded_nodes.items():
    nodes_to_load[keras_loader.get_path(node_id)] = loaded_node
  loaded = tf_load.load_partial(path, nodes_to_load, options=options)

  # Finalize the loaded layers and remove the extra tracked dependencies.
  keras_loader.finalize_objects()
  keras_loader.del_tracking()

  model = loaded['root']

  # pylint: disable=protected-access
  if isinstance(model, training_lib.Model) and compile:
    # TODO(kathywu): Use compiled objects from SavedModel, instead of
    # creating new objects from the training config.
    training_config = model._serialized_attributes['metadata'].get(
        'training_config', None)
    if training_config is not None:
      model.compile(**saving_utils.compile_args_from_training_config(
          training_config))
      saving_utils.try_build_compiled_arguments(model)
    else:
      logging.warning('No training configuration found in save file, so the '
                      'model was *not* compiled. Compile it manually.')
  # pylint: enable=protected-access

  # Force variables and resources to initialize.
  if not context.executing_eagerly():
    sess = backend.get_session()  # Variables are initialized by this call.
    sess.run(ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS))

  return model
示例#30
0
def load(export_dir, tags=None):
    """Load a SavedModel from `export_dir`.

  Signatures associated with the SavedModel are available as functions:

  ```python
  imported = tf.saved_model.load(path)
  f = imported.signatures["serving_default"]
  print(f(x=tf.constant([[1.]])))
  ```

  Objects exported with `tf.saved_model.save` additionally have checkpointable
  objects and functions assigned to attributes:

  ```python
  exported = tf.train.Checkpoint(v=tf.Variable(3.))
  exported.f = tf.function(
      lambda x: exported.v * x,
      input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
  tf.saved_model.save(exported, path)
  imported = tf.saved_model.load(path)
  assert 3. == imported.v.numpy()
  assert 6. == imported.f(x=tf.constant(2.)).numpy()
  ```

  Args:
    export_dir: The SavedModel directory to load from.
    tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
      if the SavedModel contains a single MetaGraph, as for those exported from
      `tf.saved_model.load`.

  Returns:
    A checkpointable object with a `signatures` attribute mapping from signature
    keys to functions. If the SavedModel was exported by `tf.saved_model.load`,
    it also points to checkpointable objects and functions which were attached
    to the exported object.

  Raises:
    ValueError: If `tags` don't match a MetaGraph in the SavedModel.
  """
    if tags is not None:
        # Supports e.g. tags=SERVING and tags=[SERVING]
        tags = nest.flatten(tags)
    saved_model_proto = loader_impl.parse_saved_model(export_dir)
    object_graph_filename = os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY),
        compat.as_bytes("object_graph.pb"))
    if (file_io.file_exists(object_graph_filename)
            and len(saved_model_proto.meta_graphs) == 1):
        meta_graph_def = saved_model_proto.meta_graphs[0]
        if (tags is not None
                and set(tags) != set(meta_graph_def.meta_info_def.tags)):
            raise ValueError((
                "The SavedModel at {} has one MetaGraph with tags {}, but got an "
                "incompatible argument tags={} to tf.saved_model.load. You may omit "
                "it, pass 'None', or pass matching tags.").format(
                    export_dir, meta_graph_def.meta_info_def.tags, tags))
        object_graph_proto = _load_saved_object_graph_proto(
            object_graph_filename)
        with ops.init_scope():
            loader = _Loader(object_graph_proto, saved_model_proto, export_dir)
            root = loader.get(0)
    else:
        with ops.init_scope():
            root = load_v1_in_v2.load(export_dir, tags)
    return root
示例#31
0
def load(export_dir, tags=None):
    """Load a SavedModel from `export_dir`.

  Signatures associated with the SavedModel are available as functions:

  ```python
  imported = tf.saved_model.load(path)
  f = imported.signatures["serving_default"]
  print(f(x=tf.constant([[1.]])))
  ```

  Objects exported with `tf.saved_model.save` additionally have trackable
  objects and functions assigned to attributes:

  ```python
  exported = tf.train.Checkpoint(v=tf.Variable(3.))
  exported.f = tf.function(
      lambda x: exported.v * x,
      input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
  tf.saved_model.save(exported, path)
  imported = tf.saved_model.load(path)
  assert 3. == imported.v.numpy()
  assert 6. == imported.f(x=tf.constant(2.)).numpy()
  ```

  _Importing SavedModels from TensorFlow 1.x_

  SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat
  graph instead of `tf.function` objects. These SavedModels will have functions
  corresponding to their signatures in the `.signatures` attribute, but also
  have a `.prune` method which allows you to extract functions for new
  subgraphs. This is equivalent to importing the SavedModel and naming feeds and
  fetches in a Session from TensorFlow 1.x.

  ```python
  imported = tf.saved_model.load(path_to_v1_saved_model)
  pruned = imported.prune("x:0", "out:0")
  pruned(tf.ones([]))
  ```

  See `tf.compat.v1.wrap_function` for details. These SavedModels also have a
  `.variables` attribute containing imported variables, and a `.graph` attribute
  representing the whole imported graph. For SavedModels exported from
  `tf.saved_model.save`, variables are instead assigned to whichever attributes
  they were assigned before export.

  Args:
    export_dir: The SavedModel directory to load from.
    tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
      if the SavedModel contains a single MetaGraph, as for those exported from
      `tf.saved_model.load`.

  Returns:
    A trackable object with a `signatures` attribute mapping from signature
    keys to functions. If the SavedModel was exported by `tf.saved_model.load`,
    it also points to trackable objects and functions which were attached
    to the exported object.

  Raises:
    ValueError: If `tags` don't match a MetaGraph in the SavedModel.
  """
    if tags is not None and not isinstance(tags, set):
        # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
        # sequences for nest.flatten, so we put those through as-is.
        tags = nest.flatten(tags)
    saved_model_proto = loader_impl.parse_saved_model(export_dir)
    if (len(saved_model_proto.meta_graphs) == 1
            and saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
        meta_graph_def = saved_model_proto.meta_graphs[0]
        if (tags is not None
                and set(tags) != set(meta_graph_def.meta_info_def.tags)):
            raise ValueError((
                "The SavedModel at {} has one MetaGraph with tags {}, but got an "
                "incompatible argument tags={} to tf.saved_model.load. You may omit "
                "it, pass 'None', or pass matching tags.").format(
                    export_dir, meta_graph_def.meta_info_def.tags, tags))
        object_graph_proto = meta_graph_def.object_graph_def
        with ops.init_scope():
            loader = _Loader(object_graph_proto, saved_model_proto, export_dir)
            root = loader.get(0)
        root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
        root.tensorflow_git_version = (
            meta_graph_def.meta_info_def.tensorflow_git_version)
    else:
        with ops.init_scope():
            root = load_v1_in_v2.load(export_dir, tags)
    return root
示例#32
0
def load(export_dir, tags=None):
  """Load a SavedModel from `export_dir`.

  Signatures associated with the SavedModel are available as functions:

  ```python
  imported = tf.saved_model.load(path)
  f = imported.signatures["serving_default"]
  print(f(x=tf.constant([[1.]])))
  ```

  Objects exported with `tf.saved_model.save` additionally have trackable
  objects and functions assigned to attributes:

  ```python
  exported = tf.train.Checkpoint(v=tf.Variable(3.))
  exported.f = tf.function(
      lambda x: exported.v * x,
      input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
  tf.saved_model.save(exported, path)
  imported = tf.saved_model.load(path)
  assert 3. == imported.v.numpy()
  assert 6. == imported.f(x=tf.constant(2.)).numpy()
  ```

  _Importing SavedModels from TensorFlow 1.x_

  SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat
  graph instead of `tf.function` objects. These SavedModels will have functions
  corresponding to their signatures in the `.signatures` attribute, but also
  have a `.prune` method which allows you to extract functions for new
  subgraphs. This is equivalent to importing the SavedModel and naming feeds and
  fetches in a Session from TensorFlow 1.x.

  ```python
  imported = tf.saved_model.load(path_to_v1_saved_model)
  pruned = imported.prune("x:0", "out:0")
  pruned(tf.ones([]))
  ```

  See `tf.compat.v1.wrap_function` for details. These SavedModels also have a
  `.variables` attribute containing imported variables, and a `.graph` attribute
  representing the whole imported graph. For SavedModels exported from
  `tf.saved_model.save`, variables are instead assigned to whichever attributes
  they were assigned before export.

  Args:
    export_dir: The SavedModel directory to load from.
    tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
      if the SavedModel contains a single MetaGraph, as for those exported from
      `tf.saved_model.load`.

  Returns:
    A trackable object with a `signatures` attribute mapping from signature
    keys to functions. If the SavedModel was exported by `tf.saved_model.load`,
    it also points to trackable objects and functions which were attached
    to the exported object.

  Raises:
    ValueError: If `tags` don't match a MetaGraph in the SavedModel.
  """
  if tags is not None and not isinstance(tags, set):
    # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
    # sequences for nest.flatten, so we put those through as-is.
    tags = nest.flatten(tags)
  saved_model_proto = loader_impl.parse_saved_model(export_dir)
  if (len(saved_model_proto.meta_graphs) == 1
      and saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
    meta_graph_def = saved_model_proto.meta_graphs[0]
    if (tags is not None
        and set(tags) != set(meta_graph_def.meta_info_def.tags)):
      raise ValueError(
          ("The SavedModel at {} has one MetaGraph with tags {}, but got an "
           "incompatible argument tags={} to tf.saved_model.load. You may omit "
           "it, pass 'None', or pass matching tags.")
          .format(export_dir, meta_graph_def.meta_info_def.tags, tags))
    object_graph_proto = meta_graph_def.object_graph_def
    with ops.init_scope():
      loader = _Loader(object_graph_proto,
                       saved_model_proto,
                       export_dir)
      root = loader.get(0)
    root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
    root.tensorflow_git_version = (
        meta_graph_def.meta_info_def.tensorflow_git_version)
  else:
    with ops.init_scope():
      root = load_v1_in_v2.load(export_dir, tags)
  return root
示例#33
0
def load(export_dir, tags=None):
  """Load a SavedModel from `export_dir`.

  Signatures associated with the SavedModel are available as functions:

  ```python
  imported = tf.saved_model.load(path)
  f = imported.signatures["serving_default"]
  print(f(x=tf.constant([[1.]])))
  ```

  Objects exported with `tf.saved_model.save` additionally have trackable
  objects and functions assigned to attributes:

  ```python
  exported = tf.train.Checkpoint(v=tf.Variable(3.))
  exported.f = tf.function(
      lambda x: exported.v * x,
      input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
  tf.saved_model.save(exported, path)
  imported = tf.saved_model.load(path)
  assert 3. == imported.v.numpy()
  assert 6. == imported.f(x=tf.constant(2.)).numpy()
  ```

  Args:
    export_dir: The SavedModel directory to load from.
    tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
      if the SavedModel contains a single MetaGraph, as for those exported from
      `tf.saved_model.load`.

  Returns:
    A trackable object with a `signatures` attribute mapping from signature
    keys to functions. If the SavedModel was exported by `tf.saved_model.load`,
    it also points to trackable objects and functions which were attached
    to the exported object.

  Raises:
    ValueError: If `tags` don't match a MetaGraph in the SavedModel.
  """
  if tags is not None and not isinstance(tags, set):
    # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
    # sequences for nest.flatten, so we put those through as-is.
    tags = nest.flatten(tags)
  saved_model_proto = loader_impl.parse_saved_model(export_dir)
  if (len(saved_model_proto.meta_graphs) == 1
      and saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
    meta_graph_def = saved_model_proto.meta_graphs[0]
    if (tags is not None
        and set(tags) != set(meta_graph_def.meta_info_def.tags)):
      raise ValueError(
          ("The SavedModel at {} has one MetaGraph with tags {}, but got an "
           "incompatible argument tags={} to tf.saved_model.load. You may omit "
           "it, pass 'None', or pass matching tags.")
          .format(export_dir, meta_graph_def.meta_info_def.tags, tags))
    object_graph_proto = meta_graph_def.object_graph_def
    with ops.init_scope():
      loader = _Loader(object_graph_proto,
                       saved_model_proto,
                       export_dir)
      root = loader.get(0)
  else:
    with ops.init_scope():
      root = load_v1_in_v2.load(export_dir, tags)
  return root
示例#34
0
 def assert_saved_model(self, path):
   loader_impl.parse_saved_model(path)
示例#35
0
 def _GetSignatureDef(export_dir):
     saved_model_proto = loader_impl.parse_saved_model(export_dir)
     self.assertEqual(1, len(saved_model_proto.meta_graphs))
     meta_graph = saved_model_proto.meta_graphs[0]
     self.assertIn(signature_key, meta_graph.signature_def)
     return meta_graph.signature_def[signature_key]