Esempio n. 1
0
def _get_temporary_analyzer_output(
    temp_dir: str,
    tensor_info: TensorInfo,
    name: Optional[str] = None) -> TemporaryAnalyzerOutputWrapper:
  """Create a temporary graph tensor using attributes in `tensor_info`.

  Args:
    temp_dir: Path to a directory to write out any temporary asset files to.
    tensor_info: A `TensorInfo` object containing attributes to create the graph
      tensor.
    name: A string (or None). The created graph tensor uses this name.

  Returns:
    A named tuple `TemporaryAnalyzerOutputWrapper` with:
      asset: If the graph tensor represents a path to an asset file, a
           `tf.saved_model.Asset` object for tracking. Else, None.
      graph_tensor: The graph tensor
  """
  asset = None
  with tf.name_scope('temporary_analyzer_output'):
    is_asset_filepath = tensor_info.temporary_asset_value is not None
    if is_asset_filepath:
      # Placeholders cannot be used for assets, if this graph will be serialized
      # to a SavedModel, as they will be initialized with the init op. If a
      # `temp_dir` is provided, it is assumed that this graph will be
      # serialized and a temporary asset file is written out . Else, a
      # placeholder is returned.
      # TODO(b/164921571) Support temporary files in tfrecord format.
      # TODO(b/149997088): Reduce number of temporary files written out.
      if temp_dir:
        with tf.init_scope():
          temporary_asset_filepath = os.path.join(temp_dir, uuid.uuid4().hex)
          with tf.io.gfile.GFile(temporary_asset_filepath, 'w') as f:
            f.write(tensor_info.temporary_asset_value)
          # Wrap asset files using `tf.saved_model.Asset` to ensure that
          # `SavedModel`s exported are hermetic.
          asset = common_types.Asset(temporary_asset_filepath)
        graph_tensor = tf.constant(
            temporary_asset_filepath,
            dtype=tensor_info.dtype,
            shape=tensor_info.shape,
            name=name)
      else:
        graph_tensor = tf.raw_ops.Placeholder(
            dtype=tensor_info.dtype, shape=tensor_info.shape, name=name)
    else:
      # Using a placeholder with no default value causes tracing to fail if
      # there is any control flow dependent on a child tensor of this
      # placeholder. Hence, provide a temporary default value for it.
      # If dtype is string, we want a tensor that contains '0's instead of b'[]
      # to allow string to numeric conversion ops to trace successfully.
      temporary_dtype = (
          tf.int64 if tensor_info.dtype == tf.string else tensor_info.dtype)
      temporary_tensor = tf2_utils.supply_missing_tensor(
          1, tf.TensorShape(tensor_info.shape), temporary_dtype)
      if tensor_info.dtype == tf.string:
        temporary_tensor = tf.strings.as_string(temporary_tensor)
      graph_tensor = tf.raw_ops.PlaceholderWithDefault(
          input=temporary_tensor, shape=tensor_info.shape, name=name)
    return TemporaryAnalyzerOutputWrapper(asset, graph_tensor)
def write_v2_saved_model(tf_function: function.Function, name: str,
                         saved_model_dir: str) -> function.ConcreteFunction:
    """Writes `tf_function` under attr `name` to `saved_model_dir`."""
    module = tf.Module()

    resource_tracker = tracking.ResourceTracker()
    object_tracker = annotators.ObjectTracker()
    created_variables = []

    def _variable_creator(next_creator, **kwargs):
        var = next_creator(**kwargs)
        created_variables.append(var)
        return var

    # TODO(b/164921571): Handle generic Trackable objects.
    # Trace `tf_function` to gather any resources in it using the
    # resource_tracker. These are then assigned to `module.resources` and tracked
    # before exporting to SavedModel.
    with tracking.resource_tracker_scope(resource_tracker), \
         annotators.object_tracker_scope(object_tracker), \
         tf.variable_creator_scope(_variable_creator):
        concrete_fn = tf_function.get_concrete_function()

    # Prior to 2020/10/08, saving a tf.function with a concrete function signature
    # would ensure that the function was not re-traced in a round-trip to a
    # SavedModel. Since this is no longer the case, we save the concrete function
    # directly.
    if tf.compat.forward_compatible(2020, 10, 8):
        pruned_function = optimize_concrete_function(concrete_fn)
        module.pruned_variables = pruned_function.variables
        setattr(module, name, pruned_function)
    else:
        setattr(module, name, tf_function)

    # Any variables created need to be explicitly tracked.
    module.created_variables = created_variables
    # Resources need to be explicitly tracked.
    module.resources = resource_tracker.resources
    module.trackable_objects = object_tracker.trackable_objects
    # TODO(b/158011374) - Stop explicitly tracking initializers. Tracking the
    # table should be sufficient.
    initializers = []
    for resource in module.resources:
        if isinstance(resource, lookup_ops.InitializableLookupTableBase):
            initializers.append(resource._initializer)  # pylint: disable=protected-access
    module.initializers = initializers
    module.assets = [
        common_types.Asset(asset_filepath)
        for asset_filepath in concrete_fn.graph.get_collection(
            tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
    ]
    tf.saved_model.save(module, saved_model_dir)
    return concrete_fn
Esempio n. 3
0
def trace_and_update_module(
    module: tf.Module, tf_function: function.Function, name: str,
    strip_control_dependencies: bool) -> function.ConcreteFunction:
  """Traces `tf_function` and saves under attr `name` of `module`.

  Args:
    module: A saveable module which will contain the traced `tf_function` under
      attr `name`.
    tf_function: A tf.function to trace.
    name: A name to same the traced `tf_function` to.
    strip_control_dependencies: Boolean. If True, automatic control dependencies
      will be stripped from the outputs of `tf_function`. This should almost
      always be False. It is useful only if you want to use the structure of the
      TF graph to perform any graph manipulations.

  Returns:
    The concrete function obtained from tracing `tf_function`.
  """
  resource_tracker = tracking.ResourceTracker()
  object_tracker = annotators.ObjectTracker()
  created_variables = []

  def _variable_creator(next_creator, **kwargs):
    var = next_creator(**kwargs)
    created_variables.append(var)
    return var

  # Trace `tf_function` to gather any resources in it using the
  # resource_tracker. These are then assigned to `module.resources` and tracked
  # before exporting to SavedModel.
  with tracking.resource_tracker_scope(resource_tracker), \
       annotators.object_tracker_scope(object_tracker), \
       tf.variable_creator_scope(_variable_creator):
    concrete_fn = tf_function.get_concrete_function()

  # Prior to 2020/10/08, saving a tf.function with a concrete function signature
  # would ensure that the function was not re-traced in a round-trip to a
  # SavedModel. Since this is no longer the case, we save the concrete function
  # directly.
  if tf.compat.forward_compatible(2020, 10, 8):
    pruned_function = optimize_concrete_function(concrete_fn,
                                                 strip_control_dependencies)
    module.pruned_variables = pruned_function.variables
    setattr(module, name, pruned_function)
  else:
    setattr(module, name, tf_function)

  # Any variables created need to be explicitly tracked.
  module.created_variables = created_variables
  # Resources need to be explicitly tracked.
  module.resources = resource_tracker.resources
  module.trackable_objects = object_tracker.trackable_objects
  # TODO(b/158011374) - Stop explicitly tracking initializers. Tracking the
  # table should be sufficient.
  initializers = []
  for resource in module.resources:
    if isinstance(resource, lookup_ops.InitializableLookupTableBase):
      initializers.append(resource._initializer)  # pylint: disable=protected-access
  module.initializers = initializers
  module.assets = [
      common_types.Asset(asset_filepath) for asset_filepath in
      concrete_fn.graph.get_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
  ]
  return concrete_fn
Esempio n. 4
0
def trace_and_write_v2_saved_model(saved_model_dir, preprocessing_fn,
                                   input_signature, base_temp_dir,
                                   tensor_replacement_map,
                                   output_keys_to_name_map):
    """Writes out a SavedModelV2 with preprocessing_fn traced using tf.function.

  The SavedModel written contains a method called `transform_fn` that
  represents the traced `preprocessing_fn`. Additionally, if this is the final
  SavedModel being written out, it will contain a method called `metadata_fn`
  that provides deferred schema annotations.

  Args:
    saved_model_dir: Path to write SavedModel to.
    preprocessing_fn: A user defined python function to be traced.
    input_signature: TypeSpecs describing the inputs to the `preprocessing_fn`.
    base_temp_dir: Base path to write temporary artifacts to.
    tensor_replacement_map: A map from placeholder tensor names to their
      evaluated replacement tensors.
    output_keys_to_name_map: A map from output dictionary keys to the names of
      the tensors that they represent.

  Returns:
    A tuple containing a pair of `tf.ConcreteFunction`s:
      1. The traced preprocessing_fn.
      2. A metadata_fn that returns a dictionary containing the deferred
      annotations added to the graph when invoked with any valid input.
  """

    module = tf.Module()
    transform_fn = get_traced_transform_fn(
        preprocessing_fn,
        input_signature,
        base_temp_dir,
        tensor_replacement_map=tensor_replacement_map,
        output_keys_to_name_map=output_keys_to_name_map)
    metadata_fn = None

    resource_tracker = tracking.ResourceTracker()
    created_variables = []

    def _variable_creator(next_creator, **kwargs):
        var = next_creator(**kwargs)
        created_variables.append(var)
        return var

    # TODO(b/164921571): Handle generic Trackable objects.
    # Trace the `transform_fn` and `metadata_fn` to gather any resources in it
    # using the resource_tracker. These are then assigned to `module.resources`
    # and tracked before exporting to SavedModel.
    with tracking.resource_tracker_scope(
            resource_tracker), tf.variable_creator_scope(_variable_creator):
        concrete_transform_fn = transform_fn.get_concrete_function()
        concrete_metadata_fn = None
        # If the `TENSOR_REPLACEMENTS` graph collection is empty, all TFT analyzers
        # in the `preprocessing_fn` have already been evaluated.
        if not concrete_transform_fn.graph.get_collection(
                analyzer_nodes.TENSOR_REPLACEMENTS):
            metadata_fn = schema_inference.get_traced_metadata_fn(
                tensor_replacement_map,
                preprocessing_fn,
                input_signature,
                base_temp_dir,
                evaluate_schema_overrides=True)
            concrete_metadata_fn = metadata_fn.get_concrete_function()

    # Save ConcreteFunction when possible since the above workaround won't work if
    # the tf.function is retraced.
    if tf.compat.forward_compatible(2020, 10, 8):
        module.transform_fn = concrete_transform_fn
        module.metadata_fn = concrete_metadata_fn
    else:
        module.transform_fn = transform_fn
        module.metadata_fn = metadata_fn

    # Any variables created need to be explicitly tracked.
    module.created_variables = created_variables
    # Resources need to be explicitly tracked.
    module.resources = resource_tracker.resources
    # TODO(b/158011374) - Stop explicitly tracking initializers. Tracking the
    # table should be sufficient.
    initializers = []
    for resource in module.resources:
        if isinstance(resource, lookup_ops.InitializableLookupTableBase):
            initializers.append(resource._initializer)  # pylint: disable=protected-access
    module.initializers = initializers
    module.assets = [
        common_types.Asset(asset_filepath)
        for asset_filepath in concrete_transform_fn.graph.get_collection(
            tf.compat.v1.GraphKeys.ASSET_FILEPATHS)
    ]
    tf.saved_model.save(module, saved_model_dir)
    return concrete_transform_fn, concrete_metadata_fn
Esempio n. 5
0
def _bind_future_as_tensor_v2(
    future: nodes.ValueNode,
    tensor_info: TensorInfo,
    name: Optional[str] = None) -> common_types.TemporaryAnalyzerOutputType:
  """Bind a future value as a tensor to a TF2 FuncGraph.

    If the future is expected to write out an asset file and this method is
    invoked within a `TFGraphContext` that was provided a temporary directory,
    a temporary file is written out by this method.

    This could write out a significant number of temporary files depending on
    number of times the `preprocessing_fn` is traced and number of asset files
    in each tracing.

  Args:
    future: Future whose result should replace the graph tensor to which its
      bound.
    tensor_info: A `TensorInfo` object containing attributes to create the graph
      tensor.
    name: (Optional) If provided, the graph tensor created uses this name.

  Returns:
    A graph tensor or `tf.saved_model.Asset` that this future is bound to. If
    this future has already been evaluated in a previous TFT phase, it is
    directly returned.
  """
  graph = ops.get_default_graph()
  temp_dir = TFGraphContext.get_or_create_temp_dir()
  temporary_analyzer_info = _get_temporary_analyzer_output(
      temp_dir, tensor_info, name)
  is_asset_filepath = tensor_info.temporary_asset_value is not None

  # TODO(b/149997088): Switch to using a counter instead of tensor names.
  # Check if an evaluated value exists for this analyzer node.
  evaluated_replacements = TFGraphContext.get_evaluated_replacements()
  # evaluated_replacements is a dictionary from placeholder name to evaluated
  # tensor.
  # If `preprocessing_fn` was traced previously and this future was then
  # evaluated in a TFT phase, the result will be present in this dictionary.
  analyzer_name = temporary_analyzer_info.graph_tensor.name
  if (evaluated_replacements is not None and
      analyzer_name in evaluated_replacements):
    replaced_result = evaluated_replacements[analyzer_name]
    if is_asset_filepath:
      # Wrap asset files using `tf.saved_model.Asset` to ensure that
      # `SavedModel`s exported are hermetic.
      with tf.init_scope():
        asset = common_types.Asset(replaced_result)
      graph.add_to_collection(ASSET_REPLACEMENTS,
                              TemporaryAnalyzerOutputWrapper(asset, None))
      return asset
    else:
      return replaced_result
  else:
    graph.add_to_collection(
        TENSOR_REPLACEMENTS,
        TensorSink(temporary_analyzer_info.graph_tensor, future,
                   is_asset_filepath))
    if is_asset_filepath and temporary_analyzer_info.asset:
      graph.add_to_collection(ASSET_REPLACEMENTS, temporary_analyzer_info)
      return temporary_analyzer_info.asset
    else:
      return temporary_analyzer_info.graph_tensor