コード例 #1
0
 def test_filepattern_for_dataset_split(self):
     self.assertEqual(
         '/tmp/bar/foo-test*',
         naming.filepattern_for_dataset_split(dataset_name='foo',
                                              split=splits.Split.TEST,
                                              data_dir='/tmp/bar/'))
     self.assertEqual(
         '/tmp/bar/foo-test.bar*',
         naming.filepattern_for_dataset_split(dataset_name='foo',
                                              split=splits.Split.TEST,
                                              filetype_suffix='bar',
                                              data_dir='/tmp/bar/'))
コード例 #2
0
 def test_filepattern_for_dataset_split(self):
     self.assertEqual(
         "/tmp/bar/foo-test*",
         naming.filepattern_for_dataset_split(dataset_name="foo",
                                              split=splits.Split.TEST,
                                              data_dir="/tmp/bar/"))
     self.assertEqual(
         "/tmp/bar/foo-test.bar*",
         naming.filepattern_for_dataset_split(dataset_name="foo",
                                              split=splits.Split.TEST,
                                              filetype_suffix="bar",
                                              data_dir="/tmp/bar/"))
コード例 #3
0
 def filepattern(self):
     """Returns a Glob filepattern for this split."""
     return naming.filepattern_for_dataset_split(
         dataset_name=self.dataset_name,
         split=self.split,
         data_dir=self.data_dir,
         filetype_suffix=self.filetype_suffix)
コード例 #4
0
def get_dataset_feature_statistics(builder, split):
  """Calculate statistics for the specified split."""
  tfdv = lazy_imports_lib.lazy_imports.tensorflow_data_validation
  # TODO(epot): Avoid hardcoding file format.
  filetype_suffix = "tfrecord"
  if filetype_suffix not in ["tfrecord", "csv"]:
    raise ValueError(
        "Cannot generate statistics for filetype {}".format(filetype_suffix))
  filepattern = naming.filepattern_for_dataset_split(
      builder.name, split, builder.data_dir, filetype_suffix)
  # Avoid generating a large number of buckets in rank histogram
  # (default is 1000).
  stats_options = tfdv.StatsOptions(num_top_values=10,
                                    num_rank_histogram_buckets=10)
  if filetype_suffix == "csv":
    statistics = tfdv.generate_statistics_from_csv(
        filepattern, stats_options=stats_options)
  else:
    statistics = tfdv.generate_statistics_from_tfrecord(
        filepattern, stats_options=stats_options)
  schema = tfdv.infer_schema(statistics)
  schema_features = {feature.name: feature for feature in schema.feature}
  # Override shape in the schema.
  for feature_name, feature in builder.info.features.items():
    _populate_shape(feature.shape, [feature_name], schema_features)

  # Remove legacy field.
  if getattr(schema, "generate_legacy_feature_spec", None) is not None:
    schema.ClearField("generate_legacy_feature_spec")
  return statistics.datasets[0], schema