def __init__(self, schema: schema_pb2.Schema, name: Text = 'SparseFeatureStatsGenerator') -> None: """Initializes a sparse feature statistics generator. Args: schema: A required schema for the dataset. name: An optional unique name associated with the statistics generator. """ self._sparse_feature_components = _get_components( _get_all_sparse_features(schema)) # Create length diff generators for each index / value pair and count # missing generator for all paths. constituents = [] for _, (value, indices) in self._sparse_feature_components.items(): required_paths = [value] + list(indices) constituents.append( count_missing_generator.CountMissingGenerator( value, required_paths)) for index in indices: constituents.append( length_diff_generator.LengthDiffGenerator( index, value, required_paths)) constituents.append( count_missing_generator.CountMissingGenerator( index, required_paths)) super(SparseFeatureStatsGenerator, self).__init__(name, constituents, schema)
def test_count_missing_generator_single_batch(self): batch = input_batch.InputBatch( pa.Table.from_arrays([pa.array([[1], None, []])], ['feature'])) path = types.FeaturePath(['feature']) generator = count_missing_generator.CountMissingGenerator(path) accumulator = generator.create_accumulator() accumulator = generator.add_input(accumulator, batch) self.assertEqual(1, generator.extract_output(accumulator))
def test_count_missing_generator_key(self): path = types.FeaturePath(['feature']) generator = count_missing_generator.CountMissingGenerator(path) expected_key = ('CountMissingGenerator', path) # use assertDictEqual to make failures readable while checking hash value. self.assertDictEqual({expected_key: None}, {generator.get_key(): None}) self.assertDictEqual( {expected_key: None}, {count_missing_generator.CountMissingGenerator.key(path): None})
def __init__(self, schema: schema_pb2.Schema, name: Text = 'WeightedFeatureStatsGenerator') -> None: constituents = [] for weighted_feature in schema.weighted_feature: weight = types.FeaturePath.from_proto( weighted_feature.weight_feature) value = types.FeaturePath.from_proto(weighted_feature.feature) component_paths = [weight, value] constituents.append( length_diff_generator.LengthDiffGenerator( weight, value, required_paths=component_paths)) constituents.append( count_missing_generator.CountMissingGenerator( value, required_paths=component_paths)) constituents.append( count_missing_generator.CountMissingGenerator( weight, required_paths=component_paths)) super(WeightedFeatureStatsGenerator, self).__init__(name, constituents, schema)
def test_count_missing_generator_key_with_required(self): path = types.FeaturePath(['index']) required = types.FeaturePath(['value']) generator = count_missing_generator.CountMissingGenerator( path, [required]) expected_key = ('CountMissingGenerator', path, required) self.assertDictEqual({expected_key: None}, {generator.get_key(): None}) self.assertDictEqual({expected_key: None}, { count_missing_generator.CountMissingGenerator.key(path, [required]): None })
def test_count_missing_generator_required_path(self): batch = input_batch.InputBatch( pa.RecordBatch.from_arrays( [pa.array([[1], None, []]), pa.array([[1], None, []])], ['index', 'value'])) path = types.FeaturePath(['index']) required_path = types.FeaturePath(['value']) generator = count_missing_generator.CountMissingGenerator( path, [required_path]) accumulator = generator.create_accumulator() accumulator = generator.add_input(accumulator, batch) self.assertEqual(0, generator.extract_output(accumulator))