Example #1
0
 def test_infer_feature_schema_bad_rank(self):
     with tf.Graph().as_default() as graph:
         tensors = {
             'a': tf.placeholder(tf.float32, ()),
         }
     with self.assertRaises(ValueError):
         schema_inference.infer_feature_schema(tensors, graph)
 def _get_schema(self,
                 preprocessing_fn,
                 use_compat_v1,
                 inputs=None,
                 input_signature=None,
                 create_session=False):
     if inputs is None:
         inputs = {}
     if input_signature is None:
         input_signature = {}
     if use_compat_v1:
         with tf.compat.v1.Graph().as_default() as graph:
             # Convert eager tensors to graph tensors.
             inputs_copy = {
                 k: tf.constant(v, input_signature[k].dtype)
                 for k, v in inputs.items()
             }
             tensors = preprocessing_fn(inputs_copy)
             if create_session:
                 # Create a session to actually evaluate the annotations and extract
                 # the output schema with annotations applied.
                 with tf.compat.v1.Session(graph=graph) as session:
                     schema = schema_inference.infer_feature_schema(
                         tensors, graph, session)
             else:
                 schema = schema_inference.infer_feature_schema(
                     tensors, graph)
     else:
         tf_func = tf.function(preprocessing_fn,
                               input_signature=[input_signature
                                                ]).get_concrete_function()
         tensors = tf.nest.pack_sequence_as(
             structure=tf_func.structured_outputs,
             flat_sequence=tf_func.outputs,
             expand_composites=True)
         structured_inputs = tf2_utils.get_structured_inputs_from_func_graph(
             tf_func.graph)
         tf_graph_context = graph_context.TFGraphContext(
             module_to_export=tf.Module(),
             temp_dir=os.path.join(self.get_temp_dir(),
                                   self._testMethodName),
             evaluated_replacements={})
         concrete_metadata_fn = schema_inference.get_traced_metadata_fn(
             preprocessing_fn=preprocessing_fn,
             structured_inputs=structured_inputs,
             tf_graph_context=tf_graph_context,
             evaluate_schema_overrides=create_session)
         schema = schema_inference.infer_feature_schema_v2(
             tensors,
             concrete_metadata_fn,
             evaluate_schema_overrides=create_session)
     return schema
  def test_infer_feature_schema(self, make_tensors_fn, feature_spec,
                                domains=None, create_session=False):
    with tf.Graph().as_default() as graph:
      tensors = make_tensors_fn()

    if create_session:
      with tf.Session(graph=graph) as session:
        schema = schema_inference.infer_feature_schema(tensors, graph, session)
    else:
      schema = schema_inference.infer_feature_schema(tensors, graph)

    expected_schema = dataset_schema.from_feature_spec(feature_spec, domains)
    self.assertEqual(schema, expected_schema)
    def test_global_annotation(self):
        # TODO(b/132098015): Schema annotations aren't yet supported in OSS builds.
        # pylint: disable=g-import-not-at-top
        try:
            from tensorflow_transform import annotations_pb2
        except ImportError:
            return
        # pylint: enable=g-import-not-at-top
        with tf.Graph().as_default() as graph:
            outputs = {
                'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64),
                'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64),
            }

            # Annotate an arbitrary proto at the schema level (not sure what global
            # schema boundaries would mean, but hey I'm just a test).
            boundaries = tf.constant([[1.0]])
            message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name
            sizes = tf.expand_dims([tf.size(boundaries)], axis=0)
            message_proto = encode_proto_op.encode_proto(
                sizes, [tf.cast(boundaries, tf.float32)], ['boundaries'],
                message_type)[0]
            type_url = os.path.join('type.googleapis.com', message_type)
            schema_inference.annotate(type_url, message_proto)

            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    outputs, graph, session)
                self.assertLen(schema._schema_proto.annotation.extra_metadata,
                               1)
                for annotation in schema._schema_proto.annotation.extra_metadata:
                    # Extract the annotated message and validate its contents
                    message = annotations_pb2.BucketBoundaries()
                    annotation.Unpack(message)
                    self.assertAllClose(list(message.boundaries), [1])
Example #5
0
    def test_infer_feature_schema_with_ragged_tensor(self):
        with tf.compat.v1.Graph().as_default() as graph:
            outputs = {
                'foo':
                tf.RaggedTensor.from_row_splits(values=tf.constant(
                    [3, 1, 4, 1, 5, 9, 2, 6], tf.int64),
                                                row_splits=[0, 4, 4, 7, 8, 8]),
            }
            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    outputs, graph, session)
                expected_schema_ascii = """feature {
  name: "foo"
  type: INT
  annotation {
    tag: "ragged_tensor"
  }
}
"""
                expected_schema = text_format.Parse(expected_schema_ascii,
                                                    schema_pb2.Schema())
                schema_utils_legacy.set_generate_legacy_feature_spec(
                    expected_schema, False)
                self.assertProtoEquals(expected_schema, schema)
                with self.assertRaisesRegexp(
                        ValueError, 'Feature "foo" had tag "ragged_tensor"'):
                    schema_utils.schema_as_feature_spec(schema)
    def testInferFeatureSchemaWithSession(self):
        with tf.Graph().as_default() as graph:
            tensors = {
                'a': tf.placeholder(tf.float32, (None, )),
                'b': tf.placeholder(tf.string, (1, 2, 3)),
                'c': tf.placeholder(tf.int64, (None, ))
            }
            schema_inference.set_tensor_schema_override(
                tensors['c'], tf.constant(5), tf.constant(6))
            with tf.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    tensors, graph, session)

        expected_schema = dataset_schema.Schema(
            column_schemas={
                'a':
                dataset_schema.ColumnSchema(
                    tf.float32, [],
                    dataset_schema.FixedColumnRepresentation()),
                'b':
                dataset_schema.ColumnSchema(
                    tf.string, [2, 3],
                    dataset_schema.FixedColumnRepresentation()),
                'c':
                dataset_schema.ColumnSchema(
                    dataset_schema.IntDomain(
                        tf.int64, 5, 6, is_categorical=True), [],
                    dataset_schema.FixedColumnRepresentation())
            })
        self.assertEqual(schema, expected_schema)
    def test_global_annotation(self):
        # pylint: enable=g-import-not-at-top
        with tf.compat.v1.Graph().as_default() as graph:
            outputs = {
                'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64),
                'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64),
            }

            # Annotate an arbitrary proto at the schema level (not sure what global
            # schema boundaries would mean, but hey I'm just a test).
            boundaries = tf.constant([[1.0]])
            message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name
            sizes = tf.expand_dims([tf.size(boundaries)], axis=0)
            message_proto = tf.raw_ops.EncodeProto(
                sizes=sizes,
                values=[tf.cast(boundaries, tf.float32)],
                field_names=['boundaries'],
                message_type=message_type)[0]
            type_url = os.path.join('type.googleapis.com', message_type)
            schema_inference.annotate(type_url, message_proto)

            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    outputs, graph, session)
                self.assertLen(schema.annotation.extra_metadata, 1)
                for annotation in schema.annotation.extra_metadata:
                    # Extract the annotated message and validate its contents
                    message = annotations_pb2.BucketBoundaries()
                    annotation.Unpack(message)
                    self.assertAllClose(list(message.boundaries), [1])
Example #8
0
def analyze_in_place(preprocessing_fn, force_tf_compat_v1, feature_specs,
                     type_specs, transform_output_path):
    """Analyzes the `preprocessing_fn` in-place without looking at the data.

  This should only be used if the `preprocessing_fn` contains no TFT
  analyzers or TFT mappers that use analyzers.

  Writes out a transform function and transformed metadata to subdirs under
  `transform_output_path`.

  Args:
    preprocessing_fn: The tf.Transform preprocessing_fn.
    force_tf_compat_v1: If True, call Transform's API to use Tensorflow in
      tf.compat.v1 mode.
    feature_specs: a Dict from input feature key to its feature spec.
    type_specs: a Dict from input feature key to its type spec.
    transform_output_path: An absolute path to write the output to.

  Raises:
    RuntimeError if `preprocessing_fn` contains TFT analyzers.
  """
    use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(force_tf_compat_v1)
    transform_fn_path = os.path.join(transform_output_path,
                                     TFTransformOutput.TRANSFORM_FN_DIR)
    if use_tf_compat_v1:
        graph, structured_inputs, structured_outputs = (
            trace_preprocessing_function(preprocessing_fn,
                                         feature_specs,
                                         use_tf_compat_v1=use_tf_compat_v1))
        _assert_no_analyzers_in_graph(graph)
        with tf.compat.v1.Session(graph=graph) as sess:
            sess.run(tf.compat.v1.global_variables_initializer())
            sess.run(tf.compat.v1.tables_initializer())
            saved_transform_io.write_saved_transform_from_session(
                sess, structured_inputs, structured_outputs, transform_fn_path)

            transformed_metadata = dataset_metadata.DatasetMetadata(
                schema=schema_inference.infer_feature_schema(
                    structured_outputs, graph, sess))
    else:
        concrete_transform_fn = _trace_and_write_transform_fn(
            saved_model_dir=transform_fn_path,
            preprocessing_fn=preprocessing_fn,
            input_signature=type_specs,
            base_temp_dir=None,
            tensor_replacement_map=None,
            output_keys_to_name_map=None)
        _assert_no_analyzers_in_graph(concrete_transform_fn.graph)
        structured_inputs = tf2_utils.get_structured_inputs_from_func_graph(
            concrete_transform_fn.graph)
        transformed_metadata = _trace_and_get_metadata(
            concrete_transform_fn=concrete_transform_fn,
            structured_inputs=structured_inputs,
            preprocessing_fn=preprocessing_fn,
            base_temp_dir=None,
            tensor_replacement_map=None)
    transformed_metadata_dir = os.path.join(
        transform_output_path, TFTransformOutput.TRANSFORMED_METADATA_DIR)
    metadata_io.write_metadata(transformed_metadata, transformed_metadata_dir)
Example #9
0
def _infer_metadata_from_saved_model(saved_model_dir):
  """Infers a DatasetMetadata for outputs of a SavedModel."""
  with tf.Graph().as_default() as graph:
    with tf.Session(graph=graph) as session:
      _, outputs = (
          saved_transform_io.partially_apply_saved_transform_internal(
              saved_model_dir, {}))

      session.run(tf.global_variables_initializer())
      session.run(tf.tables_initializer())
      return dataset_metadata.DatasetMetadata(
          schema=schema_inference.infer_feature_schema(outputs, graph, session))
    def test_bucketization_annotation(self):
        # TODO(b/132098015): Schema annotations aren't yet supported in OSS builds.
        # pylint: disable=g-import-not-at-top
        try:
            from tensorflow_transform import annotations_pb2
        except ImportError:
            return
        # pylint: enable=g-import-not-at-top
        with tf.Graph().as_default() as graph:
            inputs = {
                'foo': tf.convert_to_tensor([0, 1, 2, 3]),
                'bar': tf.convert_to_tensor([0, 2, 0, 2]),
            }
            boundaries_foo = tf.expand_dims(tf.convert_to_tensor([.5, 1.5]),
                                            axis=0)
            boundaries_bar = tf.expand_dims(tf.convert_to_tensor([.1, .2]),
                                            axis=0)
            outputs = {}

            # tft.apply_buckets will annotate the feature in the output schema to
            # indicate the bucket boundaries that were applied.
            outputs['Bucketized_foo'] = mappers.apply_buckets(
                inputs['foo'], boundaries_foo)
            outputs['Bucketized_bar'] = mappers.apply_buckets(
                inputs['bar'], boundaries_bar)
            # Create a session to actually evaluate the annotations and extract the
            # the output schema with annotations applied.
            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    outputs, graph, session)
                self.assertLen(schema.feature, 2)
                for feature in schema.feature:
                    self.assertLen(feature.annotation.extra_metadata, 1)
                    for annotation in feature.annotation.extra_metadata:

                        # Extract the annotated message and validate its contents
                        message = annotations_pb2.BucketBoundaries()
                        annotation.Unpack(message)
                        if feature.name == 'Bucketized_foo':
                            self.assertAllClose(list(message.boundaries),
                                                [.5, 1.5])
                        elif feature.name == 'Bucketized_bar':
                            self.assertAllClose(list(message.boundaries),
                                                [.1, .2])
                        else:
                            raise RuntimeError('Unexpected features in schema')
 def test_vocab_annotation(self):
     with tf.compat.v1.Graph().as_default() as graph:
         tensors = {
             'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64),
         }
         analyzers._maybe_annotate_vocab_metadata(
             'file1', tf.constant(100, dtype=tf.int64))
         analyzers._maybe_annotate_vocab_metadata(
             'file2', tf.constant(200, dtype=tf.int64))
         # Create a session to actually evaluate the annotations and extract the
         # the output schema with annotations applied.
         with tf.compat.v1.Session(graph=graph) as session:
             schema = schema_inference.infer_feature_schema(
                 tensors, graph, session)
             self.assertLen(schema.annotation.extra_metadata, 2)
             sizes = {}
             for annotation in schema.annotation.extra_metadata:
                 message = annotations_pb2.VocabularyMetadata()
                 annotation.Unpack(message)
                 sizes[
                     message.file_name] = message.unfiltered_vocabulary_size
             self.assertDictEqual(sizes, {'file1': 100, 'file2': 200})
Example #12
0
  def expand(self, dataset):
    """Analyze the dataset.

    Args:
      dataset: A dataset.

    Returns:
      A TransformFn containing the deferred transform function.

    Raises:
      ValueError: If preprocessing_fn has no outputs.
    """
    flattened_pcoll, input_values_pcoll_dict, input_metadata = dataset
    input_schema = input_metadata.schema

    input_values_pcoll_dict = input_values_pcoll_dict or dict()

    analyzer_cache.validate_dataset_keys(input_values_pcoll_dict.keys())

    with tf.Graph().as_default() as graph:

      with tf.name_scope('inputs'):
        feature_spec = input_schema.as_feature_spec()
        input_signature = impl_helper.feature_spec_as_batched_placeholders(
            feature_spec)
        # In order to avoid a bug where import_graph_def fails when the
        # input_map and return_elements of an imported graph are the same
        # (b/34288791), we avoid using the placeholder of an input column as an
        # output of a graph. We do this by applying tf.identity to all inputs of
        # the preprocessing_fn.  Note this applies at the level of raw tensors.
        # TODO(b/34288791): Remove this workaround and use a shallow copy of
        # inputs instead.  A shallow copy is needed in case
        # self._preprocessing_fn mutates its input.
        copied_inputs = impl_helper.copy_tensors(input_signature)

      output_signature = self._preprocessing_fn(copied_inputs)

    # At this point we check that the preprocessing_fn has at least one
    # output. This is because if we allowed the output of preprocessing_fn to
    # be empty, we wouldn't be able to determine how many instances to
    # "unbatch" the output into.
    if not output_signature:
      raise ValueError('The preprocessing function returned an empty dict')

    if graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
      raise ValueError(
          'The preprocessing function contained trainable variables '
          '{}'.format(
              graph.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)))

    pipeline = flattened_pcoll.pipeline
    serialized_tf_config = common._DEFAULT_TENSORFLOW_CONFIG_BY_RUNNER.get(  # pylint: disable=protected-access
        pipeline.runner)
    extra_args = common.ConstructBeamPipelineVisitor.ExtraArgs(
        base_temp_dir=Context.create_base_temp_dir(),
        serialized_tf_config=serialized_tf_config,
        pipeline=pipeline,
        flat_pcollection=flattened_pcoll,
        pcollection_dict=input_values_pcoll_dict,
        graph=graph,
        input_signature=input_signature,
        input_schema=input_schema,
        cache_location=self._cache_location)

    transform_fn_future = analysis_graph_builder.build(
        graph, input_signature, output_signature,
        input_values_pcoll_dict.keys(), self._cache_location)

    transform_fn_pcoll = nodes.Traverser(
        common.ConstructBeamPipelineVisitor(extra_args)).visit_value_node(
            transform_fn_future)

    # Infer metadata.  We take the inferred metadata and apply overrides that
    # refer to values of tensors in the graph.  The override tensors must
    # be "constant" in that they don't depend on input data.  The tensors can
    # depend on analyzer outputs though.  This allows us to set metadata that
    # depends on analyzer outputs. _augment_metadata will use the analyzer
    # outputs stored in `transform_fn` to compute the metadata in a
    # deferred manner, once the analyzer outputs are known.
    metadata = dataset_metadata.DatasetMetadata(
        schema=schema_inference.infer_feature_schema(output_signature, graph))

    deferred_metadata = (
        transform_fn_pcoll
        |
        'ComputeDeferredMetadata' >> beam.Map(_infer_metadata_from_saved_model))

    full_metadata = beam_metadata_io.BeamDatasetMetadata(
        metadata, deferred_metadata)

    _clear_shared_state_after_barrier(pipeline, transform_fn_pcoll)

    return transform_fn_pcoll, full_metadata
Example #13
0
  def expand(self, dataset):
    """Analyze the dataset.

    Args:
      dataset: A dataset.

    Returns:
      A TransformFn containing the deferred transform function.

    Raises:
      ValueError: If preprocessing_fn has no outputs.
    """
    (flattened_pcoll, input_values_pcoll_dict, dataset_cache_dict,
     input_metadata) = dataset
    if self._use_tfxio:
      input_schema = None
      input_tensor_adapter_config = input_metadata
    else:
      input_schema = input_metadata.schema
      input_tensor_adapter_config = None

    input_values_pcoll_dict = input_values_pcoll_dict or dict()

    with tf.compat.v1.Graph().as_default() as graph:

      with tf.compat.v1.name_scope('inputs'):
        if self._use_tfxio:
          specs = TensorAdapter(input_tensor_adapter_config).OriginalTypeSpecs()
        else:
          specs = schema_utils.schema_as_feature_spec(input_schema).feature_spec
        input_signature = impl_helper.batched_placeholders_from_specs(specs)
        # In order to avoid a bug where import_graph_def fails when the
        # input_map and return_elements of an imported graph are the same
        # (b/34288791), we avoid using the placeholder of an input column as an
        # output of a graph. We do this by applying tf.identity to all inputs of
        # the preprocessing_fn.  Note this applies at the level of raw tensors.
        # TODO(b/34288791): Remove this workaround and use a shallow copy of
        # inputs instead.  A shallow copy is needed in case
        # self._preprocessing_fn mutates its input.
        copied_inputs = impl_helper.copy_tensors(input_signature)

      output_signature = self._preprocessing_fn(copied_inputs)

    # At this point we check that the preprocessing_fn has at least one
    # output. This is because if we allowed the output of preprocessing_fn to
    # be empty, we wouldn't be able to determine how many instances to
    # "unbatch" the output into.
    if not output_signature:
      raise ValueError('The preprocessing function returned an empty dict')

    if graph.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES):
      raise ValueError(
          'The preprocessing function contained trainable variables '
          '{}'.format(
              graph.get_collection_ref(
                  tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)))

    pipeline = self.pipeline or (flattened_pcoll or next(
        v for v in input_values_pcoll_dict.values() if v is not None)).pipeline

    # Add a stage that inspects graph collections for API use counts and logs
    # them as a beam metric.
    _ = (pipeline | 'InstrumentAPI' >> _InstrumentAPI(graph))

    tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_BEAM_RUNNER_TYPE.get(
        type(pipeline.runner))
    extra_args = beam_common.ConstructBeamPipelineVisitor.ExtraArgs(
        base_temp_dir=Context.create_base_temp_dir(),
        tf_config=tf_config,
        pipeline=pipeline,
        flat_pcollection=flattened_pcoll,
        pcollection_dict=input_values_pcoll_dict,
        graph=graph,
        input_signature=input_signature,
        input_schema=input_schema,
        input_tensor_adapter_config=input_tensor_adapter_config,
        use_tfxio=self._use_tfxio,
        cache_pcoll_dict=dataset_cache_dict)

    transform_fn_future, cache_value_nodes = analysis_graph_builder.build(
        graph,
        input_signature,
        output_signature,
        input_values_pcoll_dict.keys(),
        cache_dict=dataset_cache_dict)
    traverser = nodes.Traverser(
        beam_common.ConstructBeamPipelineVisitor(extra_args))
    transform_fn_pcoll = traverser.visit_value_node(transform_fn_future)

    if cache_value_nodes is not None:
      output_cache_pcoll_dict = {}
      for (dataset_key,
           cache_key), value_node in six.iteritems(cache_value_nodes):
        if dataset_key not in output_cache_pcoll_dict:
          output_cache_pcoll_dict[dataset_key] = {}
        output_cache_pcoll_dict[dataset_key][cache_key] = (
            traverser.visit_value_node(value_node))
    else:
      output_cache_pcoll_dict = None

    # Infer metadata.  We take the inferred metadata and apply overrides that
    # refer to values of tensors in the graph.  The override tensors must
    # be "constant" in that they don't depend on input data.  The tensors can
    # depend on analyzer outputs though.  This allows us to set metadata that
    # depends on analyzer outputs. _infer_metadata_from_saved_model will use the
    # analyzer outputs stored in `transform_fn` to compute the metadata in a
    # deferred manner, once the analyzer outputs are known.
    metadata = dataset_metadata.DatasetMetadata(
        schema=schema_inference.infer_feature_schema(output_signature, graph))

    deferred_metadata = (
        transform_fn_pcoll
        |
        'ComputeDeferredMetadata' >> beam.Map(_infer_metadata_from_saved_model))

    full_metadata = beam_metadata_io.BeamDatasetMetadata(
        metadata, deferred_metadata)

    _clear_shared_state_after_barrier(pipeline, transform_fn_pcoll)

    return (transform_fn_pcoll, full_metadata), output_cache_pcoll_dict
Example #14
0
    def expand(self, dataset):
        """Analyze the dataset.

    Args:
      dataset: A dataset.

    Returns:
      A TransformFn containing the deferred transform function.

    Raises:
      ValueError: If preprocessing_fn has no outputs.
    """
        input_values, input_metadata = dataset
        input_schema = input_metadata.schema

        base_temp_dir = Context.create_base_temp_dir()

        with tf.Graph().as_default() as graph:

            with tf.name_scope('inputs'):
                feature_spec = input_schema.as_feature_spec()
                inputs = impl_helper.feature_spec_as_batched_placeholders(
                    feature_spec)
            # In order to avoid a bug where import_graph_def fails when the input_map
            # and return_elements of an imported graph are the same (b/34288791), we
            # avoid using the placeholder of an input column as an output of a graph.
            # We do this by applying tf.identity to all inputs of the
            # preprocessing_fn.  Note this applies at the level of raw tensors.
            outputs = self._preprocessing_fn(impl_helper.copy_tensors(inputs))

            # At this point we check that the preprocessing_fn has at least one
            # output. This is because if we allowed the output of preprocessing_fn to
            # be empty, we wouldn't be able to determine how many instances to
            # "unbatch" the output into.
            if not outputs:
                raise ValueError(
                    'The preprocessing function returned an empty dict')

            if graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
                raise ValueError(
                    'The preprocessing function contained trainable variables '
                    '{}'.format(
                        graph.get_collection_ref(
                            tf.GraphKeys.TRAINABLE_VARIABLES)))

            # NOTE: it's important that create_phases is called directly after
            # preprocessing_fn, because we later mutate the graph's TABLE_INITIALIZERS
            # collection which would break the logic in create_phases.
            phases = impl_helper.create_phases(inputs)

            # Iterate through levels.  tensor_pcoll_mapping is a mapping from tensor
            # names to singleton PCollections containing a _TensorValue.  We compute
            # tensor_pcoll_mapping in phases, where at each phase we compute the
            # analyzers that are ready to run and update tensor_pcoll_mapping.
            tensor_pcoll_mapping = {}
            table_initializers = graph.get_collection_ref(
                tf.GraphKeys.TABLE_INITIALIZERS)
            original_table_initializers = list(table_initializers)
            del table_initializers[:]

            serialized_tf_config = (
                common._DEFAULT_TENSORFLOW_CONFIG_BY_RUNNER.get(  # pylint: disable=protected-access
                    input_values.pipeline.runner))
            for level, phase in enumerate(phases):
                # Create a SavedModel that describes the mapping from the input data
                # to the inputs of the analyzers at this level.  The colum names of the
                # outputs are the tensor names of the analyzer inputs in the graph.
                # This graph has the anaylzer outputs computed so far replaced with
                # constants.
                analyzer_inputs = {}
                for analyzer in phase.analyzer_infos:
                    for input_tensor_name in analyzer.input_tensor_names:
                        analyzer_inputs[
                            input_tensor_name] = graph.get_tensor_by_name(
                                input_tensor_name)
                table_initializers.extend(phase.table_initializers)
                unbound_saved_model_dir = common.make_unique_temp_dir(
                    base_temp_dir)
                _write_saved_transform(graph, inputs, analyzer_inputs,
                                       unbound_saved_model_dir)

                tensor_pcoll_mapping_update = (
                    (input_values, tensor_pcoll_mapping)
                    | 'RunPhase[{}]'.format(level) >> _RunPhase(
                        phase.analyzer_infos, unbound_saved_model_dir,
                        base_temp_dir, input_schema, serialized_tf_config,
                        level))

                # Update the mapping for all analyzers.
                tensor_pcoll_mapping.update(tensor_pcoll_mapping_update)

            del table_initializers[:]
            table_initializers.extend(original_table_initializers)
            saved_model_dir = common.make_unique_temp_dir(base_temp_dir)
            _write_saved_transform(graph, inputs, outputs, saved_model_dir)
            transform_fn = (
                tensor_pcoll_mapping
                |
                'ReplaceTensorsWithConstants' >> _ReplaceTensorsWithConstants(
                    saved_model_dir, base_temp_dir, input_values.pipeline))

            # Infer metadata.  We take the inferred metadata and apply overrides that
            # refer to values of tensors in the graph.  The override tensors must
            # be "constant" in that they don't depend on input data.  The tensors can
            # depend on analyzer outputs though.  This allows us to set metadata that
            # depends on analyzer outputs. _augment_metadata will use the analyzer
            # outputs stored in `transform_fn` to compute the metadata in a
            # deferred manner, once the analyzer outputs are known.
            metadata = dataset_metadata.DatasetMetadata(
                schema=schema_inference.infer_feature_schema(outputs, graph))

            deferred_metadata = (transform_fn
                                 | 'ComputeDeferredMetadata' >>
                                 beam.Map(_infer_metadata_from_saved_model))

            full_metadata = beam_metadata_io.BeamDatasetMetadata(
                metadata, deferred_metadata)

            _clear_shared_state_after_barrier(input_values.pipeline,
                                              transform_fn)

            return transform_fn, full_metadata