示例#1
0
def get_generators(
        options: stats_options.StatsOptions,
        in_memory: bool = False) -> List[stats_generator.StatsGenerator]:
    """Initializes the list of stats generators, including custom generators.

  Args:
    options: A StatsOptions object.
    in_memory: Whether the generators will be used to generate statistics in
      memory (True) or using Beam (False).

  Returns:
    A list of stats generator objects.
  """
    generators = _get_default_generators(options, in_memory)
    if options.generators:
        # Add custom stats generators.
        generators.extend(options.generators)
    if options.enable_semantic_domain_stats:
        semantic_domain_feature_stats_generators = [
            image_stats_generator.ImageStatsGenerator(),
            natural_language_stats_generator.NLStatsGenerator(),
            time_stats_generator.TimeStatsGenerator(),
        ]
        # Wrap semantic domain feature stats generators as a separate combiner
        # stats generator, so that we can apply sampling only for those and other
        # feature stats generators are not affected by it.
        generators.append(
            CombinerFeatureStatsWrapperGenerator(
                semantic_domain_feature_stats_generators,
                weight_feature=options.weight_feature,
                sample_rate=options.semantic_domain_stats_sample_rate))
    if options.schema is not None and _schema_has_sparse_features(
            options.schema):
        generators.append(
            sparse_feature_stats_generator.SparseFeatureStatsGenerator(
                options.schema))
    # Replace all CombinerFeatureStatsGenerator with a single
    # CombinerFeatureStatsWrapperGenerator.
    feature_generators = [
        x for x in generators
        if isinstance(x, stats_generator.CombinerFeatureStatsGenerator)
    ]
    if feature_generators:
        generators = [
            x for x in generators
            if not isinstance(x, stats_generator.CombinerFeatureStatsGenerator)
        ] + [
            CombinerFeatureStatsWrapperGenerator(
                feature_generators, weight_feature=options.weight_feature)
        ]
    if in_memory:
        for generator in generators:
            if not isinstance(generator,
                              stats_generator.CombinerStatsGenerator):
                raise TypeError(
                    'Statistics generator used in '
                    'generate_statistics_in_memory must '
                    'extend CombinerStatsGenerator, found object of '
                    'type %s.' % generator.__class__.__name__)
    return generators
 def test_sparse_feature_generator_multiple_sparse_features(self):
     batches = [
         pa.RecordBatch.from_arrays([
             pa.array([
                 None, None, ['a', 'b'], ['a', 'b'], ['a', 'b'], None, None
             ]),
             pa.array([[1, 2], [1, 2], None, None, None, None, None]),
             pa.array([[2, 4], [2, 4], [2, 4, 6], [2, 4, 6], [2, 4, 6],
                       None, None]),
             pa.array(
                 [None, None, None, None, None, ['a', 'b'], ['a', 'b']]),
             pa.array([None, None, None, None, None, [2, 4], [2, 4]]),
             pa.array([None, None, None, None, None, None, None],
                      type=pa.null()),
         ], [
             'value_feature', 'index_feature1', 'index_feature2',
             'other_value_feature', 'other_index_feature1',
             'other_index_feature2'
         ]),
         pa.RecordBatch.from_arrays([
             pa.array(
                 [None, None, None, None, None, ['a', 'b'], ['a', 'b']]),
             pa.array([None, None, None, None, None, [2, 4], [2, 4]]),
             pa.array([None, None, None, None, None, None, None],
                      type=pa.null())
         ], [
             'other_value_feature', 'other_index_feature1',
             'other_index_feature2'
         ]),
     ]
     schema = text_format.Parse(
         """
     sparse_feature {
       name: 'sparse_feature'
       index_feature {
         name: 'index_feature1'
       }
       index_feature {
         name: 'index_feature2'
       }
       value_feature {
         name: 'value_feature'
       }
     }
     sparse_feature {
       name: 'other_sparse_feature'
       index_feature {
         name: 'other_index_feature1'
       }
       index_feature {
         name: 'other_index_feature2'
       }
       value_feature {
         name: 'other_value_feature'
       }
     }
     """, schema_pb2.Schema())
     expected_result = {
         types.FeaturePath(['sparse_feature']):
         text_format.Parse(
             """
             path {
               step: 'sparse_feature'
             }
             custom_stats {
               name: 'missing_value'
               num: 2
             }
             custom_stats {
               name: 'missing_index'
               rank_histogram {
                 buckets {
                   label: 'index_feature1'
                   sample_count: 3
                 }
                 buckets {
                   label: 'index_feature2'
                   sample_count: 0
                 }
               }
             }
             custom_stats {
               name: 'max_length_diff'
               rank_histogram {
                 buckets {
                   label: 'index_feature1'
                   sample_count: 2
                 }
                 buckets {
                   label: 'index_feature2'
                   sample_count: 2
                 }
               }
             }
             custom_stats {
               name: 'min_length_diff'
               rank_histogram {
                 buckets {
                   label: 'index_feature1'
                   sample_count: -2
                 }
                 buckets {
                   label: 'index_feature2'
                   sample_count: 1
                 }
               }
             }""", statistics_pb2.FeatureNameStatistics()),
         types.FeaturePath(['other_sparse_feature']):
         text_format.Parse(
             """
             path {
               step: 'other_sparse_feature'
             }
             custom_stats {
               name: 'missing_value'
               num: 0
             }
             custom_stats {
               name: 'missing_index'
               rank_histogram {
                 buckets {
                   label: 'other_index_feature1'
                   sample_count: 0
                 }
                 buckets {
                   label: 'other_index_feature2'
                   sample_count: 4
                 }
               }
             }
             custom_stats {
               name: 'max_length_diff'
               rank_histogram {
                 buckets {
                   label: 'other_index_feature1'
                   sample_count: 0
                 }
                 buckets {
                   label: 'other_index_feature2'
                   sample_count: -2
                 }
               }
             }
             custom_stats {
               name: 'min_length_diff'
               rank_histogram {
                 buckets {
                   label: 'other_index_feature1'
                   sample_count: 0
                 }
                 buckets {
                   label: 'other_index_feature2'
                   sample_count: -2
                 }
               }
             }""", statistics_pb2.FeatureNameStatistics())
     }
     generator = (
         sparse_feature_stats_generator.SparseFeatureStatsGenerator(schema))
     self.assertCombinerOutputEqual(batches, generator, expected_result)
 def test_sparse_feature_generator_dataset_missing_entire_sparse_feature(
         self):
     batches = [
         pa.RecordBatch.from_arrays([
             pa.array([['a']]),
         ], ['other_feature']),
     ]
     schema = text_format.Parse(
         """
     sparse_feature {
       name: 'sparse_feature'
       index_feature {
         name: 'index_feature1'
       }
       index_feature {
         name: 'index_feature2'
       }
       value_feature {
         name: 'value_feature'
       }
     }
     """, schema_pb2.Schema())
     # This is a semantically empty result which should not raise any anomalies.
     expected_result = {
         types.FeaturePath(['sparse_feature']):
         text_format.Parse(
             """
             path {
               step: 'sparse_feature'
             }
             custom_stats {
               name: 'missing_value'
               num: 0
             }
             custom_stats {
               name: 'missing_index'
               rank_histogram {
                 buckets {
                   label: 'index_feature1'
                 }
                 buckets {
                   label: 'index_feature2'
                 }
               }
             }
             custom_stats {
               name: 'max_length_diff'
               rank_histogram {
                 buckets {
                   label: 'index_feature1'
                 }
                 buckets {
                   label: 'index_feature2'
                 }
               }
             }
             custom_stats {
               name: 'min_length_diff'
               rank_histogram {
                 buckets {
                   label: 'index_feature1'
                 }
                 buckets {
                   label: 'index_feature2'
                 }
               }
             }""", statistics_pb2.FeatureNameStatistics())
     }
     generator = (
         sparse_feature_stats_generator.SparseFeatureStatsGenerator(schema))
     self.assertCombinerOutputEqual(batches, generator, expected_result)
 def test_sparse_feature_generator_with_struct_leaves(self):
     batches = [
         pa.RecordBatch.from_arrays([
             pa.array([[{
                 'value_feature': ['a'],
                 'index_feature1': [1],
                 'index_feature2': [2]
             }]]),
         ], ['parent']),
         pa.RecordBatch.from_arrays([
             pa.array([[{
                 'value_feature': ['a', 'b'],
                 'index_feature1': [1, 3],
                 'index_feature2': [2, 4]
             }]]),
         ], ['parent']),
     ]
     schema = text_format.Parse(
         """
     feature {
       name: 'parent'
       type: STRUCT
       struct_domain {
         feature {
           name: 'index_feature1'
         }
         feature {
           name: 'index_feature2'
         }
         feature {
           name: 'value_feature'
         }
         sparse_feature {
           name: 'sparse_feature'
           index_feature {
             name: 'index_feature1'
           }
           index_feature {
             name: 'index_feature2'
           }
           value_feature {
             name: 'value_feature'
           }
         }
       }
     }
     """, schema_pb2.Schema())
     expected_result = {
         types.FeaturePath(['parent', 'sparse_feature']):
         text_format.Parse(
             """
             path {
               step: 'parent'
               step: 'sparse_feature'
             }
             custom_stats {
               name: 'missing_value'
               num: 0
             }
             custom_stats {
               name: 'missing_index'
               rank_histogram {
                 buckets {
                   label: 'index_feature1'
                   sample_count: 0
                 }
                 buckets {
                   label: 'index_feature2'
                   sample_count: 0
                 }
               }
             }
             custom_stats {
               name: 'max_length_diff'
               rank_histogram {
                 buckets {
                   label: 'index_feature1'
                   sample_count: 0
                 }
                 buckets {
                   label: 'index_feature2'
                   sample_count: 0
                 }
               }
             }
             custom_stats {
               name: 'min_length_diff'
               rank_histogram {
                 buckets {
                   label: 'index_feature1'
                   sample_count: 0
                 }
                 buckets {
                   label: 'index_feature2'
                   sample_count: 0
                 }
               }
             }""", statistics_pb2.FeatureNameStatistics())
     }
     generator = (
         sparse_feature_stats_generator.SparseFeatureStatsGenerator(schema))
     self.assertCombinerOutputEqual(batches, generator, expected_result)
 def test_sparse_feature_generator_index_feature_not_in_batch(self):
     batches = [
         pa.Table.from_arrays([
             pa.array([['a'], ['a', 'b']]),
             pa.array([[1], [1, 3]]),
             pa.array([[2], [2, 4]])
         ], ['value_feature', 'index_feature1', 'not_index_feature2']),
     ]
     schema = text_format.Parse(
         """
     sparse_feature {
       name: 'sparse_feature'
       index_feature {
         name: 'index_feature1'
       }
       index_feature {
         name: 'index_feature2'
       }
       value_feature {
         name: 'value_feature'
       }
     }
     """, schema_pb2.Schema())
     expected_result = {
         types.FeaturePath(['sparse_feature']):
         text_format.Parse(
             """
             path {
               step: 'sparse_feature'
             }
             custom_stats {
               name: 'missing_value'
               num: 0
             }
             custom_stats {
               name: 'missing_index'
               rank_histogram {
                 buckets {
                   label: 'index_feature1'
                   sample_count: 0
                 }
                 buckets {
                   label: 'index_feature2'
                   sample_count: 2
                 }
               }
             }
             custom_stats {
               name: 'max_length_diff'
               rank_histogram {
                 buckets {
                   label: 'index_feature1'
                   sample_count: 0
                 }
                 buckets {
                   label: 'index_feature2'
                   sample_count: -1
                 }
               }
             }
             custom_stats {
               name: 'min_length_diff'
               rank_histogram {
                 buckets {
                   label: 'index_feature1'
                   sample_count: 0
                 }
                 buckets {
                   label: 'index_feature2'
                   sample_count: -2
                 }
               }
             }""", statistics_pb2.FeatureNameStatistics())
     }
     generator = (
         sparse_feature_stats_generator.SparseFeatureStatsGenerator(schema))
     self.assertCombinerOutputEqual(batches, generator, expected_result)
示例#6
0
def get_generators(
        options: stats_options.StatsOptions,
        in_memory: bool = False) -> List[stats_generator.StatsGenerator]:
    """Initializes the list of stats generators, including custom generators.

  Args:
    options: A StatsOptions object.
    in_memory: Whether the generators will be used to generate statistics in
      memory (True) or using Beam (False).

  Returns:
    A list of stats generator objects.
  """
    generators = [NumExamplesStatsGenerator(options.weight_feature)]
    if options.add_default_generators:
        generators.extend(_get_default_generators(options, in_memory))
    if options.generators:
        # Add custom stats generators.
        generators.extend(options.generators)
    if options.enable_semantic_domain_stats:
        semantic_domain_feature_stats_generators = [
            image_stats_generator.ImageStatsGenerator(),
            natural_language_domain_inferring_stats_generator.
            NLDomainInferringStatsGenerator(),
            time_stats_generator.TimeStatsGenerator(),
        ]
        # Wrap semantic domain feature stats generators as a separate combiner
        # stats generator, so that we can apply sampling only for those and other
        # feature stats generators are not affected by it.
        generators.append(
            CombinerFeatureStatsWrapperGenerator(
                semantic_domain_feature_stats_generators,
                sample_rate=options.semantic_domain_stats_sample_rate))
    if options.schema is not None:
        if _schema_has_sparse_features(options.schema):
            generators.append(
                sparse_feature_stats_generator.SparseFeatureStatsGenerator(
                    options.schema))
        if _schema_has_natural_language_domains(options.schema):
            generators.append(
                natural_language_stats_generator.NLStatsGenerator(
                    options.schema, options.vocab_paths,
                    options.num_histogram_buckets,
                    options.num_quantiles_histogram_buckets,
                    options.num_rank_histogram_buckets))
        if options.schema.weighted_feature:
            generators.append(
                weighted_feature_stats_generator.WeightedFeatureStatsGenerator(
                    options.schema))
        if options.label_feature and not in_memory:
            # The LiftStatsGenerator is not a CombinerStatsGenerator and therefore
            # cannot currenty be used for in_memory executions.
            generators.append(
                lift_stats_generator.LiftStatsGenerator(
                    y_path=types.FeaturePath([options.label_feature]),
                    schema=options.schema,
                    example_weight_map=options.example_weight_map,
                    output_custom_stats=True))

    # Replace all CombinerFeatureStatsGenerator with a single
    # CombinerFeatureStatsWrapperGenerator.
    feature_generators = [
        x for x in generators
        if isinstance(x, stats_generator.CombinerFeatureStatsGenerator)
    ]
    if feature_generators:
        generators = [
            x for x in generators
            if not isinstance(x, stats_generator.CombinerFeatureStatsGenerator)
        ] + [CombinerFeatureStatsWrapperGenerator(feature_generators)]
    if in_memory:
        for generator in generators:
            if not isinstance(generator,
                              stats_generator.CombinerStatsGenerator):
                raise TypeError(
                    'Statistics generator used in '
                    'generate_statistics_in_memory must '
                    'extend CombinerStatsGenerator, found object of '
                    'type %s.' % generator.__class__.__name__)
    return generators