def _get_input_signature_from_v1_saved_model(self, saved_model_dir):
     """Get structured inputs for a TF1 compat SavedModel."""
     saved_model = saved_model_loader.parse_saved_model(saved_model_dir)
     meta_graph_def = saved_model_loader.choose_meta_graph_def_and_raise(
         saved_model)
     signature = meta_graph_def.signature_def[constants.TRANSFORM_SIGNATURE]
     return signature.inputs
예제 #2
0
def _load_transform_saved_model(transform_savedmodel_dir):
    """Load a SavedModel representing a transform function from disk.

  Args:
    transform_savedmodel_dir: a SavedModel directory.

  Returns:
    A tuple with a `MetaGraphDef` proto, the input and outputs of a
    `SignatureDef` proto, and a dict from tensor names to absolute paths for
    asset filepaths.
  """
    saved_model = saved_model_loader.parse_saved_model(
        transform_savedmodel_dir)
    meta_graph_def = saved_model_loader.choose_meta_graph_def(
        saved_model, [constants.TRANSFORM_TAG])

    signature = meta_graph_def.signature_def[constants.TRANSFORM_SIGNATURE]
    # The following code handles models produced prior to CL/200123875.  These
    # models used a non-standard naming convention for features in order to
    # support SparseTensor.
    # TODO(b/34253951): Remove the following code once we no longer want to
    # support the legacy formats.
    _update_legacy_signature(signature)

    # maps name to TensorInfo
    input_signature = signature.inputs
    output_signature = signature.outputs

    # asset_path_dict is {string: string}, mapping tensor names to absolute paths.
    asset_path_dict = saved_model_loader.get_asset_tensors(
        transform_savedmodel_dir, meta_graph_def)

    return meta_graph_def, input_signature, output_signature, asset_path_dict
예제 #3
0
def _load_transform_saved_model(transform_savedmodel_dir):
  """Load a SavedModel representing a transform function from disk.

  Args:
    transform_savedmodel_dir: a SavedModel directory.

  Returns:
    A `SavedModel` protocol buffer.
  """
  saved_model = saved_model_loader.parse_saved_model(
      transform_savedmodel_dir)
  meta_graph_def = saved_model_loader.choose_meta_graph_def(
      saved_model, [constants.TRANSFORM_TAG])

  signature = meta_graph_def.signature_def[constants.TRANSFORM_SIGNATURE]

  # maps name to TensorInfo
  input_signature = {logical_name: tensor_info.name
                     for logical_name, tensor_info
                     in six.iteritems(signature.inputs)}
  output_signature = {logical_name: tensor_info.name
                      for logical_name, tensor_info
                      in six.iteritems(signature.outputs)}

  # asset_path_dict is {string: string}, mapping tensor names to absolute paths.
  asset_path_dict = saved_model_loader.get_asset_tensors(
      transform_savedmodel_dir, meta_graph_def)

  return meta_graph_def, input_signature, output_signature, asset_path_dict
예제 #4
0
def _restore_from_v1_saved_model(
    restored_function: function.ConcreteFunction, saved_model_dir: str
) -> Tuple[function.ConcreteFunction, Mapping[str, Any], Mapping[
    str, common_types.TensorType]]:
  """Restores an exported TF1 compat SavedModel."""
  saved_model = saved_model_loader.parse_saved_model(saved_model_dir)
  meta_graph_def = saved_model_loader.choose_meta_graph_def_and_raise(
      saved_model)
  signature = meta_graph_def.signature_def[constants.TRANSFORM_SIGNATURE]
  # Re-register pyfuncs, if any.
  graph_def = pyfunc_helper.register_pyfuncs_from_saved_transform(
      restored_function.graph, meta_graph_def, loaded_in_tf2=True)
  if graph_def is None:
    return (restored_function, signature.inputs,
            restored_function.structured_outputs)

  inputs = [t.name for t in restored_function.graph.inputs]
  outputs = [t.name for t in restored_function.graph.outputs]
  wrapped = wrap_function.function_from_graph_def(graph_def, inputs, outputs)
  structured_outputs = (
      tf.nest.pack_sequence_as(
          restored_function.structured_outputs,
          wrapped.outputs,
          expand_composites=True))
  wrapped = wrapped.prune(wrapped.inputs, structured_outputs)
  return (wrapped, signature.inputs, wrapped.structured_outputs)
예제 #5
0
def _load_transform_saved_model(transform_savedmodel_dir):
  """Load a SavedModel representing a transform function from disk.

  Args:
    transform_savedmodel_dir: a SavedModel directory.

  Returns:
    A tuple with a `MetaGraphDef` proto, the input and outputs of a
    `SignatureDef` proto, and a dict from tensor names to absolute paths for
    asset filepaths.
  """
  saved_model = saved_model_loader.parse_saved_model(
      transform_savedmodel_dir)
  meta_graph_def = saved_model_loader.choose_meta_graph_def(
      saved_model, [constants.TRANSFORM_TAG])

  signature = meta_graph_def.signature_def[constants.TRANSFORM_SIGNATURE]

  # maps name to TensorInfo
  input_signature = signature.inputs
  output_signature = signature.outputs

  # asset_path_dict is {string: string}, mapping tensor names to absolute paths.
  asset_path_dict = saved_model_loader.get_asset_tensors(
      transform_savedmodel_dir, meta_graph_def)

  return meta_graph_def, input_signature, output_signature, asset_path_dict
예제 #6
0
def _load_transform_saved_model(transform_savedmodel_dir):
    """Load a SavedModel representing a transform function from disk.

  Args:
    transform_savedmodel_dir: a SavedModel directory.

  Returns:
    A `SavedModel` protocol buffer.
  """
    saved_model = saved_model_loader.parse_saved_model(
        transform_savedmodel_dir)
    meta_graph_def = saved_model_loader.choose_meta_graph_def(
        saved_model, [constants.TRANSFORM_TAG])

    signature = meta_graph_def.signature_def[constants.TRANSFORM_SIGNATURE]

    # maps name to TensorInfo
    input_signature = {
        logical_name: tensor_info.name
        for logical_name, tensor_info in signature.inputs.items()
    }
    output_signature = {
        logical_name: tensor_info.name
        for logical_name, tensor_info in signature.outputs.items()
    }

    init_feed_dict = saved_model_loader.get_asset_tensors(
        transform_savedmodel_dir, meta_graph_def)

    if init_feed_dict:
        raise NotImplementedError('tf.Transform does not yet support assets.')

    return meta_graph_def, input_signature, output_signature
예제 #7
0
def exported_as_v1(transform_savedmodel_dir):
  """Check if a SavedModel was exported as a TF 1 model or not.

  Args:
    transform_savedmodel_dir: a SavedModel directory.

  Returns:
    `True` if `transform_savedmodel_dir` contains a TF1 SavedModel else
    returns `False`.
  """
  saved_model = saved_model_loader.parse_saved_model(transform_savedmodel_dir)
  meta_graph_def = saved_model_loader.choose_meta_graph_def(saved_model)
  return meta_graph_def is not None
예제 #8
0
def _load_transform_saved_model(transform_savedmodel_dir):
  """Load a SavedModel representing a transform function from disk.

  Args:
    transform_savedmodel_dir: a SavedModel directory.

  Returns:
    A tuple with a `MetaGraphDef` proto, the input and outputs of a
    `SignatureDef` proto, and a dict from tensor names to tf.saved_model.Assets
    of absolute paths for asset filepaths.
  """
  saved_model = saved_model_loader.parse_saved_model(
      transform_savedmodel_dir)
  meta_graph_def = saved_model_loader.choose_meta_graph_def(
      saved_model, [constants.TRANSFORM_TAG])

  signature = meta_graph_def.signature_def[constants.TRANSFORM_SIGNATURE]
  # The following code handles models produced prior to CL/200123875.  These
  # models used a non-standard naming convention for features in order to
  # support SparseTensor.
  # TODO(b/34253951): Remove the following code once we no longer want to
  # support the legacy formats.
  _update_legacy_signature(signature)

  # maps name to TensorInfo
  input_signature = signature.inputs
  output_signature = signature.outputs

  # TODO(zoyahav): Remove this branch when TFT no longer supports TF 1.15.
  if hasattr(tf.saved_model, 'Asset'):
    Asset = tf.saved_model.Asset  # pylint: disable=invalid-name
  else:
    from tensorflow.python.training.tracking import tracking  # pylint: disable=g-direct-tensorflow-import, g-import-not-at-top
    Asset = tracking.TrackableAsset  # pylint: disable=invalid-name

  # asset_path_dict is {string: string}, mapping tensor names to absolute paths.
  asset_path_dict = saved_model_loader.get_asset_tensors(
      transform_savedmodel_dir, meta_graph_def)
  assets_dict = {k: Asset(v) for k, v in asset_path_dict.items()}

  return meta_graph_def, input_signature, output_signature, assets_dict
예제 #9
0
 def _get_input_signature(self, saved_model_dir):
     saved_model = saved_model_loader.parse_saved_model(saved_model_dir)
     meta_graph_def = saved_model_loader.choose_meta_graph_def(
         saved_model, [constants.TRANSFORM_TAG])
     signature = meta_graph_def.signature_def[constants.TRANSFORM_SIGNATURE]
     return signature.inputs