def _verify(output):
      """Verifies that the output meeds the expectations."""
      if only_match_expected_feature_stats:
        features_in_stats = set(
            [types.FeaturePath.from_proto(f.path) for f in output.features])
        self.assertTrue(set(expected_feature_stats.keys())
                        .issubset(features_in_stats))
      else:
        self.assertEqual(  # pylint: disable=g-generic-assert
            len(output.features), len(expected_feature_stats),
            '{}, {}'.format(output, expected_feature_stats))
      for actual_feature_stats in output.features:
        actual_path = types.FeaturePath.from_proto(actual_feature_stats.path)
        expected_stats = expected_feature_stats.get(actual_path)
        if (only_match_expected_feature_stats and expected_stats is None):
          continue
        compare.assertProtoEqual(
            self,
            actual_feature_stats,
            expected_stats,
            normalize_numbers=True)

      self.assertEqual(  # pylint: disable=g-generic-assert
          len(result.cross_features), len(expected_cross_feature_stats),
          '{}, {}'.format(result, expected_cross_feature_stats))
      for actual_cross_feature_stats in result.cross_features:
        cross = (actual_cross_feature_stats.path_x.step[0],
                 actual_cross_feature_stats.path_y.step[0])
        compare.assertProtoEqual(
            self,
            actual_cross_feature_stats,
            expected_cross_feature_stats[cross],
            normalize_numbers=True)
 def _assert_combiner_output_equal(self, statistics, combiner, expected):
   accumulators = [
       combiner.add_input(combiner.create_accumulator(), statistic)
       for statistic in statistics
   ]
   actual = combiner.extract_output(combiner.merge_accumulators(accumulators))
   compare.assertProtoEqual(self, actual, expected, normalize_numbers=True)
示例#3
0
def assert_feature_proto_equal(test, actual, expected):
    """Ensures feature protos are equal.

  Args:
    test: The test case.
    actual: The actual feature proto.
    expected: The expected feature proto.
  """

    test.assertEqual(len(actual.custom_stats), len(expected.custom_stats))
    expected_custom_stats = {}
    for expected_custom_stat in expected.custom_stats:
        expected_custom_stats[expected_custom_stat.name] = expected_custom_stat

    for actual_custom_stat in actual.custom_stats:
        test.assertTrue(actual_custom_stat.name in expected_custom_stats)
        expected_custom_stat = expected_custom_stats[actual_custom_stat.name]
        compare.assertProtoEqual(test,
                                 actual_custom_stat,
                                 expected_custom_stat,
                                 normalize_numbers=True)
    del actual.custom_stats[:]
    del expected.custom_stats[:]

    # Compare the rest of the proto without numeric custom stats
    compare.assertProtoEqual(test, actual, expected, normalize_numbers=True)
示例#4
0
def assert_feature_proto_equal(
        test: absltest.TestCase, actual: statistics_pb2.FeatureNameStatistics,
        expected: statistics_pb2.FeatureNameStatistics) -> None:
    """Ensures feature protos are equal.

  Args:
    test: The test case.
    actual: The actual feature proto.
    expected: The expected feature proto.
  """

    test.assertLen(actual.custom_stats, len(expected.custom_stats))
    expected_custom_stats = {}
    for expected_custom_stat in expected.custom_stats:
        expected_custom_stats[expected_custom_stat.name] = expected_custom_stat

    for actual_custom_stat in actual.custom_stats:
        test.assertIn(actual_custom_stat.name, expected_custom_stats)
        expected_custom_stat = expected_custom_stats[actual_custom_stat.name]
        compare.assertProtoEqual(test,
                                 expected_custom_stat,
                                 actual_custom_stat,
                                 normalize_numbers=True)
    del actual.custom_stats[:]
    del expected.custom_stats[:]

    # Compare the rest of the proto without numeric custom stats
    compare.assertProtoEqual(test, expected, actual, normalize_numbers=True)
示例#5
0
 def testAssertEqualWithStringArg(self):
   pb = compare_test_pb2.Large()
   pb.string_ = 'abc'
   pb.float_ = 1.234
   compare.assertProtoEqual(self, """
         string_: 'abc'
         float_: 1.234
       """, pb)
 def _assert_mi_output_equal(self, batch, expected, schema, label_feature):
     """Checks that MI computation is correct."""
     actual = sklearn_mutual_information.SkLearnMutualInformation(
         label_feature, schema, TEST_SEED).compute(batch)
     compare.assertProtoEqual(self,
                              actual,
                              expected,
                              normalize_numbers=True)
示例#7
0
 def testAssertEqualWithStringArg(self):
   pb = compare_test_pb2.Large()
   pb.string_ = 'abc'
   pb.float_ = 1.234
   compare.assertProtoEqual(self, """
         string_: 'abc'
         float_: 1.234
       """, pb)
示例#8
0
  def testLargeProtoData(self):
    # Proto size should be larger than 2**16.
    number_of_entries = 2**13
    string_value = 'dummystr'  # Has length of 2**3.
    pb1_txt = 'strings: "dummystr"\n' * number_of_entries
    pb2 = compare_test_pb2.Small(strings=[string_value] * number_of_entries)
    compare.assertProtoEqual(self, pb1_txt, pb2)

    with self.assertRaises(AssertionError):
      compare.assertProtoEqual(self, pb1_txt + 'strings: "Should fail."', pb2)
示例#9
0
 def test_remove_anomaly_types_removes_diff_regions(self):
     anomaly_types_to_remove = set([
         anomalies_pb2.AnomalyInfo.ENUM_TYPE_BYTES_NOT_STRING,
     ])
     # The anomaly_info has multiple diff regions.
     anomalies = text_format.Parse(
         """
    anomaly_info {
      key: "feature_1"
      value {
           description: "Expected bytes but got string. Examples contain "
             "values missing from the schema."
           severity: ERROR
           short_description: "Multiple errors"
           diff_regions {
             removed {
               start: 1
               contents: "Test contents"
             }
           }
           diff_regions {
             added {
               start: 1
               contents: "Test contents"
             }
           }
           reason {
             type: ENUM_TYPE_BYTES_NOT_STRING
             short_description: "Bytes not string"
             description: "Expected bytes but got string."
           }
           reason {
             type: ENUM_TYPE_UNEXPECTED_STRING_VALUES
             short_description: "Unexpected string values"
             description: "Examples contain values missing from the schema."
           }
       }
     }""", anomalies_pb2.Anomalies())
     expected_result = text_format.Parse(
         """
    anomaly_info {
      key: "feature_1"
      value {
           description: "Examples contain values missing from the schema."
           severity: ERROR
           short_description: "Unexpected string values"
           reason {
             type: ENUM_TYPE_UNEXPECTED_STRING_VALUES
             short_description: "Unexpected string values"
             description: "Examples contain values missing from the schema."
           }
       }
     }""", anomalies_pb2.Anomalies())
     anomalies_util.remove_anomaly_types(anomalies, anomaly_types_to_remove)
     compare.assertProtoEqual(self, anomalies, expected_result)
示例#10
0
 def test_remove_anomaly_types_does_not_change_proto(
         self, anomaly_types_to_remove, input_anomalies_proto_text):
     """Tests where remove_anomaly_types does not modify the Anomalies proto."""
     input_anomalies_proto = text_format.Parse(input_anomalies_proto_text,
                                               anomalies_pb2.Anomalies())
     expected_anomalies_proto = anomalies_pb2.Anomalies()
     expected_anomalies_proto.CopyFrom(input_anomalies_proto)
     anomalies_util.remove_anomaly_types(input_anomalies_proto,
                                         anomaly_types_to_remove)
     compare.assertProtoEqual(self, input_anomalies_proto,
                              expected_anomalies_proto)
  def test_generate_statistics_in_memory_empty_examples(self):
    examples = []
    expected_result = text_format.Parse(
        """
        datasets {
          num_examples: 0
        }""", statistics_pb2.DatasetFeatureStatisticsList())

    result = stats_impl.generate_statistics_in_memory(examples)
    compare.assertProtoEqual(
        self, result, expected_result, normalize_numbers=True)
示例#12
0
 def _equal(actual_results):
     """Matcher for comparing a list of DatasetFeatureStatistics protos."""
     test.assertLen(expected_results, len(actual_results))
     # Sort both list of protos based on their string presentation to make
     # sure the sort is stable.
     sorted_expected_results = sorted(expected_results, key=str)
     sorted_actual_results = sorted(actual_results, key=str)
     for index, actual in enumerate(sorted_actual_results):
         compare.assertProtoEqual(test,
                                  actual,
                                  sorted_expected_results[index],
                                  normalize_numbers=True)
示例#13
0
def assert_feature_proto_equal_with_error_on_custom_stats(
        test,
        actual,
        expected,
        relative_error_threshold=0.05,
        absolute_error_threshold=0.05):
    """Compares feature protos and ensures custom stats are almost equal.

  A numeric custom stat is almost equal if
  expected * (1 - relative_error_threshold) - absolute_error_threshold < actual
  AND
  actual < expected * (1 + relative_error_threshold) + absolute_error_threshold

  All other proto fields are compared directly.

  Args:
    test: The test case.
    actual: The actual feature proto.
    expected: The expected feature proto.
    relative_error_threshold: The relative error permitted between custom stats
      in expected and actual.
    absolute_error_threshold: The absolute error permitted between custom stats
      in expected and actual.
  """

    test.assertEqual(len(actual.custom_stats), len(expected.custom_stats))
    expected_custom_stats = {}
    for expected_custom_stat in expected.custom_stats:
        expected_custom_stats[expected_custom_stat.name] = expected_custom_stat

    for i, actual_custom_stat in enumerate(actual.custom_stats):
        test.assertTrue(actual_custom_stat.name in expected_custom_stats)
        expected_custom_stat = expected_custom_stats[actual_custom_stat.name]
        # Compare numeric custom stats with error margin
        if actual_custom_stat.WhichOneof(
                'val') == 'num' and expected_custom_stat.WhichOneof(
                    'val') == 'num':
            test.assertBetween(
                actual_custom_stat.num,
                expected_custom_stat.num * (1 - relative_error_threshold) -
                absolute_error_threshold,
                expected_custom_stat.num * (1 + relative_error_threshold) +
                absolute_error_threshold,
                msg=actual_custom_stat.name +
                ' is not within the expected range.')
            del actual.custom_stats[i]
            del expected.custom_stats[i]

        # Compare the rest of the proto without numeric custom stats
        compare.assertProtoEqual(test,
                                 actual,
                                 expected,
                                 normalize_numbers=True)
示例#14
0
 def test_remove_anomaly_types_changes_proto(self, anomaly_types_to_remove,
                                             input_anomalies_proto_text,
                                             expected_anomalies_proto_text):
     """Tests where remove_anomaly_types modifies the Anomalies proto."""
     input_anomalies_proto = text_format.Parse(input_anomalies_proto_text,
                                               anomalies_pb2.Anomalies())
     expected_anomalies_proto = text_format.Parse(
         expected_anomalies_proto_text, anomalies_pb2.Anomalies())
     anomalies_util.remove_anomaly_types(input_anomalies_proto,
                                         anomaly_types_to_remove)
     compare.assertProtoEqual(self, input_anomalies_proto,
                              expected_anomalies_proto)
示例#15
0
    def _AssertProtoEquals(self, a, b):
        """Asserts that a and b are the same proto.

    Uses ProtoEq() first, as it returns correct results
    for floating point attributes, and then use assertProtoEqual()
    in case of failure as it provides good error messages.

    Args:
      a: a proto.
      b: another proto.
    """
        if not compare.ProtoEq(a, b):
            compare.assertProtoEqual(self, a, b, normalize_numbers=True)
示例#16
0
  def _AssertProtoEquals(self, a, b):
    """Asserts that a and b are the same proto.

    Uses ProtoEq() first, as it returns correct results
    for floating point attributes, and then use assertProtoEqual()
    in case of failure as it provides good error messages.

    Args:
      a: a proto.
      b: another proto.
    """
    if not compare.ProtoEq(a, b):
      compare.assertProtoEqual(self, a, b, normalize_numbers=True)
示例#17
0
 def test_make_dataset_feature_stats_proto(self):
     stats = {
         types.FeaturePath(['feature_1']): {
             'Mutual Information': 0.5,
             'Correlation': 0.1
         },
         types.FeaturePath(['feature_2']): {
             'Mutual Information': 0.8,
             'Correlation': 0.6
         }
     }
     expected = {
         types.FeaturePath(['feature_1']):
         text_format.Parse(
             """
         path {
           step: 'feature_1'
         }
         custom_stats {
           name: 'Correlation'
           num: 0.1
         }
         custom_stats {
           name: 'Mutual Information'
           num: 0.5
         }
        """, statistics_pb2.FeatureNameStatistics()),
         types.FeaturePath(['feature_2']):
         text_format.Parse(
             """
         path {
           step: 'feature_2'
         }
         custom_stats {
           name: 'Correlation'
           num: 0.6
         }
         custom_stats {
           name: 'Mutual Information'
           num: 0.8
         }
        """, statistics_pb2.FeatureNameStatistics())
     }
     actual = stats_util.make_dataset_feature_stats_proto(stats)
     self.assertEqual(len(actual.features), len(expected))
     for actual_feature_stats in actual.features:
         compare.assertProtoEqual(self,
                                  actual_feature_stats,
                                  expected[types.FeaturePath.from_proto(
                                      actual_feature_stats.path)],
                                  normalize_numbers=True)
 def _equal(actual_results: List[
     Tuple[types.SliceKey, statistics_pb2.DatasetFeatureStatistics]]):
   """Matcher for comparing a list of DatasetFeatureStatistics protos."""
   if len(actual_results) == 1 and len(expected_results) == 1:
     # If appropriate use proto matcher for better errors
     test.assertEqual(expected_results[0][0], actual_results[0][0])
     compare.assertProtoEqual(test, expected_results[0][1],
                              actual_results[0][1], normalize_numbers=True)
   else:
     test.assertCountEqual(
         [(k, _DatasetFeatureStatisticsComparatorWrapper(v))
          for k, v in expected_results],
         [(k, _DatasetFeatureStatisticsComparatorWrapper(v))
          for k, v in actual_results])
 def test_make_feature_stats_proto_with_topk_stats_weighted(self):
   expected_result = text_format.Parse(
       """
       path {
         step: 'fa'
       }
       type: STRING
       string_stats {
         weighted_string_stats {
           top_values {
             value: 'a'
             frequency: 4
           }
           top_values {
             value: 'c'
             frequency: 3
           }
           top_values {
             value: 'd'
             frequency: 2
           }
           rank_histogram {
             buckets {
               low_rank: 0
               high_rank: 0
               label: "a"
               sample_count: 4.0
             }
             buckets {
               low_rank: 1
               high_rank: 1
               label: "c"
               sample_count: 3.0
             }
           }
         }
   }""", statistics_pb2.FeatureNameStatistics())
   value_counts = [('a', 4), ('c', 3), ('d', 2), ('b', 2)]
   top_k_value_count_list = [
       top_k_uniques_stats_generator.FeatureValueCount(
           value_count[0], value_count[1])
       for value_count in value_counts
   ]
   result = (
       top_k_uniques_stats_generator
       .make_feature_stats_proto_with_topk_stats(
           types.FeaturePath(['fa']),
           top_k_value_count_list, False, True, 3, 1, 2))
   compare.assertProtoEqual(self, result, expected_result)
示例#20
0
 def test_load_sharded_pattern(self):
     full_stats_proto = statistics_pb2.DatasetFeatureStatisticsList()
     text_format.Parse(_STATS_PROTO, full_stats_proto)
     tmp_dir = self.create_tempdir()
     tmp_path = os.path.join(tmp_dir, 'statistics-0-of-1')
     writer = tf.compat.v1.io.TFRecordWriter(tmp_path)
     for dataset in full_stats_proto.datasets:
         shard = statistics_pb2.DatasetFeatureStatisticsList()
         shard.datasets.append(dataset)
         writer.write(shard.SerializeToString())
     writer.close()
     view = stats_util.load_sharded_statistics(
         input_path_prefix=tmp_path.rstrip('-0-of-1'),
         io_provider=statistics_io_impl.get_io_provider('tfrecords'))
     compare.assertProtoEqual(self, view.proto(), full_stats_proto)
示例#21
0
 def assertCombinerOutputEqual(self, batches, generator, expected_result):
     """Tests a combiner statistics generator."""
     accumulators = [
         generator.add_input(generator.create_accumulator(), batch)
         for batch in batches
     ]
     result = generator.extract_output(
         generator.merge_accumulators(accumulators))
     self.assertEqual(len(result.features), len(expected_result))
     for actual_feature_stats in result.features:
         compare.assertProtoEqual(
             self,
             actual_feature_stats,
             expected_result[actual_feature_stats.name],
             normalize_numbers=True)
示例#22
0
    def assertCombinerOutputEqual(self, batches, generator, expected_result):
        """Tests a combiner statistics generator.

    This runs the generator twice to cover different behavior. There must be at
    least two input batches in order to test the generator's merging behavior.

    Args:
      batches: A list of batches of test data.
      generator: The CombinerStatsGenerator to test.
      expected_result: Dict mapping feature name to FeatureNameStatistics proto
        that it is expected the generator will return for the feature.
    """
        # Run generator to check that merge_accumulators() works correctly.
        accumulators = [
            generator.add_input(generator.create_accumulator(), batch)
            for batch in batches
        ]
        result = generator.extract_output(
            generator.merge_accumulators(accumulators))
        self.assertEqual(  # pylint: disable=g-generic-assert
            len(result.features), len(expected_result),
            '{}, {}'.format(result, expected_result))
        for actual_feature_stats in result.features:
            compare.assertProtoEqual(
                self,
                actual_feature_stats,
                expected_result[types.FeaturePath.from_proto(
                    actual_feature_stats.path)],
                normalize_numbers=True)

        # Run generator to check that add_input() works correctly when adding
        # inputs to a non-empty accumulator.
        accumulator = generator.create_accumulator()

        for batch in batches:
            accumulator = generator.add_input(accumulator, batch)

        result = generator.extract_output(accumulator)
        self.assertEqual(len(result.features), len(expected_result))  # pylint: disable=g-generic-assert
        for actual_feature_stats in result.features:
            compare.assertProtoEqual(
                self,
                actual_feature_stats,
                expected_result[types.FeaturePath.from_proto(
                    actual_feature_stats.path)],
                normalize_numbers=True)
 def test_make_dataset_feature_stats_proto(self):
     stats = {
         'feature_1': {
             'Mutual Information': 0.5,
             'Correlation': 0.1
         },
         'feature_2': {
             'Mutual Information': 0.8,
             'Correlation': 0.6
         }
     }
     expected = {
         'feature_1':
         text_format.Parse(
             """
         name: 'feature_1'
         custom_stats {
           name: 'Correlation'
           num: 0.1
         }
         custom_stats {
           name: 'Mutual Information'
           num: 0.5
         }
        """, statistics_pb2.FeatureNameStatistics()),
         'feature_2':
         text_format.Parse(
             """
         name: 'feature_2'
         custom_stats {
           name: 'Correlation'
           num: 0.6
         }
         custom_stats {
           name: 'Mutual Information'
           num: 0.8
         }
        """, statistics_pb2.FeatureNameStatistics())
     }
     actual = stats_util.make_dataset_feature_stats_proto(stats)
     self.assertEqual(len(actual.features), len(expected))
     for actual_feature_stats in actual.features:
         compare.assertProtoEqual(self,
                                  actual_feature_stats,
                                  expected[actual_feature_stats.name],
                                  normalize_numbers=True)
示例#24
0
    def assertCombinerOutputEqual(
        self,
        input_batches: List[types.ValueBatch],
        generator: stats_generator.CombinerFeatureStatsGenerator,
        expected_result: statistics_pb2.FeatureNameStatistics,
        feature_path: types.FeaturePath = types.FeaturePath([''])
    ) -> None:
        """Tests a feature combiner statistics generator.

    This runs the generator twice to cover different behavior. There must be at
    least two input batches in order to test the generator's merging behavior.

    Args:
      input_batches: A list of batches of test data.
      generator: The CombinerFeatureStatsGenerator to test.
      expected_result: The FeatureNameStatistics proto that it is expected the
        generator will return.
      feature_path: The FeaturePath to use, if not specified, will set a
        default value.
    """
        # Run generator to check that merge_accumulators() works correctly.
        accumulators = [
            generator.add_input(generator.create_accumulator(), feature_path,
                                input_batch) for input_batch in input_batches
        ]
        result = generator.extract_output(
            generator.merge_accumulators(accumulators))
        compare.assertProtoEqual(self,
                                 result,
                                 expected_result,
                                 normalize_numbers=True)

        # Run generator to check that add_input() works correctly when adding
        # inputs to a non-empty accumulator.
        accumulator = generator.create_accumulator()

        for input_batch in input_batches:
            accumulator = generator.add_input(accumulator, feature_path,
                                              input_batch)

        result = generator.extract_output(accumulator)
        compare.assertProtoEqual(self,
                                 result,
                                 expected_result,
                                 normalize_numbers=True)
示例#25
0
 def test_make_feature_stats_proto_with_topk_stats_unsorted_value_counts(self):
   expected_result = text_format.Parse(
       """
       name: 'fa'
       type: STRING
       string_stats {
         top_values {
           value: 'a'
           frequency: 4
         }
         top_values {
           value: 'c'
           frequency: 3
         }
         top_values {
           value: 'd'
           frequency: 2
         }
         rank_histogram {
           buckets {
             low_rank: 0
             high_rank: 0
             label: "a"
             sample_count: 4.0
           }
           buckets {
             low_rank: 1
             high_rank: 1
             label: "c"
             sample_count: 3.0
           }
         }
   }""", statistics_pb2.FeatureNameStatistics())
   # 'b' has a lower count than 'c'.
   value_counts = [('a', 4), ('b', 2), ('c', 3), ('d', 2)]
   top_k_value_count_list = [
       top_k_stats_generator.FeatureValueCount(value_count[0], value_count[1])
       for value_count in value_counts
   ]
   result = top_k_stats_generator.make_feature_stats_proto_with_topk_stats(
       'fa', top_k_value_count_list, False, False, 3, 2)
   compare.assertProtoEqual(self, result, expected_result)
示例#26
0
  def _matcher(actual):
    """Matcher function for comparing DatasetFeatureStatisticsList proto."""
    try:
      test.assertEqual(len(actual), 1)
      # Get the dataset stats from DatasetFeatureStatisticsList proto.
      actual_stats = actual[0].datasets[0]
      expected_stats = expected_result.datasets[0]

      test.assertEqual(actual_stats.num_examples, expected_stats.num_examples)
      test.assertEqual(len(actual_stats.features), len(expected_stats.features))

      expected_features = {}
      for feature in expected_stats.features:
        expected_features[feature.name] = feature

      for feature in actual_stats.features:
        compare.assertProtoEqual(
            test,
            feature,
            expected_features[feature.name],
            normalize_numbers=True)
    except AssertionError, e:
      raise util.BeamAssertException('Failed assert: ' + str(e))
示例#27
0
    def assertCombinerOutputEqual(self, input_batches, generator,
                                  expected_result):
        """Tests a feature combiner statistics generator.

    This runs the generator twice to cover different behavior. There must be at
    least two input batches in order to test the generator's merging behavior.

    Args:
      input_batches: A list of batches of test data.
      generator: The CombinerFeatureStatsGenerator to test.
      expected_result: The FeatureNameStatistics proto that it is expected the
        generator will return.
    """
        # Run generator to check that merge_accumulators() works correctly.
        accumulators = [
            generator.add_input(generator.create_accumulator(), input_batch)
            for input_batch in input_batches
        ]
        result = generator.extract_output(
            generator.merge_accumulators(accumulators))
        compare.assertProtoEqual(self,
                                 result,
                                 expected_result,
                                 normalize_numbers=True)

        # Run generator to check that add_input() works correctly when adding
        # inputs to a non-empty accumulator.
        accumulator = generator.create_accumulator()

        for input_batch in input_batches:
            accumulator = generator.add_input(accumulator, input_batch)

        result = generator.extract_output(accumulator)
        compare.assertProtoEqual(self,
                                 result,
                                 expected_result,
                                 normalize_numbers=True)
示例#28
0
    def test_valid_stats_options_json_round_trip(self):
        feature_allowlist = ['a']
        schema = schema_pb2.Schema(feature=[schema_pb2.Feature(name='f')])
        vocab_paths = {'a': '/path/to/a'}
        label_feature = 'label'
        weight_feature = 'weight'
        sample_rate = 0.01
        num_top_values = 21
        frequency_threshold = 2
        weighted_frequency_threshold = 2.0
        num_rank_histogram_buckets = 1001
        num_values_histogram_buckets = 11
        num_histogram_buckets = 11
        num_quantiles_histogram_buckets = 11
        epsilon = 0.02
        infer_type_from_schema = True
        desired_batch_size = 100
        enable_semantic_domain_stats = True
        semantic_domain_stats_sample_rate = 0.1
        per_feature_weight_override = {types.FeaturePath(['a']): 'w'}
        add_default_generators = True
        use_sketch_based_topk_uniques = True
        experimental_result_partitions = 3

        options = stats_options.StatsOptions(
            feature_allowlist=feature_allowlist,
            schema=schema,
            vocab_paths=vocab_paths,
            label_feature=label_feature,
            weight_feature=weight_feature,
            sample_rate=sample_rate,
            num_top_values=num_top_values,
            frequency_threshold=frequency_threshold,
            weighted_frequency_threshold=weighted_frequency_threshold,
            num_rank_histogram_buckets=num_rank_histogram_buckets,
            num_values_histogram_buckets=num_values_histogram_buckets,
            num_histogram_buckets=num_histogram_buckets,
            num_quantiles_histogram_buckets=num_quantiles_histogram_buckets,
            epsilon=epsilon,
            infer_type_from_schema=infer_type_from_schema,
            desired_batch_size=desired_batch_size,
            enable_semantic_domain_stats=enable_semantic_domain_stats,
            semantic_domain_stats_sample_rate=semantic_domain_stats_sample_rate,
            per_feature_weight_override=per_feature_weight_override,
            add_default_generators=add_default_generators,
            experimental_use_sketch_based_topk_uniques=
            use_sketch_based_topk_uniques,
            experimental_result_partitions=experimental_result_partitions,
        )

        options_json = options.to_json()
        options = stats_options.StatsOptions.from_json(options_json)

        self.assertEqual(feature_allowlist, options.feature_allowlist)
        compare.assertProtoEqual(self, schema, options.schema)
        self.assertEqual(vocab_paths, options.vocab_paths)
        self.assertEqual(label_feature, options.label_feature)
        self.assertEqual(weight_feature, options.weight_feature)
        self.assertEqual(sample_rate, options.sample_rate)
        self.assertEqual(num_top_values, options.num_top_values)
        self.assertEqual(frequency_threshold, options.frequency_threshold)
        self.assertEqual(weighted_frequency_threshold,
                         options.weighted_frequency_threshold)
        self.assertEqual(num_rank_histogram_buckets,
                         options.num_rank_histogram_buckets)
        self.assertEqual(num_values_histogram_buckets,
                         options.num_values_histogram_buckets)
        self.assertEqual(num_histogram_buckets, options.num_histogram_buckets)
        self.assertEqual(num_quantiles_histogram_buckets,
                         options.num_quantiles_histogram_buckets)
        self.assertEqual(epsilon, options.epsilon)
        self.assertEqual(infer_type_from_schema,
                         options.infer_type_from_schema)
        self.assertEqual(desired_batch_size, options.desired_batch_size)
        self.assertEqual(enable_semantic_domain_stats,
                         options.enable_semantic_domain_stats)
        self.assertEqual(semantic_domain_stats_sample_rate,
                         options.semantic_domain_stats_sample_rate)
        self.assertEqual(per_feature_weight_override,
                         options._per_feature_weight_override)
        self.assertEqual(add_default_generators,
                         options.add_default_generators)
        self.assertEqual(use_sketch_based_topk_uniques,
                         options.experimental_use_sketch_based_topk_uniques)
        self.assertEqual(experimental_result_partitions,
                         options.experimental_result_partitions)
示例#29
0
 def testNormalizesFloat(self):
   pb1 = compare_test_pb2.Large()
   pb1.double_ = 4.0
   pb2 = compare_test_pb2.Large()
   pb2.double_ = 4
   compare.assertProtoEqual(self, pb1, pb2, normalize_numbers=True)
示例#30
0
 def testNormalizesNumbers(self):
   pb1 = compare_test_pb2.Large()
   pb1.int64_ = 4
   pb2 = compare_test_pb2.Large()
   pb2.int64_ = 4
   compare.assertProtoEqual(self, pb1, pb2)
示例#31
0
 def assertProtoEqual(self, a, b, **kwargs):
   if isinstance(a, six.string_types) and isinstance(b, six.string_types):
     a, b = LargePbs(a, b)
   compare.assertProtoEqual(self, a, b, **kwargs)
示例#32
0
    def test_stats_options_json_round_trip(self):
        generators = [
            lift_stats_generator.LiftStatsGenerator(
                schema=None,
                y_path=types.FeaturePath(['label']),
                x_paths=[types.FeaturePath(['feature'])])
        ]
        feature_whitelist = ['a']
        schema = schema_pb2.Schema(feature=[schema_pb2.Feature(name='f')])
        label_feature = 'label'
        weight_feature = 'weight'
        slice_functions = [slicing_util.get_feature_value_slicer({'b': None})]
        sample_rate = 0.01
        num_top_values = 21
        frequency_threshold = 2
        weighted_frequency_threshold = 2.0
        num_rank_histogram_buckets = 1001
        num_values_histogram_buckets = 11
        num_histogram_buckets = 11
        num_quantiles_histogram_buckets = 11
        epsilon = 0.02
        infer_type_from_schema = True
        desired_batch_size = 100
        enable_semantic_domain_stats = True
        semantic_domain_stats_sample_rate = 0.1

        options = stats_options.StatsOptions(
            generators=generators,
            feature_whitelist=feature_whitelist,
            schema=schema,
            label_feature=label_feature,
            weight_feature=weight_feature,
            slice_functions=slice_functions,
            sample_rate=sample_rate,
            num_top_values=num_top_values,
            frequency_threshold=frequency_threshold,
            weighted_frequency_threshold=weighted_frequency_threshold,
            num_rank_histogram_buckets=num_rank_histogram_buckets,
            num_values_histogram_buckets=num_values_histogram_buckets,
            num_histogram_buckets=num_histogram_buckets,
            num_quantiles_histogram_buckets=num_quantiles_histogram_buckets,
            epsilon=epsilon,
            infer_type_from_schema=infer_type_from_schema,
            desired_batch_size=desired_batch_size,
            enable_semantic_domain_stats=enable_semantic_domain_stats,
            semantic_domain_stats_sample_rate=semantic_domain_stats_sample_rate
        )

        options_json = options.to_json()
        options = stats_options.StatsOptions.from_json(options_json)

        self.assertIsNone(options.generators)
        self.assertEqual(feature_whitelist, options.feature_whitelist)
        compare.assertProtoEqual(self, schema, options.schema)
        self.assertEqual(label_feature, options.label_feature)
        self.assertEqual(weight_feature, options.weight_feature)
        self.assertIsNone(options.slice_functions)
        self.assertEqual(sample_rate, options.sample_rate)
        self.assertEqual(num_top_values, options.num_top_values)
        self.assertEqual(frequency_threshold, options.frequency_threshold)
        self.assertEqual(weighted_frequency_threshold,
                         options.weighted_frequency_threshold)
        self.assertEqual(num_rank_histogram_buckets,
                         options.num_rank_histogram_buckets)
        self.assertEqual(num_values_histogram_buckets,
                         options.num_values_histogram_buckets)
        self.assertEqual(num_histogram_buckets, options.num_histogram_buckets)
        self.assertEqual(num_quantiles_histogram_buckets,
                         options.num_quantiles_histogram_buckets)
        self.assertEqual(epsilon, options.epsilon)
        self.assertEqual(infer_type_from_schema,
                         options.infer_type_from_schema)
        self.assertEqual(desired_batch_size, options.desired_batch_size)
        self.assertEqual(enable_semantic_domain_stats,
                         options.enable_semantic_domain_stats)
        self.assertEqual(semantic_domain_stats_sample_rate,
                         options.semantic_domain_stats_sample_rate)