Exemple #1
0
 def test_list_cross_features(self):
     view = stats_util.DatasetListView(
         self._stats_proto).get_default_slice()
     self.assertCountEqual(
         view.list_cross_features(),
         [(types.FeaturePath(['f1x']), types.FeaturePath(['f1y'])),
          (types.FeaturePath(['f2x']), types.FeaturePath(['f2y']))])
Exemple #2
0
 def test_get_feature_by_path(self):
     view = stats_util.DatasetListView(
         self._stats_proto).get_default_slice()
     feature1 = view.get_feature(types.FeaturePath(['f0_step1',
                                                    'f0_step2']))
     self.assertEqual(self._stats_proto.datasets[2].features[0],
                      feature1.proto())
Exemple #3
0
 def test_get_cross_feature(self):
     view = stats_util.DatasetListView(
         self._stats_proto).get_default_slice()
     cross_feature = view.get_cross_feature(types.FeaturePath(['f1x']),
                                            types.FeaturePath(['f1y']))
     self.assertEqual(self._stats_proto.datasets[2].cross_features[0],
                      cross_feature.proto())
Exemple #4
0
 def test_list_features(self):
     view = stats_util.DatasetListView(
         self._stats_proto).get_default_slice()
     self.assertCountEqual(view.list_features(), [
         types.FeaturePath(['f0_step1', 'f0_step2']),
         types.FeaturePath(['f1']),
         types.FeaturePath(['f3_derived'])
     ])
Exemple #5
0
 def test_get_derived_feature(self):
     view = stats_util.DatasetListView(
         self._stats_proto).get_default_slice()
     feature1 = view.get_derived_feature('my_deriver_name', [
         types.FeaturePath(['f0_step1', 'f0_step2']),
         types.FeaturePath(['f1'])
     ])
     self.assertEqual(self._stats_proto.datasets[2].features[2],
                      feature1.proto())
Exemple #6
0
 def test_get_derived_feature_ambiguous(self):
     stats_proto = statistics_pb2.DatasetFeatureStatisticsList.FromString(
         self._stats_proto.SerializeToString())
     # Duplicate the derived feature.
     stats_proto.datasets[2].features.append(
         stats_proto.datasets[2].features[2])
     view = stats_util.DatasetListView(stats_proto).get_default_slice()
     with self.assertRaisesRegex(ValueError,
                                 'Ambiguous result, 2 features matched'):
         view.get_derived_feature('my_deriver_name', [
             types.FeaturePath(['f0_step1', 'f0_step2']),
             types.FeaturePath(['f1'])
         ])
Exemple #7
0
 def test_get_derived_feature_missing(self):
     view = stats_util.DatasetListView(
         self._stats_proto).get_default_slice()
     self.assertIsNone(
         view.get_derived_feature('mismatched_name', [
             types.FeaturePath(['f0_step1', 'f0_step2']),
             types.FeaturePath(['f1'])
         ]))
     self.assertIsNone(
         view.get_derived_feature('my_deriver_name', [
             types.FeaturePath(['f0_step1', 'f0_step2', 'mismatched_step']),
             types.FeaturePath(['f1'])
         ]))
     self.assertIsNone(view.get_derived_feature('my_deriver_name', []))
Exemple #8
0
 def test_get_feature_defined_by_name(self):
     stats = statistics_pb2.DatasetFeatureStatisticsList()
     text_format.Parse(
         """
 datasets: {
   name: 'All Examples'
   features: {
     name: "f0"
   }
   features: {
     name: "f1"
   }
 }
 """, stats)
     view = stats_util.DatasetListView(stats).get_default_slice()
     self.assertEqual(stats.datasets[0].features[1],
                      view.get_feature(types.FeaturePath(['f1'])).proto())
Exemple #9
0
 def test_mixed_path_and_name_is_an_error(self):
     stats = statistics_pb2.DatasetFeatureStatisticsList()
     text_format.Parse(
         """
 datasets: {
   name: 'All Examples'
   features: {
     path: {
       step: "f0_step1"
       step: "f0_step2"
     }
   }
   features: {
     name: "f1"
   }
 }
 """, stats)
     view = stats_util.DatasetListView(stats).get_default_slice()
     with self.assertRaisesRegex(ValueError,
                                 ('Features must be specified with '
                                  'either path or name within a Dataset')):
         view.get_feature(types.FeaturePath('f1'))
def _flatten_statistics_for_sliced_validation(
    statistics: statistics_pb2.DatasetFeatureStatisticsList
) -> Tuple[statistics_pb2.DatasetFeatureStatisticsList, Set[str]]:
    """Flattens sliced stats into unsliced stats with prepended slice keys."""
    result = statistics_pb2.DatasetFeatureStatisticsList()
    dataset_flat = result.datasets.add()
    # Copy top level metadata from the default (overall) slice.
    default_slice = stats_util.DatasetListView(statistics).get_default_slice()
    if default_slice is None:
        raise ValueError('Missing default slice')
    dataset_flat.CopyFrom(default_slice.proto())
    dataset_flat.ClearField('features')
    dataset_flat.ClearField('cross_features')
    slice_names = set()
    for dataset in statistics.datasets:
        slice_names.add(dataset.name)
        for feature in dataset.features:
            copied_feature = dataset_flat.features.add()
            copied_feature.CopyFrom(feature)
            copied_feature.path.CopyFrom(
                _prepend_slice_path(
                    dataset.name,
                    types.FeaturePath.from_proto(
                        copied_feature.path)).to_proto())
        for cross_feature in dataset.cross_features:
            copied_cross_feature = dataset_flat.cross_features.add()
            copied_cross_feature.CopyFrom(cross_feature)
            copied_cross_feature.path_x.CopyFrom(
                _prepend_slice_path(
                    dataset.name,
                    types.FeaturePath.from_proto(
                        copied_cross_feature.path_x)).to_proto())
            copied_cross_feature.path_y.CopyFrom(
                _prepend_slice_path(
                    dataset.name,
                    types.FeaturePath.from_proto(
                        copied_cross_feature.path_y)).to_proto())
    return result, slice_names
Exemple #11
0
 def test_get_missing_feature(self):
     view = stats_util.DatasetListView(
         self._stats_proto).get_default_slice()
     self.assertIsNone(
         view.get_feature(types.FeaturePath(['not', 'a', 'path'])))
Exemple #12
0
 def test_get_feature_by_name(self):
     view = stats_util.DatasetListView(
         self._stats_proto).get_default_slice()
     feature1 = view.get_feature('f1')
     self.assertEqual(self._stats_proto.datasets[2].features[1],
                      feature1.proto())
Exemple #13
0
 def test_get_missing_slice(self):
     view = stats_util.DatasetListView(self._stats_proto)
     slice99 = view.get_slice('slice99')
     self.assertIsNone(slice99)
Exemple #14
0
 def test_get_default(self):
     view = stats_util.DatasetListView(self._stats_proto)
     default_slice = view.get_default_slice()
     self.assertEqual(self._stats_proto.datasets[2], default_slice.proto())
Exemple #15
0
 def test_get_slice1(self):
     view = stats_util.DatasetListView(self._stats_proto)
     slice1 = view.get_slice('slice1')
     self.assertEqual(self._stats_proto.datasets[1], slice1.proto())
Exemple #16
0
 def test_list_slices(self):
     view = stats_util.DatasetListView(self._stats_proto)
     self.assertCountEqual(['slice0', 'slice1', 'All Examples'],
                           view.list_slices())