def __init__(self,
                 schema: schema_pb2.Schema,
                 name: Text = 'SparseFeatureStatsGenerator') -> None:
        """Initializes a sparse feature statistics generator.

    Args:
      schema: A required schema for the dataset.
      name: An optional unique name associated with the statistics generator.
    """
        self._sparse_feature_components = _get_components(
            _get_all_sparse_features(schema))

        # Create length diff generators for each index / value pair and count
        # missing generator for all paths.
        constituents = []
        for _, (value, indices) in self._sparse_feature_components.items():
            required_paths = [value] + list(indices)
            constituents.append(
                count_missing_generator.CountMissingGenerator(
                    value, required_paths))
            for index in indices:
                constituents.append(
                    length_diff_generator.LengthDiffGenerator(
                        index, value, required_paths))
                constituents.append(
                    count_missing_generator.CountMissingGenerator(
                        index, required_paths))

        super(SparseFeatureStatsGenerator,
              self).__init__(name, constituents, schema)
 def test_count_missing_generator_single_batch(self):
   batch = input_batch.InputBatch(
       pa.Table.from_arrays([pa.array([[1], None, []])], ['feature']))
   path = types.FeaturePath(['feature'])
   generator = count_missing_generator.CountMissingGenerator(path)
   accumulator = generator.create_accumulator()
   accumulator = generator.add_input(accumulator, batch)
   self.assertEqual(1, generator.extract_output(accumulator))
Ejemplo n.º 3
0
 def test_count_missing_generator_key(self):
     path = types.FeaturePath(['feature'])
     generator = count_missing_generator.CountMissingGenerator(path)
     expected_key = ('CountMissingGenerator', path)
     # use assertDictEqual to make failures readable while checking hash value.
     self.assertDictEqual({expected_key: None}, {generator.get_key(): None})
     self.assertDictEqual(
         {expected_key: None},
         {count_missing_generator.CountMissingGenerator.key(path): None})
 def __init__(self,
              schema: schema_pb2.Schema,
              name: Text = 'WeightedFeatureStatsGenerator') -> None:
     constituents = []
     for weighted_feature in schema.weighted_feature:
         weight = types.FeaturePath.from_proto(
             weighted_feature.weight_feature)
         value = types.FeaturePath.from_proto(weighted_feature.feature)
         component_paths = [weight, value]
         constituents.append(
             length_diff_generator.LengthDiffGenerator(
                 weight, value, required_paths=component_paths))
         constituents.append(
             count_missing_generator.CountMissingGenerator(
                 value, required_paths=component_paths))
         constituents.append(
             count_missing_generator.CountMissingGenerator(
                 weight, required_paths=component_paths))
     super(WeightedFeatureStatsGenerator,
           self).__init__(name, constituents, schema)
 def test_count_missing_generator_key_with_required(self):
   path = types.FeaturePath(['index'])
   required = types.FeaturePath(['value'])
   generator = count_missing_generator.CountMissingGenerator(
       path, [required])
   expected_key = ('CountMissingGenerator', path, required)
   self.assertDictEqual({expected_key: None}, {generator.get_key(): None})
   self.assertDictEqual({expected_key: None}, {
       count_missing_generator.CountMissingGenerator.key(path, [required]):
           None
   })
Ejemplo n.º 6
0
 def test_count_missing_generator_required_path(self):
     batch = input_batch.InputBatch(
         pa.RecordBatch.from_arrays(
             [pa.array([[1], None, []]),
              pa.array([[1], None, []])], ['index', 'value']))
     path = types.FeaturePath(['index'])
     required_path = types.FeaturePath(['value'])
     generator = count_missing_generator.CountMissingGenerator(
         path, [required_path])
     accumulator = generator.create_accumulator()
     accumulator = generator.add_input(accumulator, batch)
     self.assertEqual(0, generator.extract_output(accumulator))