def _bind_future_as_tensor_v2(future, tensor_info, name=None): """Bind a future value as a tensor to a TF2 FuncGraph. If the future is expected to write out an asset file, a temporary file is written out by this method. This could write out a significant number of temporary files depending on number of times the `preprocessing_fn` is traced and number of asset files in each tracing. If this future has already been evaluated in a previous TFT phase, it is directly returned. Args: future: Future whose result should replace the graph tensor to which its bound. tensor_info: A `TensorInfo` object containing attributes to create the graph tensor. name: (Optional) If provided, the graph tensor created uses this name. Returns: A graph tensor that this future is bound to. If this future has already been evaluated in a previous TFT phase, it is directly returned. """ graph = ops.get_default_graph() temp_dir = TFGraphContext.get_or_create_temp_dir() temporary_graph_tensor = _get_temporary_graph_tensor( temp_dir, tensor_info, name) is_asset_filepath = tensor_info.temporary_asset_value is not None # TODO(b/149997088): Switch to using a counter instead of tensor names. # Check if an evaluated value exists for this analyzer node. evaluated_replacements = TFGraphContext.get_evaluated_replacements() # evaluated_replacements is a dictionary from placeholder name to evaluated # tensor. # If `preprocessing_fn` was traced previously and this future was then # evaluated in a TFT phase, the result will be present in this dictionary. if (evaluated_replacements is not None and temporary_graph_tensor.name in evaluated_replacements): replaced_result = evaluated_replacements[temporary_graph_tensor.name] if is_asset_filepath: # Wrap asset files using `tf.saved_model.Asset` to ensure that # `SavedModel`s exported are hermetic. asset_val = tf.saved_model.Asset(replaced_result) graph.add_to_collection(ASSET_REPLACEMENTS, asset_val) return asset_val else: return replaced_result else: result = temporary_graph_tensor if temp_dir and is_asset_filepath: # Wrap asset files using `tf.saved_model.Asset` to ensure that # `SavedModel`s exported are hermetic. result = tf.saved_model.Asset(temporary_graph_tensor) graph.add_to_collection(ASSET_REPLACEMENTS, result) graph.add_to_collection( TENSOR_REPLACEMENTS, TensorSink(temporary_graph_tensor, future, is_asset_filepath)) return result
def _get_temporary_graph_tensor(tensor_info, name): """Get a temporary graph tensor using attributes in `tensor_info`.""" is_asset_filepath = tensor_info.temporary_asset_value is not None if is_asset_filepath: # Placeholders cannot be used for assets as they will be initialized as part # of the init op. Hence, a temporary file is written out during tracing. # TODO(b/164921571) Support temporary files in tfrecord format. # TODO(b/149997088): Reduce number of temporary files written out. with tf.init_scope(): temporary_asset_filepath = os.path.join( TFGraphContext.get_or_create_temp_dir(), uuid.uuid4().hex) with tf.io.gfile.GFile(temporary_asset_filepath, 'w') as f: f.write(tensor_info.temporary_asset_value) result = tf.constant( temporary_asset_filepath, dtype=tensor_info.dtype, shape=tensor_info.shape, name=name) else: # Using a placeholder with no default value causes tracing to fail if there # is any control flow dependent on a child tensor of this placeholder. # Hence, provide a temporary default value for it. # If dtype is string, we want a tensor that contains '0's instead of b'[] to # allow string to numeric conversion ops to trace successfully. temporary_dtype = ( tf.int64 if tensor_info.dtype == tf.string else tensor_info.dtype) temporary_tensor = tf2_utils.supply_missing_tensor( 1, tf.TensorShape(tensor_info.shape), temporary_dtype) if tensor_info.dtype == tf.string: temporary_tensor = tf.strings.as_string(temporary_tensor) result = tf.raw_ops.PlaceholderWithDefault( input=temporary_tensor, shape=tensor_info.shape) return result
def _get_object(name: str) -> Optional[base.Trackable]: """If an object is being tracked using `name` return it, else None.""" module = TFGraphContext.get_module_to_export() # The `preprocessing_fn` should always be invoked within a TFGraphContext. If # not, module will be None. if module is None: raise RuntimeError( f'No module found to track {name} with. Check that the `preprocessing_fn` is' ' invoked within a `TFGraphContext` with a valid ' '`TFGraphContext.module_to_export`.') return getattr(module, name, None)
def add_trackable_object(self, trackable_object: base.Trackable, name: Optional[str]): """Add `trackable_object` to list of objects tracked.""" if name is None: self._trackable_objects.append(trackable_object) else: module = TFGraphContext.get_module_to_export() # The `preprocessing_fn` should always be invoked within a TFGraphContext. # If not, module will be None. if module is None: raise RuntimeError( f'No module found to track {name} with. Check that the ' '`preprocessing_fn` is invoked within a `TFGraphContext` with a ' 'valid `TFGraphContext.module_to_export`.') if hasattr(module, name): raise ValueError( f'An object with name {name} is already being tracked. Check that a ' 'unique name was passed.') setattr(module, name, trackable_object)
def _bind_future_as_tensor_v2( future: nodes.ValueNode, tensor_info: TensorInfo, name: Optional[str] = None) -> common_types.TemporaryAnalyzerOutputType: """Bind a future value as a tensor to a TF2 FuncGraph. If the future is expected to write out an asset file and this method is invoked within a `TFGraphContext` that was provided a temporary directory, a temporary file is written out by this method. This could write out a significant number of temporary files depending on number of times the `preprocessing_fn` is traced and number of asset files in each tracing. Args: future: Future whose result should replace the graph tensor to which its bound. tensor_info: A `TensorInfo` object containing attributes to create the graph tensor. name: (Optional) If provided, the graph tensor created uses this name. Returns: A graph tensor or `tf.saved_model.Asset` that this future is bound to. If this future has already been evaluated in a previous TFT phase, it is directly returned. """ graph = ops.get_default_graph() temp_dir = TFGraphContext.get_or_create_temp_dir() temporary_analyzer_info = _get_temporary_analyzer_output( temp_dir, tensor_info, name) is_asset_filepath = tensor_info.temporary_asset_value is not None # TODO(b/149997088): Switch to using a counter instead of tensor names. # Check if an evaluated value exists for this analyzer node. evaluated_replacements = TFGraphContext.get_evaluated_replacements() # evaluated_replacements is a dictionary from placeholder name to evaluated # tensor. # If `preprocessing_fn` was traced previously and this future was then # evaluated in a TFT phase, the result will be present in this dictionary. analyzer_name = temporary_analyzer_info.graph_tensor.name if (evaluated_replacements is not None and analyzer_name in evaluated_replacements): replaced_result = evaluated_replacements[analyzer_name] if is_asset_filepath: graph.add_to_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS, replaced_result) return replaced_result else: # Without the identity wrapper some V2 tests fail with AttributeError: # Tensor.name is meaningless when eager execution is enabled. # TODO(b/149997088): Remove the identity wrapper once we no longer rely on # tensor names. return tf.identity(replaced_result) else: graph.add_to_collection( TENSOR_REPLACEMENTS, TensorSink(temporary_analyzer_info.graph_tensor, future, is_asset_filepath)) eager_asset_path = temporary_analyzer_info.eager_asset_path if is_asset_filepath and eager_asset_path is not None: tf_utils.track_asset_analyzer_output(eager_asset_path, temporary_analyzer_info.graph_tensor) graph.add_to_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS, eager_asset_path) return eager_asset_path else: return temporary_analyzer_info.graph_tensor