예제 #1
0
def _trace_preprocessing_fn_v2(preprocessing_fn, specs, base_temp_dir):
    """Trace TF2 graph for `preprocessing_fn`."""
    concrete_fn = get_traced_transform_fn(
        preprocessing_fn, specs, base_temp_dir).get_concrete_function()
    return (concrete_fn.graph,
            tf2_utils.get_structured_inputs_from_func_graph(concrete_fn.graph),
            concrete_fn.structured_outputs)
예제 #2
0
def analyze_in_place(preprocessing_fn, force_tf_compat_v1, feature_specs,
                     type_specs, transform_output_path):
    """Analyzes the `preprocessing_fn` in-place without looking at the data.

  This should only be used if the `preprocessing_fn` contains no TFT
  analyzers or TFT mappers that use analyzers.

  Writes out a transform function and transformed metadata to subdirs under
  `transform_output_path`.

  Args:
    preprocessing_fn: The tf.Transform preprocessing_fn.
    force_tf_compat_v1: If True, call Transform's API to use Tensorflow in
      tf.compat.v1 mode.
    feature_specs: a Dict from input feature key to its feature spec.
    type_specs: a Dict from input feature key to its type spec.
    transform_output_path: An absolute path to write the output to.

  Raises:
    RuntimeError if `preprocessing_fn` contains TFT analyzers.
  """
    use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(force_tf_compat_v1)
    transform_fn_path = os.path.join(transform_output_path,
                                     TFTransformOutput.TRANSFORM_FN_DIR)
    if use_tf_compat_v1:
        graph, structured_inputs, structured_outputs = (
            trace_preprocessing_function(preprocessing_fn,
                                         feature_specs,
                                         use_tf_compat_v1=use_tf_compat_v1))
        _assert_no_analyzers_in_graph(graph)
        with tf.compat.v1.Session(graph=graph) as sess:
            sess.run(tf.compat.v1.global_variables_initializer())
            sess.run(tf.compat.v1.tables_initializer())
            saved_transform_io.write_saved_transform_from_session(
                sess, structured_inputs, structured_outputs, transform_fn_path)

            transformed_metadata = dataset_metadata.DatasetMetadata(
                schema=schema_inference.infer_feature_schema(
                    structured_outputs, graph, sess))
    else:
        concrete_transform_fn = _trace_and_write_transform_fn(
            saved_model_dir=transform_fn_path,
            preprocessing_fn=preprocessing_fn,
            input_signature=type_specs,
            base_temp_dir=None,
            tensor_replacement_map=None,
            output_keys_to_name_map=None)
        _assert_no_analyzers_in_graph(concrete_transform_fn.graph)
        structured_inputs = tf2_utils.get_structured_inputs_from_func_graph(
            concrete_transform_fn.graph)
        transformed_metadata = _trace_and_get_metadata(
            concrete_transform_fn=concrete_transform_fn,
            structured_inputs=structured_inputs,
            preprocessing_fn=preprocessing_fn,
            base_temp_dir=None,
            tensor_replacement_map=None)
    transformed_metadata_dir = os.path.join(
        transform_output_path, TFTransformOutput.TRANSFORMED_METADATA_DIR)
    metadata_io.write_metadata(transformed_metadata, transformed_metadata_dir)
예제 #3
0
def trace_and_write_v2_saved_model(
        saved_model_dir: str,
        preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
                                   Mapping[str, common_types.TensorType]],
        input_signature: Mapping[str, tf.TypeSpec],
        base_temp_dir: Optional[str], baseline_analyzers_fingerprint: Mapping[
            str, graph_tools.AnalyzersFingerprint],
        tensor_replacement_map: Optional[Dict[str, tf.Tensor]],
        output_keys_to_name_map: Optional[Dict[str, str]]):
    """Writes out a SavedModelV2 with preprocessing_fn traced using tf.function.

  The SavedModel written contains a method called `transform_fn` that
  represents the traced `preprocessing_fn`. Additionally, if this is the final
  SavedModel being written out, it will contain a method called `metadata_fn`
  that provides deferred schema annotations.

  Args:
    saved_model_dir: Path to write SavedModel to.
    preprocessing_fn: A user defined python function to be traced.
    input_signature: TypeSpecs describing the inputs to the `preprocessing_fn`.
    base_temp_dir: Base path to write temporary artifacts to.
    baseline_analyzers_fingerprint: A mapping from analyzer name to a set of
      paths that define its fingerprint.
    tensor_replacement_map: A map from placeholder tensor names to their
      evaluated replacement tensors.
    output_keys_to_name_map: A map from output dictionary keys to the names of
      the tensors that they represent.

  Returns:
    A tuple containing a pair of `tf.ConcreteFunction`s:
      1. The traced preprocessing_fn.
      2. A metadata_fn that returns a dictionary containing the deferred
      annotations added to the graph when invoked with any valid input.

  Raises:
    RuntimeError: if analyzers in `preprocessing_fn` are encountered in a
    non-deterministic order.
  """
    concrete_transform_fn = _trace_and_write_transform_fn(
        saved_model_dir, preprocessing_fn, input_signature, base_temp_dir,
        tensor_replacement_map, output_keys_to_name_map)
    structured_inputs = tf2_utils.get_structured_inputs_from_func_graph(
        concrete_transform_fn.graph)
    _validate_analyzers_fingerprint(baseline_analyzers_fingerprint,
                                    concrete_transform_fn.graph,
                                    structured_inputs)

    # If the `TENSOR_REPLACEMENTS` graph collection is empty, all TFT analyzers
    # in the `preprocessing_fn` have already been evaluated.
    if not concrete_transform_fn.graph.get_collection(
            analyzer_nodes.TENSOR_REPLACEMENTS):
        metadata = _trace_and_get_metadata(concrete_transform_fn,
                                           structured_inputs, preprocessing_fn,
                                           base_temp_dir,
                                           tensor_replacement_map)
        metadata_io.write_metadata(
            metadata, os.path.join(saved_model_dir, METADATA_DIR_NAME))
    def __init__(self, saved_model_dir):
        """Init method for SavedModelLoader.

    Args:
      saved_model_dir: A SavedModel directory providing a transform graph.  The
        MetaGraphDef and signature are selected from the SavedModel using keys
        defined in `../constants.py` ('transform' and 'transform_signature',
        respectively).
    """
        if tf.version.VERSION < '2.5':
            self._imported = load.load_internal(saved_model_dir,
                                                loader_cls=_Loader)
            if isinstance(self._imported, dict):
                self._imported = self._imported['root']
        else:
            # TODO(b/160294509): Stop using tf.compat.v2 when TF1.15 support is
            # dropped.
            self._imported = tf.compat.v2.saved_model.load(saved_model_dir)
        self.load_v2_in_compat = (constants.TRANSFORM_SIGNATURE
                                  in self._imported.signatures)
        if self.load_v2_in_compat:
            self._wrapped = self._imported.signatures[
                constants.TRANSFORM_SIGNATURE]
            self._func_graph = self._wrapped.graph
            self._structured_inputs = self._get_input_signature_from_v1_saved_model(
                saved_model_dir)
            self._structured_outputs = self._wrapped.structured_outputs
        else:
            # TODO(b/160550490): Remove local import.
            from tensorflow_transform import tf2_utils  # pylint: disable=g-import-not-at-top

            # Since `input_signature` was specified when exporting the tf function to
            # transform_fn is now a ConcreteFunction, but was a tf.function. We need
            # to handle both to maintain backward compatiblity. If it's a tf.function,
            # since `input_signature` was specified when exporting the tf function to
            # `SavedModel`, there should be exactly one concrete function present on
            # loading the `SavedModel`.
            if hasattr(self._imported.transform_fn, 'concrete_functions'):
                concrete_functions = self._imported.transform_fn.concrete_functions
                assert len(concrete_functions) == 1, concrete_functions
                self._wrapped = concrete_functions[0]
            else:
                self._wrapped = self._imported.transform_fn
            self._func_graph = self._wrapped.graph
            self._structured_inputs = (
                tf2_utils.get_structured_inputs_from_func_graph(
                    self._func_graph))
            self._structured_outputs = tf.nest.pack_sequence_as(
                self._func_graph.structured_outputs,
                self._func_graph.outputs,
                expand_composites=True)
        self._output_to_inputs_map = (self._get_output_to_inputs_map(
            self._structured_outputs))
        saved_transform_io._maybe_register_addon_ops()  # pylint: disable=protected-access
예제 #5
0
def _trace_preprocessing_fn_v2(preprocessing_fn, specs, base_temp_dir):
    """Trace TF2 graph for `preprocessing_fn`."""
    tf_graph_context = graph_context.TFGraphContext(
        module_to_export=tf.Module(),
        temp_dir=base_temp_dir,
        evaluated_replacements=None)
    with annotators.object_tracker_scope(annotators.ObjectTracker()):
        concrete_fn = get_traced_transform_fn(
            preprocessing_fn, specs, tf_graph_context).get_concrete_function()
    return (concrete_fn.graph,
            tf2_utils.get_structured_inputs_from_func_graph(concrete_fn.graph),
            concrete_fn.structured_outputs)
예제 #6
0
 def _get_schema(self,
                 preprocessing_fn,
                 use_compat_v1,
                 inputs=None,
                 input_signature=None,
                 create_session=False):
     if inputs is None:
         inputs = {}
     if input_signature is None:
         input_signature = {}
     if use_compat_v1:
         with tf.compat.v1.Graph().as_default() as graph:
             # Convert eager tensors to graph tensors.
             inputs_copy = {
                 k: tf.constant(v, input_signature[k].dtype)
                 for k, v in inputs.items()
             }
             tensors = preprocessing_fn(inputs_copy)
             if create_session:
                 # Create a session to actually evaluate the annotations and extract
                 # the output schema with annotations applied.
                 with tf.compat.v1.Session(graph=graph) as session:
                     schema = schema_inference.infer_feature_schema(
                         tensors, graph, session)
             else:
                 schema = schema_inference.infer_feature_schema(
                     tensors, graph)
     else:
         tf_func = tf.function(preprocessing_fn,
                               input_signature=[input_signature
                                                ]).get_concrete_function()
         tensors = tf.nest.pack_sequence_as(
             structure=tf_func.structured_outputs,
             flat_sequence=tf_func.outputs,
             expand_composites=True)
         structured_inputs = tf2_utils.get_structured_inputs_from_func_graph(
             tf_func.graph)
         tf_graph_context = graph_context.TFGraphContext(
             module_to_export=tf.Module(),
             temp_dir=os.path.join(self.get_temp_dir(),
                                   self._testMethodName),
             evaluated_replacements={})
         concrete_metadata_fn = schema_inference.get_traced_metadata_fn(
             preprocessing_fn=preprocessing_fn,
             structured_inputs=structured_inputs,
             tf_graph_context=tf_graph_context,
             evaluate_schema_overrides=create_session)
         schema = schema_inference.infer_feature_schema_v2(
             tensors,
             concrete_metadata_fn,
             evaluate_schema_overrides=create_session)
     return schema
예제 #7
0
def infer_feature_schema_v2(features, concrete_metadata_fn,
                            evaluate_schema_overrides):
    """Given a dict of tensors, creates a `Schema`.

  Infers a schema, in the format of a tf.Transform `Schema`, for the given
  dictionary of tensors.

  If there is an override specified, we override the inferred schema for the
  given feature's tensor.  An override has the meaning that we should set
  is_categorical=True.  If evaluate_schema_overrides is False then we just set
  is_categorical=True, and if evaluate_schema_overrides is True then we also
  compute values of the tensors representing the min and max values and set them
  in the schema.

  If annotations have been specified, they are added to the output schema.

  Args:
    features: A dict mapping column names to `Tensor` or `SparseTensor`s. The
      `Tensor` or `SparseTensor`s should have a 0'th dimension which is
      interpreted as the batch dimension.
    concrete_metadata_fn: A `tf.ConcreteFunction` that returns a dictionary
      containing the deferred annotations added to the graph when invoked with
      any valid input.
    evaluate_schema_overrides: A Boolean used to compute schema overrides. If
      `False`, schema overrides will not be computed.

  Returns:
    A `Schema` proto.
  """
    structured_inputs = tf2_utils.get_structured_inputs_from_func_graph(
        concrete_metadata_fn.graph)
    # Invoke concrete_metadata_fn with some dummy data.
    inputs = tf2_utils.supply_missing_inputs(structured_inputs, batch_size=1)
    flattened_inputs = tf.nest.flatten(inputs, expand_composites=True)
    metadata = collections.defaultdict(list,
                                       concrete_metadata_fn(*flattened_inputs))

    if not evaluate_schema_overrides:
        tensor_ranges = {
            tensor.numpy().decode(): (None, None)
            for tensor in metadata[_TF_METADATA_TENSOR_COLLECTION]
        }
        tensor_annotations = {}
        global_annotations = []
    else:
        tensor_ranges = _get_tensor_ranges_v2(metadata)
        tensor_annotations, global_annotations = _get_schema_annotations_v2(
            metadata)
    return _infer_feature_schema_common(features, tensor_ranges,
                                        tensor_annotations, global_annotations)
예제 #8
0
def _trace_and_get_metadata(
    concrete_transform_fn: function.ConcreteFunction,
    preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
                               Mapping[str, common_types.TensorType]],
    base_temp_dir: Optional[str],
    tensor_replacement_map: Optional[Dict[str, tf.Tensor]]
) -> dataset_metadata.DatasetMetadata:
    """Compute and return metadata for the outputs of `concrete_transform_fn`."""
    structured_inputs = tf2_utils.get_structured_inputs_from_func_graph(
        concrete_transform_fn.graph)
    metadata_fn = schema_inference.get_traced_metadata_fn(
        tensor_replacement_map,
        preprocessing_fn,
        structured_inputs,
        base_temp_dir,
        evaluate_schema_overrides=True)
    return dataset_metadata.DatasetMetadata(
        schema=schema_inference.infer_feature_schema_v2(
            concrete_transform_fn.structured_outputs,
            metadata_fn.get_concrete_function(),
            evaluate_schema_overrides=True))
예제 #9
0
  def __init__(self, saved_model_dir: str):
    """Init method for SavedModelLoader.

    Args:
      saved_model_dir: A SavedModel directory providing a transform graph.  The
        MetaGraphDef and signature are selected from the SavedModel using keys
        defined in `../constants.py` ('transform' and 'transform_signature',
        respectively).
    """
    # TODO(b/160294509): Stop using tf.compat.v2 when TF1.15 support is
    # dropped.
    imported = tf.compat.v2.saved_model.load(saved_model_dir)
    load_v2_in_compat = constants.TRANSFORM_SIGNATURE in imported.signatures
    if load_v2_in_compat:
      restored_function = imported.signatures[constants.TRANSFORM_SIGNATURE]
      wrapped, structured_inputs, structured_outputs = (
          _restore_from_v1_saved_model(restored_function, saved_model_dir))
    else:
      # transform_fn is now a ConcreteFunction, but was a tf.function. We need
      # to handle both to maintain backward compatiblity. If it's a tf.function,
      # since `input_signature` was specified when exporting the tf function to
      # `SavedModel`, there should be exactly one concrete function present on
      # loading the `SavedModel`.
      if hasattr(imported.transform_fn, 'concrete_functions'):
        concrete_functions = imported.transform_fn.concrete_functions
        assert len(concrete_functions) == 1, concrete_functions
        wrapped = concrete_functions[0]
      else:
        wrapped = imported.transform_fn
      func_graph = wrapped.graph
      structured_inputs = (
          tf2_utils.get_structured_inputs_from_func_graph(func_graph))
      structured_outputs = tf.nest.pack_sequence_as(
          func_graph.structured_outputs,
          func_graph.outputs,
          expand_composites=True)
    outputs_to_inputs_map = _get_output_to_inputs_map(structured_outputs)
    self._initialize(load_v2_in_compat, imported, wrapped, structured_inputs,
                     structured_outputs, outputs_to_inputs_map)
    saved_transform_io._maybe_register_addon_ops()  # pylint: disable=protected-access