def test_infer_feature_schema_bad_rank(self): with tf.Graph().as_default() as graph: tensors = { 'a': tf.placeholder(tf.float32, ()), } with self.assertRaises(ValueError): schema_inference.infer_feature_schema(tensors, graph)
def _get_schema(self, preprocessing_fn, use_compat_v1, inputs=None, input_signature=None, create_session=False): if inputs is None: inputs = {} if input_signature is None: input_signature = {} if use_compat_v1: with tf.compat.v1.Graph().as_default() as graph: # Convert eager tensors to graph tensors. inputs_copy = { k: tf.constant(v, input_signature[k].dtype) for k, v in inputs.items() } tensors = preprocessing_fn(inputs_copy) if create_session: # Create a session to actually evaluate the annotations and extract # the output schema with annotations applied. with tf.compat.v1.Session(graph=graph) as session: schema = schema_inference.infer_feature_schema( tensors, graph, session) else: schema = schema_inference.infer_feature_schema( tensors, graph) else: tf_func = tf.function(preprocessing_fn, input_signature=[input_signature ]).get_concrete_function() tensors = tf.nest.pack_sequence_as( structure=tf_func.structured_outputs, flat_sequence=tf_func.outputs, expand_composites=True) structured_inputs = tf2_utils.get_structured_inputs_from_func_graph( tf_func.graph) tf_graph_context = graph_context.TFGraphContext( module_to_export=tf.Module(), temp_dir=os.path.join(self.get_temp_dir(), self._testMethodName), evaluated_replacements={}) concrete_metadata_fn = schema_inference.get_traced_metadata_fn( preprocessing_fn=preprocessing_fn, structured_inputs=structured_inputs, tf_graph_context=tf_graph_context, evaluate_schema_overrides=create_session) schema = schema_inference.infer_feature_schema_v2( tensors, concrete_metadata_fn, evaluate_schema_overrides=create_session) return schema
def test_infer_feature_schema(self, make_tensors_fn, feature_spec, domains=None, create_session=False): with tf.Graph().as_default() as graph: tensors = make_tensors_fn() if create_session: with tf.Session(graph=graph) as session: schema = schema_inference.infer_feature_schema(tensors, graph, session) else: schema = schema_inference.infer_feature_schema(tensors, graph) expected_schema = dataset_schema.from_feature_spec(feature_spec, domains) self.assertEqual(schema, expected_schema)
def test_global_annotation(self): # TODO(b/132098015): Schema annotations aren't yet supported in OSS builds. # pylint: disable=g-import-not-at-top try: from tensorflow_transform import annotations_pb2 except ImportError: return # pylint: enable=g-import-not-at-top with tf.Graph().as_default() as graph: outputs = { 'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64), 'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64), } # Annotate an arbitrary proto at the schema level (not sure what global # schema boundaries would mean, but hey I'm just a test). boundaries = tf.constant([[1.0]]) message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name sizes = tf.expand_dims([tf.size(boundaries)], axis=0) message_proto = encode_proto_op.encode_proto( sizes, [tf.cast(boundaries, tf.float32)], ['boundaries'], message_type)[0] type_url = os.path.join('type.googleapis.com', message_type) schema_inference.annotate(type_url, message_proto) with tf.compat.v1.Session(graph=graph) as session: schema = schema_inference.infer_feature_schema( outputs, graph, session) self.assertLen(schema._schema_proto.annotation.extra_metadata, 1) for annotation in schema._schema_proto.annotation.extra_metadata: # Extract the annotated message and validate its contents message = annotations_pb2.BucketBoundaries() annotation.Unpack(message) self.assertAllClose(list(message.boundaries), [1])
def test_infer_feature_schema_with_ragged_tensor(self): with tf.compat.v1.Graph().as_default() as graph: outputs = { 'foo': tf.RaggedTensor.from_row_splits(values=tf.constant( [3, 1, 4, 1, 5, 9, 2, 6], tf.int64), row_splits=[0, 4, 4, 7, 8, 8]), } with tf.compat.v1.Session(graph=graph) as session: schema = schema_inference.infer_feature_schema( outputs, graph, session) expected_schema_ascii = """feature { name: "foo" type: INT annotation { tag: "ragged_tensor" } } """ expected_schema = text_format.Parse(expected_schema_ascii, schema_pb2.Schema()) schema_utils_legacy.set_generate_legacy_feature_spec( expected_schema, False) self.assertProtoEquals(expected_schema, schema) with self.assertRaisesRegexp( ValueError, 'Feature "foo" had tag "ragged_tensor"'): schema_utils.schema_as_feature_spec(schema)
def testInferFeatureSchemaWithSession(self): with tf.Graph().as_default() as graph: tensors = { 'a': tf.placeholder(tf.float32, (None, )), 'b': tf.placeholder(tf.string, (1, 2, 3)), 'c': tf.placeholder(tf.int64, (None, )) } schema_inference.set_tensor_schema_override( tensors['c'], tf.constant(5), tf.constant(6)) with tf.Session(graph=graph) as session: schema = schema_inference.infer_feature_schema( tensors, graph, session) expected_schema = dataset_schema.Schema( column_schemas={ 'a': dataset_schema.ColumnSchema( tf.float32, [], dataset_schema.FixedColumnRepresentation()), 'b': dataset_schema.ColumnSchema( tf.string, [2, 3], dataset_schema.FixedColumnRepresentation()), 'c': dataset_schema.ColumnSchema( dataset_schema.IntDomain( tf.int64, 5, 6, is_categorical=True), [], dataset_schema.FixedColumnRepresentation()) }) self.assertEqual(schema, expected_schema)
def test_global_annotation(self): # pylint: enable=g-import-not-at-top with tf.compat.v1.Graph().as_default() as graph: outputs = { 'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64), 'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64), } # Annotate an arbitrary proto at the schema level (not sure what global # schema boundaries would mean, but hey I'm just a test). boundaries = tf.constant([[1.0]]) message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name sizes = tf.expand_dims([tf.size(boundaries)], axis=0) message_proto = tf.raw_ops.EncodeProto( sizes=sizes, values=[tf.cast(boundaries, tf.float32)], field_names=['boundaries'], message_type=message_type)[0] type_url = os.path.join('type.googleapis.com', message_type) schema_inference.annotate(type_url, message_proto) with tf.compat.v1.Session(graph=graph) as session: schema = schema_inference.infer_feature_schema( outputs, graph, session) self.assertLen(schema.annotation.extra_metadata, 1) for annotation in schema.annotation.extra_metadata: # Extract the annotated message and validate its contents message = annotations_pb2.BucketBoundaries() annotation.Unpack(message) self.assertAllClose(list(message.boundaries), [1])
def analyze_in_place(preprocessing_fn, force_tf_compat_v1, feature_specs, type_specs, transform_output_path): """Analyzes the `preprocessing_fn` in-place without looking at the data. This should only be used if the `preprocessing_fn` contains no TFT analyzers or TFT mappers that use analyzers. Writes out a transform function and transformed metadata to subdirs under `transform_output_path`. Args: preprocessing_fn: The tf.Transform preprocessing_fn. force_tf_compat_v1: If True, call Transform's API to use Tensorflow in tf.compat.v1 mode. feature_specs: a Dict from input feature key to its feature spec. type_specs: a Dict from input feature key to its type spec. transform_output_path: An absolute path to write the output to. Raises: RuntimeError if `preprocessing_fn` contains TFT analyzers. """ use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(force_tf_compat_v1) transform_fn_path = os.path.join(transform_output_path, TFTransformOutput.TRANSFORM_FN_DIR) if use_tf_compat_v1: graph, structured_inputs, structured_outputs = ( trace_preprocessing_function(preprocessing_fn, feature_specs, use_tf_compat_v1=use_tf_compat_v1)) _assert_no_analyzers_in_graph(graph) with tf.compat.v1.Session(graph=graph) as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.tables_initializer()) saved_transform_io.write_saved_transform_from_session( sess, structured_inputs, structured_outputs, transform_fn_path) transformed_metadata = dataset_metadata.DatasetMetadata( schema=schema_inference.infer_feature_schema( structured_outputs, graph, sess)) else: concrete_transform_fn = _trace_and_write_transform_fn( saved_model_dir=transform_fn_path, preprocessing_fn=preprocessing_fn, input_signature=type_specs, base_temp_dir=None, tensor_replacement_map=None, output_keys_to_name_map=None) _assert_no_analyzers_in_graph(concrete_transform_fn.graph) structured_inputs = tf2_utils.get_structured_inputs_from_func_graph( concrete_transform_fn.graph) transformed_metadata = _trace_and_get_metadata( concrete_transform_fn=concrete_transform_fn, structured_inputs=structured_inputs, preprocessing_fn=preprocessing_fn, base_temp_dir=None, tensor_replacement_map=None) transformed_metadata_dir = os.path.join( transform_output_path, TFTransformOutput.TRANSFORMED_METADATA_DIR) metadata_io.write_metadata(transformed_metadata, transformed_metadata_dir)
def _infer_metadata_from_saved_model(saved_model_dir): """Infers a DatasetMetadata for outputs of a SavedModel.""" with tf.Graph().as_default() as graph: with tf.Session(graph=graph) as session: _, outputs = ( saved_transform_io.partially_apply_saved_transform_internal( saved_model_dir, {})) session.run(tf.global_variables_initializer()) session.run(tf.tables_initializer()) return dataset_metadata.DatasetMetadata( schema=schema_inference.infer_feature_schema(outputs, graph, session))
def test_bucketization_annotation(self): # TODO(b/132098015): Schema annotations aren't yet supported in OSS builds. # pylint: disable=g-import-not-at-top try: from tensorflow_transform import annotations_pb2 except ImportError: return # pylint: enable=g-import-not-at-top with tf.Graph().as_default() as graph: inputs = { 'foo': tf.convert_to_tensor([0, 1, 2, 3]), 'bar': tf.convert_to_tensor([0, 2, 0, 2]), } boundaries_foo = tf.expand_dims(tf.convert_to_tensor([.5, 1.5]), axis=0) boundaries_bar = tf.expand_dims(tf.convert_to_tensor([.1, .2]), axis=0) outputs = {} # tft.apply_buckets will annotate the feature in the output schema to # indicate the bucket boundaries that were applied. outputs['Bucketized_foo'] = mappers.apply_buckets( inputs['foo'], boundaries_foo) outputs['Bucketized_bar'] = mappers.apply_buckets( inputs['bar'], boundaries_bar) # Create a session to actually evaluate the annotations and extract the # the output schema with annotations applied. with tf.compat.v1.Session(graph=graph) as session: schema = schema_inference.infer_feature_schema( outputs, graph, session) self.assertLen(schema.feature, 2) for feature in schema.feature: self.assertLen(feature.annotation.extra_metadata, 1) for annotation in feature.annotation.extra_metadata: # Extract the annotated message and validate its contents message = annotations_pb2.BucketBoundaries() annotation.Unpack(message) if feature.name == 'Bucketized_foo': self.assertAllClose(list(message.boundaries), [.5, 1.5]) elif feature.name == 'Bucketized_bar': self.assertAllClose(list(message.boundaries), [.1, .2]) else: raise RuntimeError('Unexpected features in schema')
def test_vocab_annotation(self): with tf.compat.v1.Graph().as_default() as graph: tensors = { 'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64), } analyzers._maybe_annotate_vocab_metadata( 'file1', tf.constant(100, dtype=tf.int64)) analyzers._maybe_annotate_vocab_metadata( 'file2', tf.constant(200, dtype=tf.int64)) # Create a session to actually evaluate the annotations and extract the # the output schema with annotations applied. with tf.compat.v1.Session(graph=graph) as session: schema = schema_inference.infer_feature_schema( tensors, graph, session) self.assertLen(schema.annotation.extra_metadata, 2) sizes = {} for annotation in schema.annotation.extra_metadata: message = annotations_pb2.VocabularyMetadata() annotation.Unpack(message) sizes[ message.file_name] = message.unfiltered_vocabulary_size self.assertDictEqual(sizes, {'file1': 100, 'file2': 200})
def expand(self, dataset): """Analyze the dataset. Args: dataset: A dataset. Returns: A TransformFn containing the deferred transform function. Raises: ValueError: If preprocessing_fn has no outputs. """ flattened_pcoll, input_values_pcoll_dict, input_metadata = dataset input_schema = input_metadata.schema input_values_pcoll_dict = input_values_pcoll_dict or dict() analyzer_cache.validate_dataset_keys(input_values_pcoll_dict.keys()) with tf.Graph().as_default() as graph: with tf.name_scope('inputs'): feature_spec = input_schema.as_feature_spec() input_signature = impl_helper.feature_spec_as_batched_placeholders( feature_spec) # In order to avoid a bug where import_graph_def fails when the # input_map and return_elements of an imported graph are the same # (b/34288791), we avoid using the placeholder of an input column as an # output of a graph. We do this by applying tf.identity to all inputs of # the preprocessing_fn. Note this applies at the level of raw tensors. # TODO(b/34288791): Remove this workaround and use a shallow copy of # inputs instead. A shallow copy is needed in case # self._preprocessing_fn mutates its input. copied_inputs = impl_helper.copy_tensors(input_signature) output_signature = self._preprocessing_fn(copied_inputs) # At this point we check that the preprocessing_fn has at least one # output. This is because if we allowed the output of preprocessing_fn to # be empty, we wouldn't be able to determine how many instances to # "unbatch" the output into. if not output_signature: raise ValueError('The preprocessing function returned an empty dict') if graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): raise ValueError( 'The preprocessing function contained trainable variables ' '{}'.format( graph.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES))) pipeline = flattened_pcoll.pipeline serialized_tf_config = common._DEFAULT_TENSORFLOW_CONFIG_BY_RUNNER.get( # pylint: disable=protected-access pipeline.runner) extra_args = common.ConstructBeamPipelineVisitor.ExtraArgs( base_temp_dir=Context.create_base_temp_dir(), serialized_tf_config=serialized_tf_config, pipeline=pipeline, flat_pcollection=flattened_pcoll, pcollection_dict=input_values_pcoll_dict, graph=graph, input_signature=input_signature, input_schema=input_schema, cache_location=self._cache_location) transform_fn_future = analysis_graph_builder.build( graph, input_signature, output_signature, input_values_pcoll_dict.keys(), self._cache_location) transform_fn_pcoll = nodes.Traverser( common.ConstructBeamPipelineVisitor(extra_args)).visit_value_node( transform_fn_future) # Infer metadata. We take the inferred metadata and apply overrides that # refer to values of tensors in the graph. The override tensors must # be "constant" in that they don't depend on input data. The tensors can # depend on analyzer outputs though. This allows us to set metadata that # depends on analyzer outputs. _augment_metadata will use the analyzer # outputs stored in `transform_fn` to compute the metadata in a # deferred manner, once the analyzer outputs are known. metadata = dataset_metadata.DatasetMetadata( schema=schema_inference.infer_feature_schema(output_signature, graph)) deferred_metadata = ( transform_fn_pcoll | 'ComputeDeferredMetadata' >> beam.Map(_infer_metadata_from_saved_model)) full_metadata = beam_metadata_io.BeamDatasetMetadata( metadata, deferred_metadata) _clear_shared_state_after_barrier(pipeline, transform_fn_pcoll) return transform_fn_pcoll, full_metadata
def expand(self, dataset): """Analyze the dataset. Args: dataset: A dataset. Returns: A TransformFn containing the deferred transform function. Raises: ValueError: If preprocessing_fn has no outputs. """ (flattened_pcoll, input_values_pcoll_dict, dataset_cache_dict, input_metadata) = dataset if self._use_tfxio: input_schema = None input_tensor_adapter_config = input_metadata else: input_schema = input_metadata.schema input_tensor_adapter_config = None input_values_pcoll_dict = input_values_pcoll_dict or dict() with tf.compat.v1.Graph().as_default() as graph: with tf.compat.v1.name_scope('inputs'): if self._use_tfxio: specs = TensorAdapter(input_tensor_adapter_config).OriginalTypeSpecs() else: specs = schema_utils.schema_as_feature_spec(input_schema).feature_spec input_signature = impl_helper.batched_placeholders_from_specs(specs) # In order to avoid a bug where import_graph_def fails when the # input_map and return_elements of an imported graph are the same # (b/34288791), we avoid using the placeholder of an input column as an # output of a graph. We do this by applying tf.identity to all inputs of # the preprocessing_fn. Note this applies at the level of raw tensors. # TODO(b/34288791): Remove this workaround and use a shallow copy of # inputs instead. A shallow copy is needed in case # self._preprocessing_fn mutates its input. copied_inputs = impl_helper.copy_tensors(input_signature) output_signature = self._preprocessing_fn(copied_inputs) # At this point we check that the preprocessing_fn has at least one # output. This is because if we allowed the output of preprocessing_fn to # be empty, we wouldn't be able to determine how many instances to # "unbatch" the output into. if not output_signature: raise ValueError('The preprocessing function returned an empty dict') if graph.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES): raise ValueError( 'The preprocessing function contained trainable variables ' '{}'.format( graph.get_collection_ref( tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES))) pipeline = self.pipeline or (flattened_pcoll or next( v for v in input_values_pcoll_dict.values() if v is not None)).pipeline # Add a stage that inspects graph collections for API use counts and logs # them as a beam metric. _ = (pipeline | 'InstrumentAPI' >> _InstrumentAPI(graph)) tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_BEAM_RUNNER_TYPE.get( type(pipeline.runner)) extra_args = beam_common.ConstructBeamPipelineVisitor.ExtraArgs( base_temp_dir=Context.create_base_temp_dir(), tf_config=tf_config, pipeline=pipeline, flat_pcollection=flattened_pcoll, pcollection_dict=input_values_pcoll_dict, graph=graph, input_signature=input_signature, input_schema=input_schema, input_tensor_adapter_config=input_tensor_adapter_config, use_tfxio=self._use_tfxio, cache_pcoll_dict=dataset_cache_dict) transform_fn_future, cache_value_nodes = analysis_graph_builder.build( graph, input_signature, output_signature, input_values_pcoll_dict.keys(), cache_dict=dataset_cache_dict) traverser = nodes.Traverser( beam_common.ConstructBeamPipelineVisitor(extra_args)) transform_fn_pcoll = traverser.visit_value_node(transform_fn_future) if cache_value_nodes is not None: output_cache_pcoll_dict = {} for (dataset_key, cache_key), value_node in six.iteritems(cache_value_nodes): if dataset_key not in output_cache_pcoll_dict: output_cache_pcoll_dict[dataset_key] = {} output_cache_pcoll_dict[dataset_key][cache_key] = ( traverser.visit_value_node(value_node)) else: output_cache_pcoll_dict = None # Infer metadata. We take the inferred metadata and apply overrides that # refer to values of tensors in the graph. The override tensors must # be "constant" in that they don't depend on input data. The tensors can # depend on analyzer outputs though. This allows us to set metadata that # depends on analyzer outputs. _infer_metadata_from_saved_model will use the # analyzer outputs stored in `transform_fn` to compute the metadata in a # deferred manner, once the analyzer outputs are known. metadata = dataset_metadata.DatasetMetadata( schema=schema_inference.infer_feature_schema(output_signature, graph)) deferred_metadata = ( transform_fn_pcoll | 'ComputeDeferredMetadata' >> beam.Map(_infer_metadata_from_saved_model)) full_metadata = beam_metadata_io.BeamDatasetMetadata( metadata, deferred_metadata) _clear_shared_state_after_barrier(pipeline, transform_fn_pcoll) return (transform_fn_pcoll, full_metadata), output_cache_pcoll_dict
def expand(self, dataset): """Analyze the dataset. Args: dataset: A dataset. Returns: A TransformFn containing the deferred transform function. Raises: ValueError: If preprocessing_fn has no outputs. """ input_values, input_metadata = dataset input_schema = input_metadata.schema base_temp_dir = Context.create_base_temp_dir() with tf.Graph().as_default() as graph: with tf.name_scope('inputs'): feature_spec = input_schema.as_feature_spec() inputs = impl_helper.feature_spec_as_batched_placeholders( feature_spec) # In order to avoid a bug where import_graph_def fails when the input_map # and return_elements of an imported graph are the same (b/34288791), we # avoid using the placeholder of an input column as an output of a graph. # We do this by applying tf.identity to all inputs of the # preprocessing_fn. Note this applies at the level of raw tensors. outputs = self._preprocessing_fn(impl_helper.copy_tensors(inputs)) # At this point we check that the preprocessing_fn has at least one # output. This is because if we allowed the output of preprocessing_fn to # be empty, we wouldn't be able to determine how many instances to # "unbatch" the output into. if not outputs: raise ValueError( 'The preprocessing function returned an empty dict') if graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): raise ValueError( 'The preprocessing function contained trainable variables ' '{}'.format( graph.get_collection_ref( tf.GraphKeys.TRAINABLE_VARIABLES))) # NOTE: it's important that create_phases is called directly after # preprocessing_fn, because we later mutate the graph's TABLE_INITIALIZERS # collection which would break the logic in create_phases. phases = impl_helper.create_phases(inputs) # Iterate through levels. tensor_pcoll_mapping is a mapping from tensor # names to singleton PCollections containing a _TensorValue. We compute # tensor_pcoll_mapping in phases, where at each phase we compute the # analyzers that are ready to run and update tensor_pcoll_mapping. tensor_pcoll_mapping = {} table_initializers = graph.get_collection_ref( tf.GraphKeys.TABLE_INITIALIZERS) original_table_initializers = list(table_initializers) del table_initializers[:] serialized_tf_config = ( common._DEFAULT_TENSORFLOW_CONFIG_BY_RUNNER.get( # pylint: disable=protected-access input_values.pipeline.runner)) for level, phase in enumerate(phases): # Create a SavedModel that describes the mapping from the input data # to the inputs of the analyzers at this level. The colum names of the # outputs are the tensor names of the analyzer inputs in the graph. # This graph has the anaylzer outputs computed so far replaced with # constants. analyzer_inputs = {} for analyzer in phase.analyzer_infos: for input_tensor_name in analyzer.input_tensor_names: analyzer_inputs[ input_tensor_name] = graph.get_tensor_by_name( input_tensor_name) table_initializers.extend(phase.table_initializers) unbound_saved_model_dir = common.make_unique_temp_dir( base_temp_dir) _write_saved_transform(graph, inputs, analyzer_inputs, unbound_saved_model_dir) tensor_pcoll_mapping_update = ( (input_values, tensor_pcoll_mapping) | 'RunPhase[{}]'.format(level) >> _RunPhase( phase.analyzer_infos, unbound_saved_model_dir, base_temp_dir, input_schema, serialized_tf_config, level)) # Update the mapping for all analyzers. tensor_pcoll_mapping.update(tensor_pcoll_mapping_update) del table_initializers[:] table_initializers.extend(original_table_initializers) saved_model_dir = common.make_unique_temp_dir(base_temp_dir) _write_saved_transform(graph, inputs, outputs, saved_model_dir) transform_fn = ( tensor_pcoll_mapping | 'ReplaceTensorsWithConstants' >> _ReplaceTensorsWithConstants( saved_model_dir, base_temp_dir, input_values.pipeline)) # Infer metadata. We take the inferred metadata and apply overrides that # refer to values of tensors in the graph. The override tensors must # be "constant" in that they don't depend on input data. The tensors can # depend on analyzer outputs though. This allows us to set metadata that # depends on analyzer outputs. _augment_metadata will use the analyzer # outputs stored in `transform_fn` to compute the metadata in a # deferred manner, once the analyzer outputs are known. metadata = dataset_metadata.DatasetMetadata( schema=schema_inference.infer_feature_schema(outputs, graph)) deferred_metadata = (transform_fn | 'ComputeDeferredMetadata' >> beam.Map(_infer_metadata_from_saved_model)) full_metadata = beam_metadata_io.BeamDatasetMetadata( metadata, deferred_metadata) _clear_shared_state_after_barrier(input_values.pipeline, transform_fn) return transform_fn, full_metadata