Esempio n. 1
def _get_temporary_analyzer_output(
    temp_dir: str,
    tensor_info: TensorInfo,
    name: Optional[str] = None) -> TemporaryAnalyzerOutputWrapper:
  """Create a temporary graph tensor using attributes in `tensor_info`.

    temp_dir: Path to a directory to write out any temporary asset files to.
    tensor_info: A `TensorInfo` object containing attributes to create the graph
    name: A string (or None). The created graph tensor uses this name.

    A named tuple `TemporaryAnalyzerOutputWrapper` with:
      asset: If the graph tensor represents a path to an asset file, a
           `tf.saved_model.Asset` object for tracking. Else, None.
      graph_tensor: The graph tensor
  asset = None
  with tf.name_scope('temporary_analyzer_output'):
    is_asset_filepath = tensor_info.temporary_asset_value is not None
    if is_asset_filepath:
      # Placeholders cannot be used for assets, if this graph will be serialized
      # to a SavedModel, as they will be initialized with the init op. If a
      # `temp_dir` is provided, it is assumed that this graph will be
      # serialized and a temporary asset file is written out . Else, a
      # placeholder is returned.
      # TODO(b/164921571) Support temporary files in tfrecord format.
      # TODO(b/149997088): Reduce number of temporary files written out.
      if temp_dir:
        with tf.init_scope():
          temporary_asset_filepath = os.path.join(temp_dir, uuid.uuid4().hex)
          with, 'w') as f:
          # Wrap asset files using `tf.saved_model.Asset` to ensure that
          # `SavedModel`s exported are hermetic.
          asset = common_types.Asset(temporary_asset_filepath)
        graph_tensor = tf.constant(
        graph_tensor = tf.raw_ops.Placeholder(
            dtype=tensor_info.dtype, shape=tensor_info.shape, name=name)
      # Using a placeholder with no default value causes tracing to fail if
      # there is any control flow dependent on a child tensor of this
      # placeholder. Hence, provide a temporary default value for it.
      # If dtype is string, we want a tensor that contains '0's instead of b'[]
      # to allow string to numeric conversion ops to trace successfully.
      temporary_dtype = (
          tf.int64 if tensor_info.dtype == tf.string else tensor_info.dtype)
      temporary_tensor = tf2_utils.supply_missing_tensor(
          1, tf.TensorShape(tensor_info.shape), temporary_dtype)
      if tensor_info.dtype == tf.string:
        temporary_tensor = tf.strings.as_string(temporary_tensor)
      graph_tensor = tf.raw_ops.PlaceholderWithDefault(
          input=temporary_tensor, shape=tensor_info.shape, name=name)
    return TemporaryAnalyzerOutputWrapper(asset, graph_tensor)
def write_v2_saved_model(tf_function: function.Function, name: str,
                         saved_model_dir: str) -> function.ConcreteFunction:
    """Writes `tf_function` under attr `name` to `saved_model_dir`."""
    module = tf.Module()

    resource_tracker = tracking.ResourceTracker()
    object_tracker = annotators.ObjectTracker()
    created_variables = []

    def _variable_creator(next_creator, **kwargs):
        var = next_creator(**kwargs)
        return var

    # TODO(b/164921571): Handle generic Trackable objects.
    # Trace `tf_function` to gather any resources in it using the
    # resource_tracker. These are then assigned to `module.resources` and tracked
    # before exporting to SavedModel.
    with tracking.resource_tracker_scope(resource_tracker), \
         annotators.object_tracker_scope(object_tracker), \
        concrete_fn = tf_function.get_concrete_function()

    # Prior to 2020/10/08, saving a tf.function with a concrete function signature
    # would ensure that the function was not re-traced in a round-trip to a
    # SavedModel. Since this is no longer the case, we save the concrete function
    # directly.
    if tf.compat.forward_compatible(2020, 10, 8):
        pruned_function = optimize_concrete_function(concrete_fn)
        module.pruned_variables = pruned_function.variables
        setattr(module, name, pruned_function)
        setattr(module, name, tf_function)

    # Any variables created need to be explicitly tracked.
    module.created_variables = created_variables
    # Resources need to be explicitly tracked.
    module.resources = resource_tracker.resources
    module.trackable_objects = object_tracker.trackable_objects
    # TODO(b/158011374) - Stop explicitly tracking initializers. Tracking the
    # table should be sufficient.
    initializers = []
    for resource in module.resources:
        if isinstance(resource, lookup_ops.InitializableLookupTableBase):
            initializers.append(resource._initializer)  # pylint: disable=protected-access
    module.initializers = initializers
    module.assets = [
        for asset_filepath in concrete_fn.graph.get_collection(
    ], saved_model_dir)
    return concrete_fn
Esempio n. 3
def trace_and_update_module(
    module: tf.Module, tf_function: function.Function, name: str,
    strip_control_dependencies: bool) -> function.ConcreteFunction:
  """Traces `tf_function` and saves under attr `name` of `module`.

    module: A saveable module which will contain the traced `tf_function` under
      attr `name`.
    tf_function: A tf.function to trace.
    name: A name to same the traced `tf_function` to.
    strip_control_dependencies: Boolean. If True, automatic control dependencies
      will be stripped from the outputs of `tf_function`. This should almost
      always be False. It is useful only if you want to use the structure of the
      TF graph to perform any graph manipulations.

    The concrete function obtained from tracing `tf_function`.
  resource_tracker = tracking.ResourceTracker()
  object_tracker = annotators.ObjectTracker()
  created_variables = []

  def _variable_creator(next_creator, **kwargs):
    var = next_creator(**kwargs)
    return var

  # Trace `tf_function` to gather any resources in it using the
  # resource_tracker. These are then assigned to `module.resources` and tracked
  # before exporting to SavedModel.
  with tracking.resource_tracker_scope(resource_tracker), \
       annotators.object_tracker_scope(object_tracker), \
    concrete_fn = tf_function.get_concrete_function()

  # Prior to 2020/10/08, saving a tf.function with a concrete function signature
  # would ensure that the function was not re-traced in a round-trip to a
  # SavedModel. Since this is no longer the case, we save the concrete function
  # directly.
  if tf.compat.forward_compatible(2020, 10, 8):
    pruned_function = optimize_concrete_function(concrete_fn,
    module.pruned_variables = pruned_function.variables
    setattr(module, name, pruned_function)
    setattr(module, name, tf_function)

  # Any variables created need to be explicitly tracked.
  module.created_variables = created_variables
  # Resources need to be explicitly tracked.
  module.resources = resource_tracker.resources
  module.trackable_objects = object_tracker.trackable_objects
  # TODO(b/158011374) - Stop explicitly tracking initializers. Tracking the
  # table should be sufficient.
  initializers = []
  for resource in module.resources:
    if isinstance(resource, lookup_ops.InitializableLookupTableBase):
      initializers.append(resource._initializer)  # pylint: disable=protected-access
  module.initializers = initializers
  module.assets = [
      common_types.Asset(asset_filepath) for asset_filepath in
  return concrete_fn
Esempio n. 4
def trace_and_write_v2_saved_model(saved_model_dir, preprocessing_fn,
                                   input_signature, base_temp_dir,
    """Writes out a SavedModelV2 with preprocessing_fn traced using tf.function.

  The SavedModel written contains a method called `transform_fn` that
  represents the traced `preprocessing_fn`. Additionally, if this is the final
  SavedModel being written out, it will contain a method called `metadata_fn`
  that provides deferred schema annotations.

    saved_model_dir: Path to write SavedModel to.
    preprocessing_fn: A user defined python function to be traced.
    input_signature: TypeSpecs describing the inputs to the `preprocessing_fn`.
    base_temp_dir: Base path to write temporary artifacts to.
    tensor_replacement_map: A map from placeholder tensor names to their
      evaluated replacement tensors.
    output_keys_to_name_map: A map from output dictionary keys to the names of
      the tensors that they represent.

    A tuple containing a pair of `tf.ConcreteFunction`s:
      1. The traced preprocessing_fn.
      2. A metadata_fn that returns a dictionary containing the deferred
      annotations added to the graph when invoked with any valid input.

    module = tf.Module()
    transform_fn = get_traced_transform_fn(
    metadata_fn = None

    resource_tracker = tracking.ResourceTracker()
    created_variables = []

    def _variable_creator(next_creator, **kwargs):
        var = next_creator(**kwargs)
        return var

    # TODO(b/164921571): Handle generic Trackable objects.
    # Trace the `transform_fn` and `metadata_fn` to gather any resources in it
    # using the resource_tracker. These are then assigned to `module.resources`
    # and tracked before exporting to SavedModel.
    with tracking.resource_tracker_scope(
            resource_tracker), tf.variable_creator_scope(_variable_creator):
        concrete_transform_fn = transform_fn.get_concrete_function()
        concrete_metadata_fn = None
        # If the `TENSOR_REPLACEMENTS` graph collection is empty, all TFT analyzers
        # in the `preprocessing_fn` have already been evaluated.
        if not concrete_transform_fn.graph.get_collection(
            metadata_fn = schema_inference.get_traced_metadata_fn(
            concrete_metadata_fn = metadata_fn.get_concrete_function()

    # Save ConcreteFunction when possible since the above workaround won't work if
    # the tf.function is retraced.
    if tf.compat.forward_compatible(2020, 10, 8):
        module.transform_fn = concrete_transform_fn
        module.metadata_fn = concrete_metadata_fn
        module.transform_fn = transform_fn
        module.metadata_fn = metadata_fn

    # Any variables created need to be explicitly tracked.
    module.created_variables = created_variables
    # Resources need to be explicitly tracked.
    module.resources = resource_tracker.resources
    # TODO(b/158011374) - Stop explicitly tracking initializers. Tracking the
    # table should be sufficient.
    initializers = []
    for resource in module.resources:
        if isinstance(resource, lookup_ops.InitializableLookupTableBase):
            initializers.append(resource._initializer)  # pylint: disable=protected-access
    module.initializers = initializers
    module.assets = [
        for asset_filepath in concrete_transform_fn.graph.get_collection(
    ], saved_model_dir)
    return concrete_transform_fn, concrete_metadata_fn
Esempio n. 5
def _bind_future_as_tensor_v2(
    future: nodes.ValueNode,
    tensor_info: TensorInfo,
    name: Optional[str] = None) -> common_types.TemporaryAnalyzerOutputType:
  """Bind a future value as a tensor to a TF2 FuncGraph.

    If the future is expected to write out an asset file and this method is
    invoked within a `TFGraphContext` that was provided a temporary directory,
    a temporary file is written out by this method.

    This could write out a significant number of temporary files depending on
    number of times the `preprocessing_fn` is traced and number of asset files
    in each tracing.

    future: Future whose result should replace the graph tensor to which its
    tensor_info: A `TensorInfo` object containing attributes to create the graph
    name: (Optional) If provided, the graph tensor created uses this name.

    A graph tensor or `tf.saved_model.Asset` that this future is bound to. If
    this future has already been evaluated in a previous TFT phase, it is
    directly returned.
  graph = ops.get_default_graph()
  temp_dir = TFGraphContext.get_or_create_temp_dir()
  temporary_analyzer_info = _get_temporary_analyzer_output(
      temp_dir, tensor_info, name)
  is_asset_filepath = tensor_info.temporary_asset_value is not None

  # TODO(b/149997088): Switch to using a counter instead of tensor names.
  # Check if an evaluated value exists for this analyzer node.
  evaluated_replacements = TFGraphContext.get_evaluated_replacements()
  # evaluated_replacements is a dictionary from placeholder name to evaluated
  # tensor.
  # If `preprocessing_fn` was traced previously and this future was then
  # evaluated in a TFT phase, the result will be present in this dictionary.
  analyzer_name =
  if (evaluated_replacements is not None and
      analyzer_name in evaluated_replacements):
    replaced_result = evaluated_replacements[analyzer_name]
    if is_asset_filepath:
      # Wrap asset files using `tf.saved_model.Asset` to ensure that
      # `SavedModel`s exported are hermetic.
      with tf.init_scope():
        asset = common_types.Asset(replaced_result)
                              TemporaryAnalyzerOutputWrapper(asset, None))
      return asset
      return replaced_result
        TensorSink(temporary_analyzer_info.graph_tensor, future,
    if is_asset_filepath and temporary_analyzer_info.asset:
      graph.add_to_collection(ASSET_REPLACEMENTS, temporary_analyzer_info)
      return temporary_analyzer_info.asset
      return temporary_analyzer_info.graph_tensor