def wrapped_finalized(inputs): missing_inputs = self._get_missing_inputs(unfed_input_keys, batch_size=1) # Directly modifying inputs is not allowed in a tf.function. Hence, we # make a deep copy here. inputs_copy = tf_utils.copy_tensors(inputs) inputs_copy.update(missing_inputs) flattened_inputs = tf.nest.flatten(inputs_copy, expand_composites=True) transformed_features = self._wrapped_function(*flattened_inputs) return {key: transformed_features[key] for key in fetches_keys}
def _trace_preprocessing_fn_v1(preprocessing_fn, specs): """Trace TF1 graph for `preprocessing_fn`.""" with tf.compat.v1.Graph().as_default() as graph: with tf.compat.v1.name_scope('inputs'): structured_inputs = batched_placeholders_from_specs(specs) # In order to avoid a bug where import_graph_def fails when the # input_map and return_elements of an imported graph are the same # (b/34288791), we avoid using the placeholder of an input column as an # output of a graph. We do this by applying tf.identity to all inputs of # the preprocessing_fn. Note this applies at the level of raw tensors. # TODO(b/34288791): Remove this workaround and use a shallow copy of # inputs instead. A shallow copy is needed in case # self._preprocessing_fn mutates its input. copied_inputs = tf_utils.copy_tensors(structured_inputs) structured_outputs = preprocessing_fn(copied_inputs) return graph, structured_inputs, structured_outputs
def metadata_fn(inputs): graph = ops.get_default_graph() # The user defined `preprocessing_fn` may directly modify its inputs which # is not allowed in a tf.function. Hence, we make a copy here. inputs_copy = tf_utils.copy_tensors(inputs) with graph_context.TFGraphContext( temp_dir=base_temp_dir, evaluated_replacements=tensor_replacement_map): transformed_features = preprocessing_fn(inputs_copy) # Get a map from tensor value names to feature keys. reversed_features = _get_tensor_value_to_key_map(transformed_features) result = collections.defaultdict(list) if not evaluate_schema_overrides: schema_override_tensors = graph.get_collection( _TF_METADATA_TENSOR_COLLECTION) for tensor in schema_override_tensors: if tensor.name in reversed_features: result[_TF_METADATA_TENSOR_COLLECTION].append( reversed_features[tensor.name]) else: # Obtain schema overrides for feature tensor ranges. result.update( _get_schema_overrides(graph, reversed_features, _TF_METADATA_TENSOR_COLLECTION, [ _TF_METADATA_TENSOR_MIN_COLLECTION, _TF_METADATA_TENSOR_MAX_COLLECTION ])) # Obtain schema overrides for feature protos. If no feature tensor is in # the `_TF_METADATA_EXTRA_ANNOTATION` collection for a specified # annotation, `_TF_METADATA_EXTRA_ANNOTATION_GLOBAL` is used as the # feature name to indicate that this annotation should be added to the # global schema. result.update( _get_schema_overrides( graph, reversed_features, _TF_METADATA_EXTRA_ANNOTATION, [ _TF_METADATA_EXTRA_ANNOTATION_TYPE_URL, _TF_METADATA_EXTRA_ANNOTATION_PROTO ], _TF_METADATA_EXTRA_ANNOTATION_GLOBAL)) return result
def transform_fn(inputs): graph = ops.get_default_graph() # If any analyzers have already been evaluated, pass them using the # `graph_context.TFGraphContext`. This will be used in place of the analyzer # nodes. # The user defined `preprocessing_fn` may directly modify its inputs which # is not allowed in a tf.function. Hence, we make a copy here. inputs_copy = tf_utils.copy_tensors(inputs) with tf_graph_context: transformed_features = preprocessing_fn(inputs_copy) # An empty `TENSOR_REPLACEMENTS` collection symbolizes that there is no # analyzer left for Transform to evaluate. Either if this collection is # empty or if no specific outputs have been requested, return # the same output as `preprocessing_fn` (i.e, transformed_features). if (output_keys_to_name_map is None or not graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS)): return transformed_features else: return { key: graph.get_tensor_by_name(value) for key, value in output_keys_to_name_map.items() }