def test_list_cross_features(self): view = stats_util.DatasetListView( self._stats_proto).get_default_slice() self.assertCountEqual( view.list_cross_features(), [(types.FeaturePath(['f1x']), types.FeaturePath(['f1y'])), (types.FeaturePath(['f2x']), types.FeaturePath(['f2y']))])
def test_get_feature_by_path(self): view = stats_util.DatasetListView( self._stats_proto).get_default_slice() feature1 = view.get_feature(types.FeaturePath(['f0_step1', 'f0_step2'])) self.assertEqual(self._stats_proto.datasets[2].features[0], feature1.proto())
def test_get_cross_feature(self): view = stats_util.DatasetListView( self._stats_proto).get_default_slice() cross_feature = view.get_cross_feature(types.FeaturePath(['f1x']), types.FeaturePath(['f1y'])) self.assertEqual(self._stats_proto.datasets[2].cross_features[0], cross_feature.proto())
def test_list_features(self): view = stats_util.DatasetListView( self._stats_proto).get_default_slice() self.assertCountEqual(view.list_features(), [ types.FeaturePath(['f0_step1', 'f0_step2']), types.FeaturePath(['f1']), types.FeaturePath(['f3_derived']) ])
def test_get_derived_feature(self): view = stats_util.DatasetListView( self._stats_proto).get_default_slice() feature1 = view.get_derived_feature('my_deriver_name', [ types.FeaturePath(['f0_step1', 'f0_step2']), types.FeaturePath(['f1']) ]) self.assertEqual(self._stats_proto.datasets[2].features[2], feature1.proto())
def test_get_derived_feature_ambiguous(self): stats_proto = statistics_pb2.DatasetFeatureStatisticsList.FromString( self._stats_proto.SerializeToString()) # Duplicate the derived feature. stats_proto.datasets[2].features.append( stats_proto.datasets[2].features[2]) view = stats_util.DatasetListView(stats_proto).get_default_slice() with self.assertRaisesRegex(ValueError, 'Ambiguous result, 2 features matched'): view.get_derived_feature('my_deriver_name', [ types.FeaturePath(['f0_step1', 'f0_step2']), types.FeaturePath(['f1']) ])
def test_get_derived_feature_missing(self): view = stats_util.DatasetListView( self._stats_proto).get_default_slice() self.assertIsNone( view.get_derived_feature('mismatched_name', [ types.FeaturePath(['f0_step1', 'f0_step2']), types.FeaturePath(['f1']) ])) self.assertIsNone( view.get_derived_feature('my_deriver_name', [ types.FeaturePath(['f0_step1', 'f0_step2', 'mismatched_step']), types.FeaturePath(['f1']) ])) self.assertIsNone(view.get_derived_feature('my_deriver_name', []))
def test_get_feature_defined_by_name(self): stats = statistics_pb2.DatasetFeatureStatisticsList() text_format.Parse( """ datasets: { name: 'All Examples' features: { name: "f0" } features: { name: "f1" } } """, stats) view = stats_util.DatasetListView(stats).get_default_slice() self.assertEqual(stats.datasets[0].features[1], view.get_feature(types.FeaturePath(['f1'])).proto())
def test_mixed_path_and_name_is_an_error(self): stats = statistics_pb2.DatasetFeatureStatisticsList() text_format.Parse( """ datasets: { name: 'All Examples' features: { path: { step: "f0_step1" step: "f0_step2" } } features: { name: "f1" } } """, stats) view = stats_util.DatasetListView(stats).get_default_slice() with self.assertRaisesRegex(ValueError, ('Features must be specified with ' 'either path or name within a Dataset')): view.get_feature(types.FeaturePath('f1'))
def _flatten_statistics_for_sliced_validation( statistics: statistics_pb2.DatasetFeatureStatisticsList ) -> Tuple[statistics_pb2.DatasetFeatureStatisticsList, Set[str]]: """Flattens sliced stats into unsliced stats with prepended slice keys.""" result = statistics_pb2.DatasetFeatureStatisticsList() dataset_flat = result.datasets.add() # Copy top level metadata from the default (overall) slice. default_slice = stats_util.DatasetListView(statistics).get_default_slice() if default_slice is None: raise ValueError('Missing default slice') dataset_flat.CopyFrom(default_slice.proto()) dataset_flat.ClearField('features') dataset_flat.ClearField('cross_features') slice_names = set() for dataset in statistics.datasets: slice_names.add(dataset.name) for feature in dataset.features: copied_feature = dataset_flat.features.add() copied_feature.CopyFrom(feature) copied_feature.path.CopyFrom( _prepend_slice_path( dataset.name, types.FeaturePath.from_proto( copied_feature.path)).to_proto()) for cross_feature in dataset.cross_features: copied_cross_feature = dataset_flat.cross_features.add() copied_cross_feature.CopyFrom(cross_feature) copied_cross_feature.path_x.CopyFrom( _prepend_slice_path( dataset.name, types.FeaturePath.from_proto( copied_cross_feature.path_x)).to_proto()) copied_cross_feature.path_y.CopyFrom( _prepend_slice_path( dataset.name, types.FeaturePath.from_proto( copied_cross_feature.path_y)).to_proto()) return result, slice_names
def test_get_missing_feature(self): view = stats_util.DatasetListView( self._stats_proto).get_default_slice() self.assertIsNone( view.get_feature(types.FeaturePath(['not', 'a', 'path'])))
def test_get_feature_by_name(self): view = stats_util.DatasetListView( self._stats_proto).get_default_slice() feature1 = view.get_feature('f1') self.assertEqual(self._stats_proto.datasets[2].features[1], feature1.proto())
def test_get_missing_slice(self): view = stats_util.DatasetListView(self._stats_proto) slice99 = view.get_slice('slice99') self.assertIsNone(slice99)
def test_get_default(self): view = stats_util.DatasetListView(self._stats_proto) default_slice = view.get_default_slice() self.assertEqual(self._stats_proto.datasets[2], default_slice.proto())
def test_get_slice1(self): view = stats_util.DatasetListView(self._stats_proto) slice1 = view.get_slice('slice1') self.assertEqual(self._stats_proto.datasets[1], slice1.proto())
def test_list_slices(self): view = stats_util.DatasetListView(self._stats_proto) self.assertCountEqual(['slice0', 'slice1', 'All Examples'], view.list_slices())