def test_filepattern_for_dataset_split(self): self.assertEqual( '/tmp/bar/foo-test*', naming.filepattern_for_dataset_split(dataset_name='foo', split=splits.Split.TEST, data_dir='/tmp/bar/')) self.assertEqual( '/tmp/bar/foo-test.bar*', naming.filepattern_for_dataset_split(dataset_name='foo', split=splits.Split.TEST, filetype_suffix='bar', data_dir='/tmp/bar/'))
def test_filepattern_for_dataset_split(self): self.assertEqual( "/tmp/bar/foo-test*", naming.filepattern_for_dataset_split(dataset_name="foo", split=splits.Split.TEST, data_dir="/tmp/bar/")) self.assertEqual( "/tmp/bar/foo-test.bar*", naming.filepattern_for_dataset_split(dataset_name="foo", split=splits.Split.TEST, filetype_suffix="bar", data_dir="/tmp/bar/"))
def filepattern(self): """Returns a Glob filepattern for this split.""" return naming.filepattern_for_dataset_split( dataset_name=self.dataset_name, split=self.split, data_dir=self.data_dir, filetype_suffix=self.filetype_suffix)
def get_dataset_feature_statistics(builder, split): """Calculate statistics for the specified split.""" tfdv = lazy_imports_lib.lazy_imports.tensorflow_data_validation # TODO(epot): Avoid hardcoding file format. filetype_suffix = "tfrecord" if filetype_suffix not in ["tfrecord", "csv"]: raise ValueError( "Cannot generate statistics for filetype {}".format(filetype_suffix)) filepattern = naming.filepattern_for_dataset_split( builder.name, split, builder.data_dir, filetype_suffix) # Avoid generating a large number of buckets in rank histogram # (default is 1000). stats_options = tfdv.StatsOptions(num_top_values=10, num_rank_histogram_buckets=10) if filetype_suffix == "csv": statistics = tfdv.generate_statistics_from_csv( filepattern, stats_options=stats_options) else: statistics = tfdv.generate_statistics_from_tfrecord( filepattern, stats_options=stats_options) schema = tfdv.infer_schema(statistics) schema_features = {feature.name: feature for feature in schema.feature} # Override shape in the schema. for feature_name, feature in builder.info.features.items(): _populate_shape(feature.shape, [feature_name], schema_features) # Remove legacy field. if getattr(schema, "generate_legacy_feature_spec", None) is not None: schema.ClearField("generate_legacy_feature_spec") return statistics.datasets[0], schema