def test_schema_from_feature_spec_fails(self, feature_spec, error_msg, domain=None, error_class=ValueError): with self.assertRaisesRegexp(error_class, error_msg): schema_utils.schema_from_feature_spec(feature_spec, domain)
def testConvertToRecordBatchPassthroughData(self): passthrough_key1 = '__passthrough_with_batch_length__' passthrough_key2 = '__passthrough_with_one_value__' passthrough_key3 = '__passthrough_with_one_distinct_value_none__' passthrough_key4 = '__passthrough_with_one_distinct_value_not_none__' batch_dict = { 'a': np.array([100, 1, 10], np.int64), passthrough_key1: pa.array([[1], None, [0]], pa.large_list(pa.int64())), passthrough_key2: pa.array([None], pa.large_list(pa.float32())), passthrough_key3: pa.array([None, None], pa.large_list(pa.large_binary())), passthrough_key4: pa.array([[10], [10]], pa.large_list(pa.int64())) } schema = schema_utils.schema_from_feature_spec( {'a': tf.io.FixedLenFeature([], tf.int64)}) converter = impl_helper.make_tensor_to_arrow_converter(schema) passthrough_keys = { passthrough_key1, passthrough_key2, passthrough_key3, passthrough_key4 } arrow_schema = pa.schema([ ('a', pa.large_list(pa.int64())), (passthrough_key1, batch_dict[passthrough_key1].type), (passthrough_key2, batch_dict[passthrough_key2].type), (passthrough_key3, batch_dict[passthrough_key3].type), (passthrough_key4, batch_dict[passthrough_key4].type) ]) # Note that we only need `input_metadata.arrow_schema`. input_metadata = TensorAdapterConfig(arrow_schema, {}) record_batch, unary_features = impl._convert_to_record_batch( batch_dict, schema, converter, passthrough_keys, input_metadata) expected_record_batch = { 'a': [[100], [1], [10]], passthrough_key1: [[1], None, [0]] } self.assertDictEqual(expected_record_batch, record_batch.to_pydict()) expected_unary_features = { passthrough_key2: [None], passthrough_key3: [None], passthrough_key4: [[10]] } unary_features = {k: v.to_pylist() for k, v in unary_features.items()} self.assertDictEqual(expected_unary_features, unary_features) # Test pass-through data when input and output batch sizes are different and # the number of its unique values is >1. passthrough_key5 = '__passthrough_with_wrong_batch_size__' passthrough_keys.add(passthrough_key5) batch_dict[passthrough_key5] = pa.array([[1], [2]], pa.large_list(pa.int64())) input_metadata.arrow_schema = input_metadata.arrow_schema.append( pa.field(passthrough_key5, batch_dict[passthrough_key5].type)) with self.assertRaisesRegexp( ValueError, 'Cannot pass-through data when ' 'input and output batch sizes are different'): _ = impl._convert_to_record_batch(batch_dict, schema, converter, passthrough_keys, input_metadata)
def main(_): # Define schema. raw_metadata = dataset_metadata.DatasetMetadata( schema_utils.schema_from_feature_spec({ 'text': tf.FixedLenFeature([], tf.string), 'language_code': tf.FixedLenFeature([], tf.string), })) # Add in padding tokens. reserved_tokens = FLAGS.reserved_tokens if FLAGS.num_pad_tokens: padded_tokens = ['<pad>'] padded_tokens += [ '<pad%d>' % i for i in range(1, FLAGS.num_pad_tokens) ] reserved_tokens = padded_tokens + reserved_tokens params = learner.Params(FLAGS.upper_thresh, FLAGS.lower_thresh, FLAGS.num_iterations, FLAGS.max_input_tokens, FLAGS.max_token_length, FLAGS.max_unique_chars, FLAGS.vocab_size, FLAGS.slack_ratio, FLAGS.include_joiner_token, FLAGS.joiner, reserved_tokens) generate_vocab(FLAGS.data_file, FLAGS.vocab_file, FLAGS.metrics_file, raw_metadata, params)
def test_no_data_needed(self): span_0_key = 'span-0' span_1_key = 'span-1' def preprocessing_fn(inputs): return {k: tf.identity(v) for k, v in six.iteritems(inputs)} input_metadata = dataset_metadata.DatasetMetadata( schema_utils.schema_from_feature_spec({ 'x': tf.io.FixedLenFeature([], tf.float32), })) input_data_dict = { span_0_key: None, span_1_key: None, } with _TestPipeline() as p: flat_data = None cache_dict = { span_0_key: {}, span_1_key: {}, } _, output_cache = ( (flat_data, input_data_dict, cache_dict, input_metadata) | 'Analyze' >> tft_beam.AnalyzeDatasetWithCache( preprocessing_fn, pipeline=p)) self.assertFalse(output_cache)
def tfrecord_schema(num_detections: int) -> schema_pb2.Schema: return schema_utils.schema_from_feature_spec({ "boxes_xmax": tf.io.FixedLenFeature([num_detections], tf.float32), "boxes_xmin": tf.io.FixedLenFeature([num_detections], tf.float32), "boxes_ymax": tf.io.FixedLenFeature([num_detections], tf.float32), "boxes_ymin": tf.io.FixedLenFeature([num_detections], tf.float32), "client_version": tf.io.FixedLenFeature([], tf.string), "detection_classes": tf.io.FixedLenFeature([num_detections], tf.int64), "detection_scores": tf.io.FixedLenFeature([num_detections], tf.float32), "cloudiot_device_id": tf.io.FixedLenFeature([], tf.int64), "octoprint_device_id": tf.io.FixedLenFeature([], tf.int64), "image_data": tf.io.FixedLenFeature([], tf.string), "image_height": tf.io.FixedLenFeature([], tf.int64), "image_width": tf.io.FixedLenFeature([], tf.int64), "num_detections": tf.io.FixedLenFeature([], tf.float32), "print_session": tf.io.FixedLenFeature([], tf.string), "ts": tf.io.FixedLenFeature([], tf.int64), "user_id": tf.io.FixedLenFeature([], tf.int64), })
def testPreprocessingFn(self): schema_file = os.path.join(self._testdata_path, 'schema_gen/schema.pbtxt') schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema()) feature_spec = taxi_utils._get_raw_feature_spec(schema) working_dir = self.get_temp_dir() transform_graph_path = os.path.join(working_dir, 'transform_graph') transformed_examples_path = os.path.join( working_dir, 'transformed_examples') # Run very simplified version of executor logic. # TODO(kestert): Replace with tft_unit.assertAnalyzeAndTransformResults. # Generate legacy `DatasetMetadata` object. Future version of Transform # will accept the `Schema` proto directly. legacy_metadata = dataset_metadata.DatasetMetadata( schema_utils.schema_from_feature_spec(feature_spec)) tfxio = tf_example_record.TFExampleRecord( file_pattern=os.path.join(self._testdata_path, 'csv_example_gen/Split-train/*'), telemetry_descriptors=['Tests'], schema=legacy_metadata.schema) with beam.Pipeline() as p: with tft_beam.Context(temp_dir=os.path.join(working_dir, 'tmp')): examples = p | 'ReadTrainData' >> tfxio.BeamSource() (transformed_examples, transformed_metadata), transform_fn = ( (examples, tfxio.TensorAdapterConfig()) | 'AnalyzeAndTransform' >> tft_beam.AnalyzeAndTransformDataset( taxi_utils.preprocessing_fn)) # WriteTransformFn writes transform_fn and metadata to subdirectories # tensorflow_transform.SAVED_MODEL_DIR and # tensorflow_transform.TRANSFORMED_METADATA_DIR respectively. # pylint: disable=expression-not-assigned (transform_fn | 'WriteTransformFn' >> tft_beam.WriteTransformFn(transform_graph_path)) encoder = tft.coders.ExampleProtoCoder(transformed_metadata.schema) (transformed_examples | 'EncodeTrainData' >> beam.Map(encoder.encode) | 'WriteTrainData' >> beam.io.WriteToTFRecord( os.path.join(transformed_examples_path, 'Split-train/transformed_examples.gz'), coder=beam.coders.BytesCoder())) # pylint: enable=expression-not-assigned # Verify the output matches golden output. # NOTE: we don't verify that transformed examples match golden output. expected_transformed_schema = io_utils.parse_pbtxt_file( os.path.join( self._testdata_path, 'transform/transform_graph/transformed_metadata/schema.pbtxt'), schema_pb2.Schema()) transformed_schema = io_utils.parse_pbtxt_file( os.path.join(transform_graph_path, 'transformed_metadata/schema.pbtxt'), schema_pb2.Schema()) # Clear annotations so we only have to test main schema. transformed_schema.ClearField('annotation') for feature in transformed_schema.feature: feature.ClearField('annotation') self.assertEqual(transformed_schema, expected_transformed_schema)
def test_make_feed_list(self, feature_spec, instances, feed_dict): schema = schema_utils.schema_from_feature_spec(feature_spec) feature_names = list(feature_spec.keys()) expected_feed_list = [feed_dict[key] for key in feature_names] np.testing.assert_equal( impl_helper.make_feed_list(feature_names, schema, instances), expected_feed_list)
def _get_common_variables(dataset, force_tf_compat_v1): """Returns metadata schema, preprocessing fn, input dataset metadata.""" tf_metadata_schema = benchmark_utils.read_schema( dataset.tf_metadata_schema_path()) preprocessing_fn = dataset.tft_preprocessing_fn() feature_spec = schema_utils.schema_as_feature_spec( tf_metadata_schema).feature_spec type_spec = impl_helper.get_type_specs_from_feature_specs(feature_spec) transform_input_columns = ( tft.get_transform_input_columns( preprocessing_fn, type_spec, force_tf_compat_v1=force_tf_compat_v1)) transform_input_dataset_metadata = dataset_metadata.DatasetMetadata( schema_utils.schema_from_feature_spec({ feature: feature_spec[feature] for feature in transform_input_columns })) tfxio = tf_example_record.TFExampleBeamRecord( physical_format="tfexamples", schema=transform_input_dataset_metadata.schema, telemetry_descriptors=["TFTransformBenchmark"]) return CommonVariablesTuple( tf_metadata_schema=tf_metadata_schema, preprocessing_fn=preprocessing_fn, transform_input_dataset_metadata=transform_input_dataset_metadata, tfxio=tfxio)
def test_encode_non_serialized(self, feature_spec, ascii_proto, instance, **kwargs): schema = schema_utils.schema_from_feature_spec(feature_spec) coder = example_proto_coder.ExampleProtoCoder( schema, serialized=False, **kwargs) proto = _ascii_to_example(ascii_proto) np.testing.assert_equal(coder.encode(instance), proto)
def test_convert_to_arrow(self, feature_spec, instances, feed_dict, feed_eager_tensors): if feed_eager_tensors: test_case.skip_if_not_tf2('Tensorflow 2.x required') schema = schema_utils.schema_from_feature_spec(feature_spec) converter = impl_helper.make_tensor_to_arrow_converter(schema) feed_dict_local = copy.copy(feed_dict) if feed_eager_tensors: for key, value in six.iteritems(feed_dict_local): if isinstance(value, tf.compat.v1.SparseTensorValue): feed_dict_local[key] = tf.sparse.SparseTensor.from_value( value) else: feed_dict_local[key] = tf.constant(value) arrow_columns, arrow_schema = impl_helper.convert_to_arrow( schema, converter, feed_dict_local) record_batch = pa.RecordBatch.from_arrays(arrow_columns, arrow_schema) # Merge and flatten expected instance dicts. expected = collections.defaultdict(list) for instance_dict in instances: for key, value in instance_dict.items(): expected[key].append(np.ravel(value)) actual = record_batch.to_pydict() self.assertEqual(len(actual), len(expected)) for key, expected_value in expected.items(): # Floating-point error breaks exact equality for some floating values. # However, the approximate equality testing fails on strings. if np.issubdtype(expected_value[0].dtype, np.number): self.assertAllClose(actual[key], expected_value) else: np.testing.assert_equal(actual[key], expected_value)
def run_fn(fn_args: TrainerFnArgs): """Train the model based on given args. Args: fn_args: Holds args used to train the model as name/value pairs. """ # This schema is usually either an output of SchemaGen or a manually-curated # version provided by pipeline author. A schema can also derived from TFT # graph if a Transform component is used. In the case when either is missing, # `schema_from_feature_spec` could be used to generate schema from very simple # feature_spec, but the schema returned would be very primitive. schema = schema_utils.schema_from_feature_spec(_FEATURE_SPEC) train_dataset = _input_fn(fn_args.train_files, fn_args.data_accessor, schema, batch_size=_TRAIN_BATCH_SIZE) eval_dataset = _input_fn(fn_args.eval_files, fn_args.data_accessor, schema, batch_size=_EVAL_BATCH_SIZE) model = _build_keras_model() model.fit(train_dataset, steps_per_epoch=fn_args.train_steps, validation_data=eval_dataset, validation_steps=fn_args.eval_steps) # The result of the training should be saved in `fn_args.serving_model_dir` # directory. model.save(fn_args.serving_model_dir, save_format='tf')
def from_feature_spec( cls: Type[_DatasetMetadataType], feature_spec: Mapping[str, common_types.FeatureSpecType], domains: Optional[Mapping[str, common_types.DomainType]] = None ) -> _DatasetMetadataType: """Creates a DatasetMetadata from a TF feature spec dict.""" return cls(schema_utils.schema_from_feature_spec( feature_spec, domains))
def _validate_column_schemas(self): """Validate that this Schema can be represented as a schema_pb2.Schema.""" feature_spec = self.as_feature_spec() int_domains = {} for name, column_schema in self._column_schemas.items(): domain = column_schema.domain if isinstance(domain, IntDomain): int_domains[name] = schema_pb2.IntDomain( min=domain.min_value, max=domain.max_value, is_categorical=domain.is_categorical) try: schema_utils.schema_from_feature_spec(feature_spec, int_domains) except Exception as e: raise ValueError( 'The values of column_schemas were invalid, as detected when ' 'converting them to a schema_pb2.Schema proto. Original error: ' '{}'.format(e.message))
def test_to_instance_dicts_error(self, feature_spec, feed_dict, error_msg, error_type=ValueError): schema = schema_utils.schema_from_feature_spec(feature_spec) with self.assertRaisesRegexp(error_type, error_msg): impl_helper.to_instance_dicts(schema, feed_dict)
def test_schema_from_feature_spec( self, ascii_proto, feature_spec, domains=None, generate_legacy_feature_spec=False): expected_schema_proto = text_format.Parse(ascii_proto, schema_pb2.Schema()) schema_utils_legacy.set_generate_legacy_feature_spec( expected_schema_proto, generate_legacy_feature_spec) result = schema_utils.schema_from_feature_spec(feature_spec, domains) self.assertEqual(result, expected_schema_proto)
def test_schema_from_feature_spec(self, ascii_proto, feature_spec, domains=None): expected_schema_proto = _parse_schema_ascii_proto(ascii_proto) self.assertEqual( schema_utils.schema_from_feature_spec(feature_spec, domains), expected_schema_proto)
def __init__(self, column_schemas): feature_spec = {name: spec for name, (_, spec) in column_schemas.items()} domains = {name: domain for name, (domain, _) in column_schemas.items() if domain is not None} self._schema_proto = schema_utils.schema_from_feature_spec( feature_spec, domains)
def _infer_feature_schema_common(features, tensor_ranges, feature_annotations, global_annotations): """Given a dict of tensors, creates a `Schema`. Args: features: A dict mapping column names to `Tensor` or `SparseTensor`s. The `Tensor` or `SparseTensor`s should have a 0'th dimension which is interpreted as the batch dimension. tensor_ranges: A dict mapping a tensor to a tuple containing its min and max value. feature_annotations: dictionary from feature name to list of any_pb2.Any protos to be added as an annotation for that feature in the schema. global_annotations: list of any_pb2.Any protos to be added at the global schema level. Returns: A `Schema` proto. """ domains = {} feature_tags = collections.defaultdict(list) for name, tensor in six.iteritems(features): if isinstance(tensor, tf.RaggedTensor): # Add the 'ragged_tensor' tag which will cause coder and # schema_as_feature_spec to raise an error, as currently there is no # feature spec for ragged tensors. feature_tags[name].append(schema_utils.RAGGED_TENSOR_TAG) if name in tensor_ranges: min_value, max_value = tensor_ranges[name] domains[name] = schema_pb2.IntDomain(min=min_value, max=max_value, is_categorical=True) feature_spec = _feature_spec_from_batched_tensors(features) schema_proto = schema_utils.schema_from_feature_spec(feature_spec, domains) # Add the annotations to the schema. for annotation in global_annotations: schema_proto.annotation.extra_metadata.add().CopyFrom(annotation) # Build a map from logical feature names to Feature protos feature_protos_by_name = {} for feature in schema_proto.feature: feature_protos_by_name[feature.name] = feature for sparse_feature in schema_proto.sparse_feature: for index_feature in sparse_feature.index_feature: feature_protos_by_name.pop(index_feature.name) value_feature = feature_protos_by_name.pop( sparse_feature.value_feature.name) feature_protos_by_name[sparse_feature.name] = value_feature # Update annotations for feature_name, annotations in feature_annotations.items(): feature_proto = feature_protos_by_name[feature_name] for annotation in annotations: feature_proto.annotation.extra_metadata.add().CopyFrom(annotation) for feature_name, tags in feature_tags.items(): feature_proto = feature_protos_by_name[feature_name] for tag in tags: feature_proto.annotation.tag.append(tag) return schema_proto
def test_convert_to_arrow_error(self, feature_spec, feed_dict, error_msg, error_type=ValueError): schema = schema_utils.schema_from_feature_spec(feature_spec) converter = impl_helper.make_tensor_to_arrow_converter(schema) with self.assertRaisesRegexp(error_type, error_msg): impl_helper.convert_to_arrow(schema, converter, feed_dict)
def test_constructor_error(self, columns, feature_spec, error_msg, error_type=ValueError, **kwargs): schema = schema_utils.schema_from_feature_spec(feature_spec) with self.assertRaisesRegexp(error_type, error_msg): csv_coder.CsvCoder(columns, schema, **kwargs)
def _remove_columns_from_metadata(metadata, excluded_columns): """Remove columns from metadata without mutating original metadata.""" feature_spec, domains = schema_utils.schema_as_feature_spec(metadata.schema) new_feature_spec = {name: spec for name, spec in feature_spec.items() if name not in excluded_columns} new_domains = {name: spec for name, spec in domains.items() if name not in excluded_columns} return dataset_metadata.DatasetMetadata( schema_utils.schema_from_feature_spec(new_feature_spec, new_domains))
def get_raw_metadata( columns: List[str], schema_map: Dict[str, collections.namedtuple] ) -> dataset_metadata.DatasetMetadata: """Returns metadata prior to TF Transform preprocessing Note: takes base schema_map as input, not raw_schema_map. """ feature_spec = get_raw_feature_spec(columns, schema_map) return dataset_metadata.DatasetMetadata( schema_utils.schema_from_feature_spec(feature_spec))
def test_to_instance_dicts(self, feature_spec, instances, record_batch, feed_dict, feed_eager_tensors): del record_batch if feed_eager_tensors: test_case.skip_if_not_tf2('Tensorflow 2.x required') schema = schema_utils.schema_from_feature_spec(feature_spec) feed_dict_local = (_eager_tensor_from_values(feed_dict) if feed_eager_tensors else copy.copy(feed_dict)) result = impl_helper.to_instance_dicts(schema, feed_dict_local) np.testing.assert_equal(instances, result)
def test_encode_error(self, feature_spec, instance, error_msg, error_type=ValueError, **kwargs): schema = schema_utils.schema_from_feature_spec(feature_spec) with self.assertRaisesRegexp(error_type, error_msg): coder = example_proto_coder.ExampleProtoCoder(schema, **kwargs) coder.encode(instance)
def test_decode_error(self, feature_spec, ascii_proto, error_msg, error_type=ValueError, **kwargs): schema = schema_utils.schema_from_feature_spec(feature_spec) coder = example_proto_coder.ExampleProtoCoder(schema, **kwargs) serialized_proto = _ascii_to_binary(ascii_proto) with self.assertRaisesRegexp(error_type, error_msg): coder.decode(serialized_proto)
def test_encode_error(self, columns, feature_spec, instance, error_msg, error_type=ValueError, **kwargs): schema = schema_utils.schema_from_feature_spec(feature_spec) coder = csv_coder.CsvCoder(columns, schema, **kwargs) with self.assertRaisesRegexp(error_type, error_msg): coder.encode(instance)
def test_make_feed_list_error(self, feature_spec, instances, error_msg, error_type=ValueError): tensors = tf.io.parse_example(serialized=tf.compat.v1.placeholder( tf.string, [None]), features=feature_spec) schema = schema_utils.schema_from_feature_spec(feature_spec) with self.assertRaisesRegexp(error_type, error_msg): impl_helper.make_feed_list(tensors, schema, instances)
def metadata_from_feature_spec(feature_spec, domains=None): """Construct a DatasetMetadata from a feature spec. Args: feature_spec: A feature spec domains: A dict containing domains of features Returns: A `tft.tf_metadata.dataset_metadata.DatasetMetadata` object. """ return dataset_metadata.DatasetMetadata( schema_utils.schema_from_feature_spec(feature_spec, domains))
def train(**args): print('Ingesting data.') client = storage.Client() bucket = client.get_bucket('ames-house-dataset') blob = storage.Blob('train.csv', bucket) content = blob.download_as_string() data = pd.read_csv(BytesIO(content), index_col=0) print('Creating metadata specification.') RAW_DATA_FEATURE_SPEC = get_raw_data_spec(data) RAW_DATA_METADATA = dataset_metadata.DatasetMetadata( schema_utils.schema_from_feature_spec(RAW_DATA_FEATURE_SPEC))
def _parse_schema_json(schema_json): """Translate a JSON schema into a Schema proto.""" schema_dict = json.loads(schema_json) feature_spec = { feature_dict['name']: _column_schema_from_json(feature_dict) for feature_dict in schema_dict.get('feature', []) } domains = { feature_dict['name']: _domain_from_json(feature_dict['domain']) for feature_dict in schema_dict.get('feature', []) } return schema_utils.schema_from_feature_spec(feature_spec, domains)
def _RunBeamImpl(self, inputs, outputs, preprocessing_fn, input_dataset_metadata, raw_examples_data_format, transform_output_path, compute_statistics, materialize_output_paths): """Perform data preprocessing with FlumeC++ runner. Args: inputs: A dictionary of labelled input values. outputs: A dictionary of labelled output values. preprocessing_fn: The tf.Transform preprocessing_fn. input_dataset_metadata: A DatasetMetadata object for the input data. raw_examples_data_format: A string describing the raw data format. transform_output_path: An absolute path to write the output to. compute_statistics: A bool indicating whether or not compute statistics. materialize_output_paths: Paths to materialized outputs. Raises: RuntimeError: If reset() is not being invoked between two run(). ValueError: If the schema is empty. Returns: Status of the execution. """ raw_examples_file_format = common.GetSoleValue( inputs, labels.EXAMPLES_FILE_FORMAT_LABEL, strict=False) analyze_and_transform_data_paths = common.GetValues( inputs, labels.ANALYZE_AND_TRANSFORM_DATA_PATHS_LABEL) transform_only_data_paths = common.GetValues( inputs, labels.TRANSFORM_ONLY_DATA_PATHS_LABEL) stats_use_tfdv = common.GetSoleValue(inputs, labels.TFT_STATISTICS_USE_TFDV_LABEL) per_set_stats_output_paths = common.GetValues( outputs, labels.PER_SET_STATS_OUTPUT_PATHS_LABEL) temp_path = common.GetSoleValue(outputs, labels.TEMP_OUTPUT_LABEL) tf.logging.info('Analyze and transform data patterns: %s', list(enumerate(analyze_and_transform_data_paths))) tf.logging.info('Transform data patterns: %s', list(enumerate(transform_only_data_paths))) tf.logging.info('Transform materialization output paths: %s', list(enumerate(materialize_output_paths))) tf.logging.info('Transform output path: %s', transform_output_path) feature_spec = input_dataset_metadata.schema.as_feature_spec() try: analyze_input_columns = tft.get_analyze_input_columns( preprocessing_fn, feature_spec) transform_input_columns = ( tft.get_transform_input_columns(preprocessing_fn, feature_spec)) except AttributeError: # If using TFT 1.12, fall back to assuming all features are used. analyze_input_columns = feature_spec.keys() transform_input_columns = feature_spec.keys() # Use the same dataset (same columns) for AnalyzeDataset and computing # pre-transform stats so that the data will only be read once for these # two operations. if compute_statistics: analyze_input_columns = list( set(list(analyze_input_columns) + list(transform_input_columns))) analyze_input_dataset_metadata = copy.deepcopy(input_dataset_metadata) transform_input_dataset_metadata = copy.deepcopy(input_dataset_metadata) if input_dataset_metadata.schema is not _RAW_EXAMPLE_SCHEMA: analyze_input_dataset_metadata.schema = dataset_schema.from_feature_spec( {feature: feature_spec[feature] for feature in analyze_input_columns}) transform_input_dataset_metadata.schema = ( dataset_schema.from_feature_spec({ feature: feature_spec[feature] for feature in transform_input_columns })) can_process_jointly = not bool(per_set_stats_output_paths or materialize_output_paths) analyze_data_list = self._MakeDatasetList( analyze_and_transform_data_paths, raw_examples_file_format, raw_examples_data_format, analyze_input_dataset_metadata, can_process_jointly) transform_data_list = self._MakeDatasetList( list(analyze_and_transform_data_paths) + list(transform_only_data_paths), raw_examples_file_format, raw_examples_data_format, transform_input_dataset_metadata, can_process_jointly) desired_batch_size = self._GetDesiredBatchSize(raw_examples_data_format) with self._CreatePipeline(outputs) as p: with tft_beam.Context( temp_dir=temp_path, desired_batch_size=desired_batch_size, passthrough_keys={_TRANSFORM_INTERNAL_FEATURE_FOR_KEY}, use_deep_copy_optimization=True): # pylint: disable=expression-not-assigned # pylint: disable=no-value-for-parameter analyze_decode_fn = ( self._GetDecodeFunction(raw_examples_data_format, analyze_input_dataset_metadata.schema)) for (idx, dataset) in enumerate(analyze_data_list): dataset.encoded = ( p | 'ReadAnalysisDataset[{}]'.format(idx) >> self._ReadExamples(dataset)) dataset.decoded = ( dataset.encoded | 'DecodeAnalysisDataset[{}]'.format(idx) >> self._DecodeInputs(analyze_decode_fn)) input_analysis_data = ( [dataset.decoded for dataset in analyze_data_list] | 'FlattenAnalysisDatasets' >> beam.Flatten()) transform_fn = ( (input_analysis_data, input_dataset_metadata) | 'AnalyzeDataset' >> tft_beam.AnalyzeDataset(preprocessing_fn)) # Write the raw/input metadata. (input_dataset_metadata | 'WriteMetadata' >> tft_beam.WriteMetadata( os.path.join(transform_output_path, tft.TFTransformOutput.RAW_METADATA_DIR), p)) # WriteTransformFn writes transform_fn and metadata to subdirectories # tensorflow_transform.SAVED_MODEL_DIR and # tensorflow_transform.TRANSFORMED_METADATA_DIR respectively. (transform_fn | 'WriteTransformFn' >> tft_beam.WriteTransformFn(transform_output_path)) if compute_statistics or materialize_output_paths: # Do not compute pre-transform stats if the input format is raw proto, # as StatsGen would treat any input as tf.Example. if (compute_statistics and not self._IsDataFormatProto(raw_examples_data_format)): # Aggregated feature stats before transformation. pre_transform_feature_stats_path = os.path.join( transform_output_path, tft.TFTransformOutput.PRE_TRANSFORM_FEATURE_STATS_PATH) # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in # schema. Currently input dataset schema only contains dtypes, # and other metadata is dropped due to roundtrip to tensors. schema_proto = schema_utils.schema_from_feature_spec( analyze_input_dataset_metadata.schema.as_feature_spec()) ([ dataset.decoded if stats_use_tfdv else dataset.encoded for dataset in analyze_data_list ] | 'FlattenPreTransformAnalysisDatasets' >> beam.Flatten() | 'GenerateAggregatePreTransformAnalysisStats' >> self._GenerateStats( pre_transform_feature_stats_path, schema_proto, use_deep_copy_optimization=True, use_tfdv=stats_use_tfdv)) transform_decode_fn = ( self._GetDecodeFunction(raw_examples_data_format, transform_input_dataset_metadata.schema)) # transform_data_list is a superset of analyze_data_list, we pay the # cost to read the same dataset (analyze_data_list) again here to # prevent certain beam runner from doing large temp materialization. for (idx, dataset) in enumerate(transform_data_list): dataset.encoded = ( p | 'ReadTransformDataset[{}]'.format(idx) >> self._ReadExamples(dataset)) dataset.decoded = ( dataset.encoded | 'DecodeTransformDataset[{}]'.format(idx) >> self._DecodeInputs(transform_decode_fn)) (dataset.transformed, metadata) = (((dataset.decoded, transform_input_dataset_metadata), transform_fn) | 'TransformDataset[{}]'.format(idx) >> tft_beam.TransformDataset()) if materialize_output_paths or not stats_use_tfdv: dataset.transformed_and_encoded = ( dataset.transformed | 'EncodeTransformedDataset[{}]'.format(idx) >> beam.ParDo( self._EncodeAsExamples(), metadata)) if compute_statistics: # Aggregated feature stats after transformation. _, metadata = transform_fn post_transform_feature_stats_path = os.path.join( transform_output_path, tft.TFTransformOutput.POST_TRANSFORM_FEATURE_STATS_PATH) # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in # schema. Currently input dataset schema only contains dtypes, # and other metadata is dropped due to roundtrip to tensors. transformed_schema_proto = schema_utils.schema_from_feature_spec( metadata.schema.as_feature_spec()) ([(dataset.transformed if stats_use_tfdv else dataset.transformed_and_encoded) for dataset in transform_data_list] | 'FlattenPostTransformAnalysisDatasets' >> beam.Flatten() | 'GenerateAggregatePostTransformAnalysisStats' >> self._GenerateStats( post_transform_feature_stats_path, transformed_schema_proto, use_tfdv=stats_use_tfdv)) if per_set_stats_output_paths: assert len(transform_data_list) == len(per_set_stats_output_paths) # TODO(b/67632871): Remove duplicate stats gen compute that is # done both on a flattened view of the data, and on each span # below. bundles = zip(transform_data_list, per_set_stats_output_paths) for (idx, (dataset, output_path)) in enumerate(bundles): if stats_use_tfdv: data = dataset.transformed else: data = dataset.transformed_and_encoded (data | 'GeneratePostTransformStats[{}]'.format(idx) >> self._GenerateStats( output_path, transformed_schema_proto, use_tfdv=stats_use_tfdv)) if materialize_output_paths: assert len(transform_data_list) == len(materialize_output_paths) bundles = zip(transform_data_list, materialize_output_paths) for (idx, (dataset, output_path)) in enumerate(bundles): (dataset.transformed_and_encoded | 'Materialize[{}]'.format(idx) >> self._WriteExamples( raw_examples_file_format, output_path)) return _Status.OK()