Exemplo n.º 1
0
def _restore_from_v1_saved_model(
    restored_function: function.ConcreteFunction, saved_model_dir: str
) -> Tuple[function.ConcreteFunction, Mapping[str, Any], Mapping[
    str, common_types.TensorType]]:
  """Restores an exported TF1 compat SavedModel."""
  saved_model = saved_model_loader.parse_saved_model(saved_model_dir)
  meta_graph_def = saved_model_loader.choose_meta_graph_def_and_raise(
      saved_model)
  signature = meta_graph_def.signature_def[constants.TRANSFORM_SIGNATURE]
  # Re-register pyfuncs, if any.
  graph_def = pyfunc_helper.register_pyfuncs_from_saved_transform(
      restored_function.graph, meta_graph_def, loaded_in_tf2=True)
  if graph_def is None:
    return (restored_function, signature.inputs,
            restored_function.structured_outputs)

  inputs = [t.name for t in restored_function.graph.inputs]
  outputs = [t.name for t in restored_function.graph.outputs]
  wrapped = wrap_function.function_from_graph_def(graph_def, inputs, outputs)
  structured_outputs = (
      tf.nest.pack_sequence_as(
          restored_function.structured_outputs,
          wrapped.outputs,
          expand_composites=True))
  wrapped = wrapped.prune(wrapped.inputs, structured_outputs)
  return (wrapped, signature.inputs, wrapped.structured_outputs)
 def _get_input_signature_from_v1_saved_model(self, saved_model_dir):
     """Get structured inputs for a TF1 compat SavedModel."""
     saved_model = saved_model_loader.parse_saved_model(saved_model_dir)
     meta_graph_def = saved_model_loader.choose_meta_graph_def_and_raise(
         saved_model)
     signature = meta_graph_def.signature_def[constants.TRANSFORM_SIGNATURE]
     return signature.inputs
Exemplo n.º 3
0
def _load_transform_saved_model(transform_savedmodel_dir):
  """Load a SavedModel representing a transform function from disk.

  Args:
    transform_savedmodel_dir: a SavedModel directory.

  Returns:
    A tuple with a `MetaGraphDef` proto, the input and outputs of a
    `SignatureDef` proto, and a dict from tensor names to absolute paths for
    asset filepaths.
  """
  saved_model = saved_model_loader.parse_saved_model(
      transform_savedmodel_dir)
  meta_graph_def = saved_model_loader.choose_meta_graph_def_and_raise(
      saved_model)

  signature = meta_graph_def.signature_def[constants.TRANSFORM_SIGNATURE]
  # The following code handles models produced prior to CL/200123875.  These
  # models used a non-standard naming convention for features in order to
  # support SparseTensor.
  # TODO(b/34253951): Remove the following code once we no longer want to
  # support the legacy formats.
  _update_legacy_signature(signature)

  # maps name to TensorInfo
  input_signature = signature.inputs
  output_signature = signature.outputs

  # asset_path_dict is {string: string}, mapping tensor names to absolute paths.
  asset_path_dict = saved_model_loader.get_asset_tensors(
      transform_savedmodel_dir, meta_graph_def)

  return meta_graph_def, input_signature, output_signature, asset_path_dict