Exemplo n.º 1
0
 def wrapped_finalized(inputs):
   missing_inputs = self._get_missing_inputs(unfed_input_keys, batch_size=1)
   # Directly modifying inputs is not allowed in a tf.function. Hence, we
   # make a deep copy here.
   inputs_copy = tf_utils.copy_tensors(inputs)
   inputs_copy.update(missing_inputs)
   flattened_inputs = tf.nest.flatten(inputs_copy, expand_composites=True)
   transformed_features = self._wrapped_function(*flattened_inputs)
   return {key: transformed_features[key] for key in fetches_keys}
Exemplo n.º 2
0
def _trace_preprocessing_fn_v1(preprocessing_fn, specs):
    """Trace TF1 graph for `preprocessing_fn`."""
    with tf.compat.v1.Graph().as_default() as graph:
        with tf.compat.v1.name_scope('inputs'):
            structured_inputs = batched_placeholders_from_specs(specs)
            # In order to avoid a bug where import_graph_def fails when the
            # input_map and return_elements of an imported graph are the same
            # (b/34288791), we avoid using the placeholder of an input column as an
            # output of a graph. We do this by applying tf.identity to all inputs of
            # the preprocessing_fn.  Note this applies at the level of raw tensors.
            # TODO(b/34288791): Remove this workaround and use a shallow copy of
            # inputs instead.  A shallow copy is needed in case
            # self._preprocessing_fn mutates its input.
            copied_inputs = tf_utils.copy_tensors(structured_inputs)

        structured_outputs = preprocessing_fn(copied_inputs)
    return graph, structured_inputs, structured_outputs
Exemplo n.º 3
0
    def metadata_fn(inputs):
        graph = ops.get_default_graph()
        # The user defined `preprocessing_fn` may directly modify its inputs which
        # is not allowed in a tf.function. Hence, we make a copy here.
        inputs_copy = tf_utils.copy_tensors(inputs)
        with graph_context.TFGraphContext(
                temp_dir=base_temp_dir,
                evaluated_replacements=tensor_replacement_map):
            transformed_features = preprocessing_fn(inputs_copy)

        # Get a map from tensor value names to feature keys.
        reversed_features = _get_tensor_value_to_key_map(transformed_features)

        result = collections.defaultdict(list)
        if not evaluate_schema_overrides:
            schema_override_tensors = graph.get_collection(
                _TF_METADATA_TENSOR_COLLECTION)
            for tensor in schema_override_tensors:
                if tensor.name in reversed_features:
                    result[_TF_METADATA_TENSOR_COLLECTION].append(
                        reversed_features[tensor.name])
        else:
            # Obtain schema overrides for feature tensor ranges.
            result.update(
                _get_schema_overrides(graph, reversed_features,
                                      _TF_METADATA_TENSOR_COLLECTION, [
                                          _TF_METADATA_TENSOR_MIN_COLLECTION,
                                          _TF_METADATA_TENSOR_MAX_COLLECTION
                                      ]))
            # Obtain schema overrides for feature protos. If no feature tensor is in
            # the `_TF_METADATA_EXTRA_ANNOTATION` collection for a specified
            # annotation, `_TF_METADATA_EXTRA_ANNOTATION_GLOBAL` is used as the
            # feature name to indicate that this annotation should be added to the
            # global schema.
            result.update(
                _get_schema_overrides(
                    graph, reversed_features, _TF_METADATA_EXTRA_ANNOTATION, [
                        _TF_METADATA_EXTRA_ANNOTATION_TYPE_URL,
                        _TF_METADATA_EXTRA_ANNOTATION_PROTO
                    ], _TF_METADATA_EXTRA_ANNOTATION_GLOBAL))
        return result
Exemplo n.º 4
0
 def transform_fn(inputs):
     graph = ops.get_default_graph()
     # If any analyzers have already been evaluated, pass them using the
     # `graph_context.TFGraphContext`. This will be used in place of the analyzer
     # nodes.
     # The user defined `preprocessing_fn` may directly modify its inputs which
     # is not allowed in a tf.function. Hence, we make a copy here.
     inputs_copy = tf_utils.copy_tensors(inputs)
     with tf_graph_context:
         transformed_features = preprocessing_fn(inputs_copy)
     # An empty `TENSOR_REPLACEMENTS` collection symbolizes that there is no
     # analyzer left for Transform to evaluate. Either if this collection is
     # empty or if no specific outputs have been requested, return
     # the same output as `preprocessing_fn` (i.e, transformed_features).
     if (output_keys_to_name_map is None or
             not graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS)):
         return transformed_features
     else:
         return {
             key: graph.get_tensor_by_name(value)
             for key, value in output_keys_to_name_map.items()
         }