コード例 #1
0
ファイル: analyzer_nodes.py プロジェクト: Mikehem/tfx
def _bind_future_as_tensor_v2(future, tensor_info, name=None):
    """Bind a future value as a tensor to a TF2 FuncGraph.

    If the future is expected to write out an asset file, a temporary file is
    written out by this method. This could write out a significant number of
    temporary files depending on number of times the `preprocessing_fn` is
    traced and number of asset files in each tracing.

    If this future has already been evaluated in a previous TFT phase, it is
    directly returned.

  Args:
    future: Future whose result should replace the graph tensor to which its
      bound.
    tensor_info: A `TensorInfo` object containing attributes to create the graph
      tensor.
    name: (Optional) If provided, the graph tensor created uses this name.

  Returns:
    A graph tensor that this future is bound to. If this future has already
    been evaluated in a previous TFT phase, it is directly returned.
  """
    graph = ops.get_default_graph()
    temp_dir = TFGraphContext.get_or_create_temp_dir()
    temporary_graph_tensor = _get_temporary_graph_tensor(
        temp_dir, tensor_info, name)
    is_asset_filepath = tensor_info.temporary_asset_value is not None

    # TODO(b/149997088): Switch to using a counter instead of tensor names.
    # Check if an evaluated value exists for this analyzer node.
    evaluated_replacements = TFGraphContext.get_evaluated_replacements()
    # evaluated_replacements is a dictionary from placeholder name to evaluated
    # tensor.
    # If `preprocessing_fn` was traced previously and this future was then
    # evaluated in a TFT phase, the result will be present in this dictionary.
    if (evaluated_replacements is not None
            and temporary_graph_tensor.name in evaluated_replacements):
        replaced_result = evaluated_replacements[temporary_graph_tensor.name]
        if is_asset_filepath:
            # Wrap asset files using `tf.saved_model.Asset` to ensure that
            # `SavedModel`s exported are hermetic.
            asset_val = tf.saved_model.Asset(replaced_result)
            graph.add_to_collection(ASSET_REPLACEMENTS, asset_val)
            return asset_val
        else:
            return replaced_result
    else:
        result = temporary_graph_tensor
        if temp_dir and is_asset_filepath:
            # Wrap asset files using `tf.saved_model.Asset` to ensure that
            # `SavedModel`s exported are hermetic.
            result = tf.saved_model.Asset(temporary_graph_tensor)
            graph.add_to_collection(ASSET_REPLACEMENTS, result)
        graph.add_to_collection(
            TENSOR_REPLACEMENTS,
            TensorSink(temporary_graph_tensor, future, is_asset_filepath))
        return result
コード例 #2
0
ファイル: analyzer_nodes.py プロジェクト: geewynn/transform
def _get_temporary_graph_tensor(tensor_info, name):
  """Get a temporary graph tensor using attributes in `tensor_info`."""
  is_asset_filepath = tensor_info.temporary_asset_value is not None
  if is_asset_filepath:
    # Placeholders cannot be used for assets as they will be initialized as part
    # of the init op. Hence, a temporary file is written out during tracing.
    # TODO(b/164921571) Support temporary files in tfrecord format.
    # TODO(b/149997088): Reduce number of temporary files written out.
    with tf.init_scope():
      temporary_asset_filepath = os.path.join(
          TFGraphContext.get_or_create_temp_dir(),
          uuid.uuid4().hex)
      with tf.io.gfile.GFile(temporary_asset_filepath, 'w') as f:
        f.write(tensor_info.temporary_asset_value)
    result = tf.constant(
        temporary_asset_filepath,
        dtype=tensor_info.dtype,
        shape=tensor_info.shape,
        name=name)
  else:
    # Using a placeholder with no default value causes tracing to fail if there
    # is any control flow dependent on a child tensor of this placeholder.
    # Hence, provide a temporary default value for it.
    # If dtype is string, we want a tensor that contains '0's instead of b'[] to
    # allow string to numeric conversion ops to trace successfully.
    temporary_dtype = (
        tf.int64 if tensor_info.dtype == tf.string else tensor_info.dtype)
    temporary_tensor = tf2_utils.supply_missing_tensor(
        1, tf.TensorShape(tensor_info.shape), temporary_dtype)
    if tensor_info.dtype == tf.string:
      temporary_tensor = tf.strings.as_string(temporary_tensor)
    result = tf.raw_ops.PlaceholderWithDefault(
        input=temporary_tensor, shape=tensor_info.shape)
  return result
コード例 #3
0
ファイル: annotators.py プロジェクト: tensorflow/transform
def _get_object(name: str) -> Optional[base.Trackable]:
    """If an object is being tracked using `name` return it, else None."""
    module = TFGraphContext.get_module_to_export()
    # The `preprocessing_fn` should always be invoked within a TFGraphContext. If
    # not, module will be None.
    if module is None:
        raise RuntimeError(
            f'No module found to track {name} with. Check that the `preprocessing_fn` is'
            ' invoked within a `TFGraphContext` with a valid '
            '`TFGraphContext.module_to_export`.')
    return getattr(module, name, None)
コード例 #4
0
ファイル: annotators.py プロジェクト: tensorflow/transform
 def add_trackable_object(self, trackable_object: base.Trackable,
                          name: Optional[str]):
     """Add `trackable_object` to list of objects tracked."""
     if name is None:
         self._trackable_objects.append(trackable_object)
     else:
         module = TFGraphContext.get_module_to_export()
         # The `preprocessing_fn` should always be invoked within a TFGraphContext.
         # If not, module will be None.
         if module is None:
             raise RuntimeError(
                 f'No module found to track {name} with. Check that the '
                 '`preprocessing_fn` is invoked within a `TFGraphContext` with a '
                 'valid `TFGraphContext.module_to_export`.')
         if hasattr(module, name):
             raise ValueError(
                 f'An object with name {name} is already being tracked. Check that a '
                 'unique name was passed.')
         setattr(module, name, trackable_object)
コード例 #5
0
def _bind_future_as_tensor_v2(
    future: nodes.ValueNode,
    tensor_info: TensorInfo,
    name: Optional[str] = None) -> common_types.TemporaryAnalyzerOutputType:
  """Bind a future value as a tensor to a TF2 FuncGraph.

    If the future is expected to write out an asset file and this method is
    invoked within a `TFGraphContext` that was provided a temporary directory,
    a temporary file is written out by this method.

    This could write out a significant number of temporary files depending on
    number of times the `preprocessing_fn` is traced and number of asset files
    in each tracing.

  Args:
    future: Future whose result should replace the graph tensor to which its
      bound.
    tensor_info: A `TensorInfo` object containing attributes to create the graph
      tensor.
    name: (Optional) If provided, the graph tensor created uses this name.

  Returns:
    A graph tensor or `tf.saved_model.Asset` that this future is bound to. If
    this future has already been evaluated in a previous TFT phase, it is
    directly returned.
  """
  graph = ops.get_default_graph()
  temp_dir = TFGraphContext.get_or_create_temp_dir()
  temporary_analyzer_info = _get_temporary_analyzer_output(
      temp_dir, tensor_info, name)
  is_asset_filepath = tensor_info.temporary_asset_value is not None

  # TODO(b/149997088): Switch to using a counter instead of tensor names.
  # Check if an evaluated value exists for this analyzer node.
  evaluated_replacements = TFGraphContext.get_evaluated_replacements()
  # evaluated_replacements is a dictionary from placeholder name to evaluated
  # tensor.
  # If `preprocessing_fn` was traced previously and this future was then
  # evaluated in a TFT phase, the result will be present in this dictionary.
  analyzer_name = temporary_analyzer_info.graph_tensor.name
  if (evaluated_replacements is not None and
      analyzer_name in evaluated_replacements):
    replaced_result = evaluated_replacements[analyzer_name]
    if is_asset_filepath:
      graph.add_to_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS,
                              replaced_result)
      return replaced_result
    else:
      # Without the identity wrapper some V2 tests fail with AttributeError:
      # Tensor.name is meaningless when eager execution is enabled.
      # TODO(b/149997088): Remove the identity wrapper once we no longer rely on
      # tensor names.
      return tf.identity(replaced_result)
  else:
    graph.add_to_collection(
        TENSOR_REPLACEMENTS,
        TensorSink(temporary_analyzer_info.graph_tensor, future,
                   is_asset_filepath))
    eager_asset_path = temporary_analyzer_info.eager_asset_path
    if is_asset_filepath and eager_asset_path is not None:
      tf_utils.track_asset_analyzer_output(eager_asset_path,
                                           temporary_analyzer_info.graph_tensor)
      graph.add_to_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS,
                              eager_asset_path)
      return eager_asset_path
    else:
      return temporary_analyzer_info.graph_tensor