コード例 #1
0
ファイル: impl_helper.py プロジェクト: tensorflow/transform
def _trace_preprocessing_fn_v2(preprocessing_fn, specs, base_temp_dir):
    """Trace TF2 graph for `preprocessing_fn`."""
    tf_graph_context = graph_context.TFGraphContext(
        module_to_export=tf.Module(),
        temp_dir=base_temp_dir,
        evaluated_replacements=None)
    with annotators.object_tracker_scope(annotators.ObjectTracker()):
        concrete_fn = get_traced_transform_fn(
            preprocessing_fn, specs, tf_graph_context).get_concrete_function()
    return (concrete_fn.graph,
            tf2_utils.get_structured_inputs_from_func_graph(concrete_fn.graph),
            concrete_fn.structured_outputs)
コード例 #2
0
 def _get_schema(self,
                 preprocessing_fn,
                 use_compat_v1,
                 inputs=None,
                 input_signature=None,
                 create_session=False):
     if inputs is None:
         inputs = {}
     if input_signature is None:
         input_signature = {}
     if use_compat_v1:
         with tf.compat.v1.Graph().as_default() as graph:
             # Convert eager tensors to graph tensors.
             inputs_copy = {
                 k: tf.constant(v, input_signature[k].dtype)
                 for k, v in inputs.items()
             }
             tensors = preprocessing_fn(inputs_copy)
             if create_session:
                 # Create a session to actually evaluate the annotations and extract
                 # the output schema with annotations applied.
                 with tf.compat.v1.Session(graph=graph) as session:
                     schema = schema_inference.infer_feature_schema(
                         tensors, graph, session)
             else:
                 schema = schema_inference.infer_feature_schema(
                     tensors, graph)
     else:
         tf_func = tf.function(preprocessing_fn,
                               input_signature=[input_signature
                                                ]).get_concrete_function()
         tensors = tf.nest.pack_sequence_as(
             structure=tf_func.structured_outputs,
             flat_sequence=tf_func.outputs,
             expand_composites=True)
         structured_inputs = tf2_utils.get_structured_inputs_from_func_graph(
             tf_func.graph)
         tf_graph_context = graph_context.TFGraphContext(
             module_to_export=tf.Module(),
             temp_dir=os.path.join(self.get_temp_dir(),
                                   self._testMethodName),
             evaluated_replacements={})
         concrete_metadata_fn = schema_inference.get_traced_metadata_fn(
             preprocessing_fn=preprocessing_fn,
             structured_inputs=structured_inputs,
             tf_graph_context=tf_graph_context,
             evaluate_schema_overrides=create_session)
         schema = schema_inference.infer_feature_schema_v2(
             tensors,
             concrete_metadata_fn,
             evaluate_schema_overrides=create_session)
     return schema
コード例 #3
0
def _create_test_saved_model(export_in_tf1,
                             input_specs,
                             preprocessing_fn,
                             export_path_suffix=None,
                             base_dir=None):
    if not export_path_suffix:
        export_path = os.path.join(tempfile.mkdtemp(dir=base_dir), 'export')
    else:
        export_path = os.path.join(tempfile.mkdtemp(dir=base_dir),
                                   export_path_suffix)
    if export_in_tf1:
        with tf.compat.v1.Graph().as_default():
            with tf.compat.v1.Session().as_default() as session:
                inputs = {}
                for key in input_specs:
                    tensor_spec = input_specs[key]
                    if isinstance(tensor_spec, tf.TensorSpec):
                        inputs[key] = tf.compat.v1.placeholder(
                            tensor_spec.dtype, shape=tensor_spec.shape)
                    elif isinstance(tensor_spec, tf.SparseTensorSpec):
                        inputs[key] = tf.compat.v1.sparse_placeholder(
                            tensor_spec.dtype, shape=tensor_spec.shape)
                    elif isinstance(tensor_spec, tf.RaggedTensorSpec):
                        inputs[key] = tf.compat.v1.ragged.placeholder(
                            tensor_spec._dtype, tensor_spec._ragged_rank, [])
                    else:
                        raise ValueError(
                            'TypeSpecs specified should be one of `tf.TensorSpec`, '
                            '`tf.SparseTensorSpec`, `tf.RaggedTensorSpec`')
                outputs = preprocessing_fn(inputs)
                # show that unrelated & unmapped placeholders do not interfere
                tf.compat.v1.placeholder(tf.int64)
                saved_transform_io.write_saved_transform_from_session(
                    session, inputs, outputs, export_path)
    else:
        module = tf.Module()
        tf_graph_context = graph_context.TFGraphContext(
            module_to_export=module,
            temp_dir=None,
            evaluated_replacements=None)
        transform_fn = impl_helper.get_traced_transform_fn(
            preprocessing_fn=preprocessing_fn,
            input_signature=input_specs,
            tf_graph_context=tf_graph_context,
            output_keys_to_name_map=None)

        saved_transform_io_v2.write_v2_saved_model(module, transform_fn,
                                                   'transform_fn', export_path)
    return export_path
コード例 #4
0
 def transform_fn(inputs):
   graph = ops.get_default_graph()
   # If any analyzers have already been evaluated, pass them using the
   # `graph_context.TFGraphContext`. This will be used in place of the analyzer
   # nodes.
   with graph_context.TFGraphContext(
       temp_dir=base_temp_dir, evaluated_replacements=tensor_replacement_map):
     transformed_features = preprocessing_fn(inputs)
   # An empty `TENSOR_REPLACEMENTS` collection symbolizes that there is no
   # analyzer left for Transform to evaluate. Either if this collection is
   # empty or if no specific outputs have been requested, return
   # the same output as `preprocessing_fn` (i.e, transformed_features).
   if (output_keys_to_name_map is None or
       not graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS)):
     return transformed_features
   else:
     return {
         key: graph.get_tensor_by_name(value)
         for key, value in six.iteritems(output_keys_to_name_map)
     }
コード例 #5
0
    def metadata_fn(inputs):
        graph = ops.get_default_graph()
        # The user defined `preprocessing_fn` may directly modify its inputs which
        # is not allowed in a tf.function. Hence, we make a copy here.
        inputs_copy = tf_utils.copy_tensors(inputs)
        with graph_context.TFGraphContext(
                temp_dir=base_temp_dir,
                evaluated_replacements=tensor_replacement_map):
            transformed_features = preprocessing_fn(inputs_copy)

        # Get a map from tensor value names to feature keys.
        reversed_features = _get_tensor_value_to_key_map(transformed_features)

        result = collections.defaultdict(list)
        if not evaluate_schema_overrides:
            schema_override_tensors = graph.get_collection(
                _TF_METADATA_TENSOR_COLLECTION)
            for tensor in schema_override_tensors:
                if tensor.name in reversed_features:
                    result[_TF_METADATA_TENSOR_COLLECTION].append(
                        reversed_features[tensor.name])
        else:
            # Obtain schema overrides for feature tensor ranges.
            result.update(
                _get_schema_overrides(graph, reversed_features,
                                      _TF_METADATA_TENSOR_COLLECTION, [
                                          _TF_METADATA_TENSOR_MIN_COLLECTION,
                                          _TF_METADATA_TENSOR_MAX_COLLECTION
                                      ]))
            # Obtain schema overrides for feature protos. If no feature tensor is in
            # the `_TF_METADATA_EXTRA_ANNOTATION` collection for a specified
            # annotation, `_TF_METADATA_EXTRA_ANNOTATION_GLOBAL` is used as the
            # feature name to indicate that this annotation should be added to the
            # global schema.
            result.update(
                _get_schema_overrides(
                    graph, reversed_features, _TF_METADATA_EXTRA_ANNOTATION, [
                        _TF_METADATA_EXTRA_ANNOTATION_TYPE_URL,
                        _TF_METADATA_EXTRA_ANNOTATION_PROTO
                    ], _TF_METADATA_EXTRA_ANNOTATION_GLOBAL))
        return result
コード例 #6
0
ファイル: impl_helper.py プロジェクト: tensorflow/transform
def _trace_and_write_transform_fn(
    saved_model_dir: str,
    preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
                               Mapping[str, common_types.TensorType]],
    input_signature: Mapping[str, tf.TypeSpec], base_temp_dir: Optional[str],
    tensor_replacement_map: Optional[Dict[str, tf.Tensor]],
    output_keys_to_name_map: Optional[Dict[str, str]]
) -> function.ConcreteFunction:
    """Trace `preprocessing_fn` and serialize to a SavedModel."""
    tf_graph_context = graph_context.TFGraphContext(
        module_to_export=tf.Module(),
        temp_dir=base_temp_dir,
        evaluated_replacements=tensor_replacement_map)
    transform_fn = get_traced_transform_fn(
        preprocessing_fn,
        input_signature,
        tf_graph_context,
        output_keys_to_name_map=output_keys_to_name_map)
    return saved_transform_io_v2.write_v2_saved_model(
        tf_graph_context.module_to_export, transform_fn, 'transform_fn',
        saved_model_dir)
コード例 #7
0
    def metadata_fn():
        graph = ops.get_default_graph()
        inputs = tf2_utils.supply_missing_inputs(structured_inputs,
                                                 batch_size=1)
        with graph_context.TFGraphContext(
                temp_dir=base_temp_dir,
                evaluated_replacements=tensor_replacement_map):
            transformed_features = preprocessing_fn(inputs)

        # Get a map from tensor value names to feature keys.
        reversed_features = _get_tensor_value_to_key_map(transformed_features)

        result = collections.defaultdict(list)
        if not evaluate_schema_overrides:
            schema_override_tensors = graph.get_collection(
                _TF_METADATA_TENSOR_COLLECTION)
            for tensor in schema_override_tensors:
                if tensor.name in reversed_features:
                    result[_TF_METADATA_TENSOR_COLLECTION].append(
                        reversed_features[tensor.name])
        else:
            # Obtain schema overrides for feature tensor ranges.
            result.update(
                _get_schema_overrides(graph, reversed_features,
                                      _TF_METADATA_TENSOR_COLLECTION, [
                                          _TF_METADATA_TENSOR_MIN_COLLECTION,
                                          _TF_METADATA_TENSOR_MAX_COLLECTION
                                      ]))
            # Obtain schema overrides for feature protos. If no feature tensor is in
            # the `_TF_METADATA_EXTRA_ANNOTATION` collection for a specified
            # annotation, `_TF_METADATA_EXTRA_ANNOTATION_GLOBAL` is used as the
            # feature name to indicate that this annotation should be added to the
            # global schema.
            result.update(
                _get_schema_overrides(
                    graph, reversed_features, _TF_METADATA_EXTRA_ANNOTATION, [
                        _TF_METADATA_EXTRA_ANNOTATION_TYPE_URL,
                        _TF_METADATA_EXTRA_ANNOTATION_PROTO
                    ], _TF_METADATA_EXTRA_ANNOTATION_GLOBAL))
        return result
コード例 #8
0
ファイル: impl_helper.py プロジェクト: tensorflow/transform
def _trace_and_get_metadata(
    concrete_transform_fn: function.ConcreteFunction,
    structured_inputs: Mapping[str, common_types.TensorType],
    preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
                               Mapping[str, common_types.TensorType]],
    base_temp_dir: Optional[str],
    tensor_replacement_map: Optional[Dict[str, tf.Tensor]]
) -> dataset_metadata.DatasetMetadata:
    """Compute and return metadata for the outputs of `concrete_transform_fn`."""
    tf_graph_context = graph_context.TFGraphContext(
        module_to_export=tf.Module(),
        temp_dir=base_temp_dir,
        evaluated_replacements=tensor_replacement_map)
    concrete_metadata_fn = schema_inference.get_traced_metadata_fn(
        preprocessing_fn,
        structured_inputs,
        tf_graph_context,
        evaluate_schema_overrides=True)
    return dataset_metadata.DatasetMetadata(
        schema=schema_inference.infer_feature_schema_v2(
            concrete_transform_fn.structured_outputs,
            concrete_metadata_fn,
            evaluate_schema_overrides=True))