示例#1
0
    def test_nl_generator_avg_word_heuristic_match(self):
        """Tests generator with avg word length heuristic."""
        generator = nlsg.NLStatsGenerator(values_threshold=2)
        input_batches = [
            pa.Column.from_array(
                'feature',
                pa.array([[
                    'This looks correct.', 'This one too, it should be text.'
                ], ['xosuhddsofuhg123fdgosh']])),
            pa.Column.from_array(
                'feature',
                pa.array(
                    [['This should be text as well',
                      'Here is another text']])),
            pa.Column.from_array(
                'feature',
                pa.array([['This should also be considered good.']])),
        ]

        self.assertCombinerOutputEqual(
            input_batches, generator,
            statistics_pb2.FeatureNameStatistics(custom_stats=[
                statistics_pb2.CustomStatistic(
                    name='domain_info', str='natural_language_domain {}'),
                statistics_pb2.CustomStatistic(
                    name='natural_language_match_rate', num=0.8333333)
            ]))
示例#2
0
    def test_nl_generator_values_threshold_check(self):
        """Tests generator values threshold with fake heuristic."""
        # Expected to give 6 matches.
        input_batches = [
            pa.Column.from_array(
                'feature', pa.array([['MATCH', 'MATCH', 'MATCH'], ['MATCH']])),
            pa.Column.from_array('feature', pa.array([['MATCH', 'MATCH']])),
            # Nones should be ignored.
            pa.Column.from_array('feature', pa.array([None, None])),
        ]
        # Try generators with values_threshold=7 (should not create stats) and
        # 6 (should create stats)
        generator = nlsg.NLStatsGenerator(_FakeHeuristic(), values_threshold=7)
        self.assertCombinerOutputEqual(input_batches, generator,
                                       statistics_pb2.FeatureNameStatistics())

        generator = nlsg.NLStatsGenerator(_FakeHeuristic(), values_threshold=6)
        self.assertCombinerOutputEqual(
            input_batches, generator,
            statistics_pb2.FeatureNameStatistics(custom_stats=[
                statistics_pb2.CustomStatistic(
                    name='domain_info', str='natural_language_domain {}'),
                statistics_pb2.CustomStatistic(
                    name='natural_language_match_rate', num=1.0)
            ]))
    def test_time_stats_generator_match_ratio_with_different_valid_formats(
            self):
        """Tests match ratio where valid values have different formats."""
        input_batches = [
            pa.array([[
                '2018-11-30', '2018/11/30', '20181130', '18-11-30', '18/11/30'
            ], [
                '11-30-2018', '11/30/2018', '11302018', '11/30/18', '11/30/18'
            ]]),
        ]
        # Any single format could satisfy the match_ratio, but this should identify
        # only the most common as the time format.
        generator = time_stats_generator.TimeStatsGenerator(match_ratio=0.05,
                                                            values_threshold=1)
        self.assertCombinerOutputEqual(
            input_batches, generator,
            statistics_pb2.FeatureNameStatistics(custom_stats=[
                statistics_pb2.CustomStatistic(
                    name='domain_info',
                    str="time_domain {string_format: '%m/%d/%y'}"),
                statistics_pb2.CustomStatistic(name='time_match_ratio',
                                               num=0.2),
            ]))

        # No single valid format satisfies the specified match_ratio, so this should
        # not create stats.
        generator = time_stats_generator.TimeStatsGenerator(match_ratio=0.3,
                                                            values_threshold=1)
        self.assertCombinerOutputEqual(input_batches, generator,
                                       statistics_pb2.FeatureNameStatistics())
示例#4
0
    def test_nl_generator_match_ratio_check(self):
        """Tests generator match ratio with fake heuristic."""
        input_batches = [
            pa.Column.from_array(
                'feature',
                pa.array([['MATCH', 'MATCH', 'MATCH'], ['MATCH', 'Nope']])),
            pa.Column.from_array('feature',
                                 pa.array([['MATCH', 'MATCH', 'MATCH']])),
            pa.Column.from_array('feature', pa.array([['12345', 'No']])),
        ]
        # Set values_threshold=5 so it always passes.
        # Try generators with match_ratio 0.71 (should not create stats) and
        # 0.69 (should create stats)
        generator = nlsg.NLStatsGenerator(_FakeHeuristic(),
                                          match_ratio=0.71,
                                          values_threshold=5)
        self.assertCombinerOutputEqual(input_batches, generator,
                                       statistics_pb2.FeatureNameStatistics())

        generator = nlsg.NLStatsGenerator(_FakeHeuristic(),
                                          match_ratio=0.69,
                                          values_threshold=5)
        self.assertCombinerOutputEqual(
            input_batches, generator,
            statistics_pb2.FeatureNameStatistics(custom_stats=[
                statistics_pb2.CustomStatistic(
                    name='domain_info', str='natural_language_domain {}'),
                statistics_pb2.CustomStatistic(
                    name='natural_language_match_rate', num=0.7)
            ]))
    def test_time_stats_generator_values_threshold_check(self):
        """Tests generator values threshold."""
        # Expected to give 6 matches with the same format.
        input_batches = [
            pa.array([['2018-11-30', '2018-11-30', '2018-11-30'],
                      ['2018-11-30']]),
            pa.array([['2018-11-30', '2018-11-30']]),
            pa.array([None, None]),
        ]
        # Try generator with values_threshold=7 (should not create stats).
        generator = time_stats_generator.TimeStatsGenerator(values_threshold=7)
        self.assertCombinerOutputEqual(input_batches, generator,
                                       statistics_pb2.FeatureNameStatistics())

        # Try generator with values_threshold=6 (should create stats).
        generator = time_stats_generator.TimeStatsGenerator(values_threshold=6)
        self.assertCombinerOutputEqual(
            input_batches, generator,
            statistics_pb2.FeatureNameStatistics(custom_stats=[
                statistics_pb2.CustomStatistic(
                    name='domain_info',
                    str="time_domain {string_format: '%Y-%m-%d'}"),
                statistics_pb2.CustomStatistic(name='time_match_ratio',
                                               num=1.0),
            ]))
 def test_time_stats_generator_match_ratio_with_same_valid_format(self):
     """Tests match ratio where all valid values have the same format."""
     input_batches = [
         pa.array([['2018-11-30', '2018-11-30', '2018-11-30'],
                   ['2018-11-30', '2018-11-30']]),
         pa.array([['not-valid', 'not-valid', 'not-valid'],
                   ['not-valid', 'not-valid']]),
     ]
     # Try generator with match_ratio 0.51 (should not create stats).
     generator = time_stats_generator.TimeStatsGenerator(match_ratio=0.51,
                                                         values_threshold=5)
     self.assertCombinerOutputEqual(input_batches, generator,
                                    statistics_pb2.FeatureNameStatistics())
     # Try generator with match_ratio 0.49 (should create stats).
     generator = time_stats_generator.TimeStatsGenerator(match_ratio=0.49,
                                                         values_threshold=5)
     self.assertCombinerOutputEqual(
         input_batches, generator,
         statistics_pb2.FeatureNameStatistics(custom_stats=[
             statistics_pb2.CustomStatistic(
                 name='domain_info',
                 str="time_domain {string_format: '%Y-%m-%d'}"),
             statistics_pb2.CustomStatistic(name='time_match_ratio',
                                            num=0.50),
         ]))
    def test_nl_generator_example_threshold_check(self):
        """Tests generator example threshold with fake heuristic."""
        # Expected to give 6 matches.
        input_batches = [
            [
                np.array(['MATCH', 'MATCH', 'MATCH']),
                np.array(['MATCH']),
            ],
            [
                np.array(['MATCH', 'MATCH']),
            ],
            # Nones should be ignored.
            [
                None,
                np.array([None] * 10),
            ],
        ]
        # Try generators with examples_threshold=7 (should not create stats) and
        # 6 (should create stats)
        generator = nlsg.NLStatsGenerator(_FakeHeuristic(),
                                          examples_threshold=7)
        self.assertCombinerOutputEqual(input_batches, generator,
                                       statistics_pb2.FeatureNameStatistics())

        generator = nlsg.NLStatsGenerator(_FakeHeuristic(),
                                          examples_threshold=6)
        self.assertCombinerOutputEqual(
            input_batches, generator,
            statistics_pb2.FeatureNameStatistics(custom_stats=[
                statistics_pb2.CustomStatistic(
                    name='domain_info', str='natural_language_domain {}'),
                statistics_pb2.CustomStatistic(
                    name='natural_language_match_rate', num=1.0)
            ]))
 def test_time_stats_generator_combined_formats(self):
     """Tests that the generator handles combined formats."""
     # The combined format is the most common, since the generator should count
     # it only as the combined format and not its component parts.
     input_batches = [[np.array(['2018/11/30 23:59', '2018/12/01 23:59'])],
                      [np.array(['2018/11/30 23:59', '23:59'])],
                      [np.array(['2018/11/30', '2018/11/30'])]]
     generator = time_stats_generator.TimeStatsGenerator(match_ratio=0.1,
                                                         values_threshold=1)
     self.assertCombinerOutputEqual(
         input_batches, generator,
         statistics_pb2.FeatureNameStatistics(custom_stats=[
             statistics_pb2.CustomStatistic(
                 name='domain_info',
                 str="time_domain {format: '%Y/%m/%d %H:%M'}"),
             statistics_pb2.CustomStatistic(name='time_match_ratio',
                                            num=0.5),
         ]))
示例#9
0
 def _create_expected_feature_name_statistics(self,
                                              feature_coverage=None,
                                              avg_token_length=None):
     custom_stats = []
     nls = statistics_pb2.NaturalLanguageStatistics()
     if feature_coverage is not None:
         nls.feature_coverage = feature_coverage
         custom_stats.append(
             statistics_pb2.CustomStatistic(name='nl_feature_coverage',
                                            num=feature_coverage))
     if avg_token_length is not None:
         nls.avg_token_length = avg_token_length
         custom_stats.append(
             statistics_pb2.CustomStatistic(name='nl_avg_token_length',
                                            num=avg_token_length))
     my_proto = any_pb2.Any()
     custom_stats.append(
         statistics_pb2.CustomStatistic(name='nl_statistics',
                                        any=my_proto.Pack(nls)))
     return statistics_pb2.FeatureNameStatistics(custom_stats=custom_stats)
 def test_time_stats_generator_integer_formats(self):
     """Tests that the generator handles integer formats."""
     # Three of values are within the valid range for Unix seconds, one is within
     # the valid range for Unix milliseconds, and the other two are not within
     # the valid range for any integer time formats.
     input_batches = [
         pa.array([[631152001, 631152002]]),
         pa.array([[631152003, 631152000001]]),
         pa.array([[1, 2]])
     ]
     generator = time_stats_generator.TimeStatsGenerator(match_ratio=0.1,
                                                         values_threshold=1)
     assert schema_pb2.TimeDomain.UNIX_SECONDS == 1
     self.assertCombinerOutputEqual(
         input_batches, generator,
         statistics_pb2.FeatureNameStatistics(custom_stats=[
             statistics_pb2.CustomStatistic(
                 name='domain_info',
                 str=('time_domain {integer_format: 1}')),
             statistics_pb2.CustomStatistic(name='time_match_ratio',
                                            num=0.5),
         ]))
def _make_num_values_custom_stats_proto(
    common_stats: _PartialCommonStats,
    num_histogram_buckets: int,
    ) -> List[statistics_pb2.CustomStatistic]:
  """Returns a list of CustomStatistic protos that contains histograms.

  Those histograms captures the distribution of number of values at each
  nest level.

  It will only create histograms for nest levels greater than 1. Because
  the histogram of nest level 1 is already in
  CommonStatistics.num_values_histogram.

  Args:
    common_stats: a _PartialCommonStats.
    num_histogram_buckets: number of buckets in the histogram.
  Returns:
    a (potentially empty) list of statistics_pb2.CustomStatistic.
  """
  result = []
  if common_stats.type is None:
    return result
  presence_and_valency_stats = common_stats.presence_and_valency_stats
  if presence_and_valency_stats is None:
    return result

  # The top level histogram is included in CommonStats -- skip.
  for level, presence_and_valency, parent_presence_and_valency in zip(
      itertools.count(2), presence_and_valency_stats[1:],
      presence_and_valency_stats):
    num_values_quantiles = (
        presence_and_valency.num_values_summary.GetQuantiles(
            num_histogram_buckets).flatten().to_pylist())
    histogram = quantiles_util.generate_quantiles_histogram(
        num_values_quantiles, parent_presence_and_valency.num_non_missing,
        num_histogram_buckets)
    proto = statistics_pb2.CustomStatistic()
    proto.name = 'level_{}_value_list_length'.format(level)
    proto.histogram.CopyFrom(histogram)
    result.append(proto)
  return result
示例#12
0
    def _create_expected_feature_name_statistics(
            self,
            feature_coverage=None,
            avg_token_length=None,
            min_sequence_length=None,
            max_sequence_length=None,
            token_len_quantiles=None,
            sequence_len_quantiles=None,
            sorted_token_names_and_counts=None,
            reported_sequences=None,
            token_statistics=None):
        nls = statistics_pb2.NaturalLanguageStatistics()
        if feature_coverage is not None:
            nls.feature_coverage = feature_coverage
        if avg_token_length:
            nls.avg_token_length = avg_token_length
        if min_sequence_length:
            nls.min_sequence_length = min_sequence_length
        if max_sequence_length:
            nls.max_sequence_length = max_sequence_length
        if token_len_quantiles:
            for low_value, high_value, sample_count in token_len_quantiles:
                nls.token_length_histogram.type = statistics_pb2.Histogram.QUANTILES
                nls.token_length_histogram.buckets.add(
                    low_value=low_value,
                    high_value=high_value,
                    sample_count=sample_count)
        if sequence_len_quantiles:
            for low_value, high_value, sample_count in sequence_len_quantiles:
                nls.sequence_length_histogram.type = statistics_pb2.Histogram.QUANTILES
                nls.sequence_length_histogram.buckets.add(
                    low_value=low_value,
                    high_value=high_value,
                    sample_count=sample_count)
        if sorted_token_names_and_counts:
            for index, (token_name,
                        count) in enumerate(sorted_token_names_and_counts):
                nls.rank_histogram.buckets.add(low_rank=index,
                                               high_rank=index,
                                               label=token_name,
                                               sample_count=count)
        if token_statistics:
            for k, v in token_statistics.items():
                ts = nls.token_statistics.add(frequency=v[0],
                                              fraction_of_sequences=v[1],
                                              per_sequence_min_frequency=v[2],
                                              per_sequence_max_frequency=v[3],
                                              per_sequence_avg_frequency=v[4])
                if isinstance(k, str):
                    ts.string_token = k
                else:
                    ts.int_token = k
                ts.positions.CopyFrom(v[5])
        if reported_sequences:
            for r in reported_sequences:
                nls.reported_sequences.append(str(r))

        custom_nl_stats = statistics_pb2.CustomStatistic(name='nl_statistics')
        custom_nl_stats.any.Pack(nls)
        return statistics_pb2.FeatureNameStatistics(
            custom_stats=[custom_nl_stats])
示例#13
0
 def _create_expected_feature_name_statistics(
         self,
         feature_coverage=None,
         avg_token_length=None,
         token_len_quantiles=None,
         sorted_token_names_and_counts=None,
         reported_sequences=None,
         token_statistics=None):
     custom_stats = []
     nls = statistics_pb2.NaturalLanguageStatistics()
     if feature_coverage is not None:
         nls.feature_coverage = feature_coverage
         custom_stats.append(
             statistics_pb2.CustomStatistic(name='nl_feature_coverage',
                                            num=feature_coverage))
     if avg_token_length:
         nls.avg_token_length = avg_token_length
         custom_stats.append(
             statistics_pb2.CustomStatistic(name='nl_avg_token_length',
                                            num=nls.avg_token_length))
     if token_len_quantiles:
         for low_value, high_value, sample_count in token_len_quantiles:
             nls.token_length_histogram.type = statistics_pb2.Histogram.QUANTILES
             nls.token_length_histogram.buckets.add(
                 low_value=low_value,
                 high_value=high_value,
                 sample_count=sample_count)
         custom_stats.append(
             statistics_pb2.CustomStatistic(
                 name='nl_token_length_histogram',
                 histogram=nls.token_length_histogram))
     if sorted_token_names_and_counts:
         for index, (token_name,
                     count) in enumerate(sorted_token_names_and_counts):
             nls.rank_histogram.buckets.add(low_rank=index,
                                            high_rank=index,
                                            label=token_name,
                                            sample_count=count)
         custom_stats.append(
             statistics_pb2.CustomStatistic(
                 name='nl_rank_tokens', rank_histogram=nls.rank_histogram))
     if token_statistics:
         for k, v in token_statistics.items():
             ts = nls.token_statistics.add(frequency=v[0],
                                           fraction_of_sequences=v[1],
                                           per_sequence_min_frequency=v[2],
                                           per_sequence_max_frequency=v[3],
                                           per_sequence_avg_frequency=v[4])
             if isinstance(k, str):
                 ts.string_token = k
             else:
                 ts.int_token = k
             ts.positions.CopyFrom(v[5])
             custom_stats.append(
                 statistics_pb2.CustomStatistic(
                     name='nl_{}_token_frequency'.format(k),
                     num=ts.frequency))
             custom_stats.append(
                 statistics_pb2.CustomStatistic(
                     name='nl_{}_fraction_of_examples'.format(k),
                     num=ts.fraction_of_sequences))
             custom_stats.append(
                 statistics_pb2.CustomStatistic(
                     name='nl_{}_per_sequence_min_frequency'.format(k),
                     num=ts.per_sequence_min_frequency))
             custom_stats.append(
                 statistics_pb2.CustomStatistic(
                     name='nl_{}_per_sequence_max_frequency'.format(k),
                     num=ts.per_sequence_max_frequency))
             custom_stats.append(
                 statistics_pb2.CustomStatistic(
                     name='nl_{}_per_sequence_avg_frequency'.format(k),
                     num=ts.per_sequence_avg_frequency))
             custom_stats.append(
                 statistics_pb2.CustomStatistic(
                     name='nl_{}_token_positions'.format(k),
                     histogram=ts.positions))
     if reported_sequences:
         for r in reported_sequences:
             nls.reported_sequences.append(str(r))
         str_reported_sequences = '\n'.join(nls.reported_sequences)
         custom_stats.append(
             statistics_pb2.CustomStatistic(name='nl_reported_sequences',
                                            str=str_reported_sequences))
     my_proto = any_pb2.Any()
     custom_stats.append(
         statistics_pb2.CustomStatistic(name='nl_statistics',
                                        any=my_proto.Pack(nls)))
     return statistics_pb2.FeatureNameStatistics(custom_stats=custom_stats)