def test_large_string_threshold(self): values1 = pa.array(["a", "bbb", "c", "d", "eeff"]) values2 = pa.array(["a", "gghh"]) sketch1 = sketches.MisraGriesSketch( _NUM_BUCKETS, large_string_threshold=2, large_string_placeholder=b"<LARGE>") sketch1.AddValues(values1) sketch2 = sketches.MisraGriesSketch( _NUM_BUCKETS, large_string_threshold=2, large_string_placeholder=b"<LARGE>") sketch2.AddValues(values2) serialized1 = sketch1.Serialize() serialized2 = sketch2.Serialize() sketch1 = sketches.MisraGriesSketch.Deserialize(serialized1) sketch2 = sketches.MisraGriesSketch.Deserialize(serialized2) sketch1.AddValues(values2) sketch1.Merge(sketch2) actual = sketch1.Estimate() actual.validate(full=True) self.assertEqual(actual.to_pylist(), [ {"values": b"<LARGE>", "counts": 4.0}, {"values": b"a", "counts": 3.0}, {"values": b"c", "counts": 1.0}, {"values": b"d", "counts": 1.0}, ])
def test_replace_invalid_utf8(self): values1 = pa.array([ b"a", b"\x80", # invalid b"\xC1", # invalid ]) values2 = pa.array([ b"\xc0\x80", # invalid b"a"]) sketch1 = sketches.MisraGriesSketch( _NUM_BUCKETS, invalid_utf8_placeholder=b"<BYTES>") sketch1.AddValues(values1) sketch2 = sketches.MisraGriesSketch( _NUM_BUCKETS, invalid_utf8_placeholder=b"<BYTES>") sketch2.AddValues(values2) serialized1 = sketch1.Serialize() serialized2 = sketch2.Serialize() sketch1 = sketches.MisraGriesSketch.Deserialize(serialized1) sketch2 = sketches.MisraGriesSketch.Deserialize(serialized2) sketch1.AddValues(values2) sketch1.Merge(sketch2) actual = sketch1.Estimate() actual.validate(full=True) self.assertEqual(actual.to_pylist(), [ {"values": b"<BYTES>", "counts": 4.0}, {"values": b"a", "counts": 3.0}, ])
def test_invalid_large_string_replacing_config(self): with self.assertRaisesRegex( RuntimeError, "Must provide both or neither large_string_threshold and " "large_string_placeholder"): _ = sketches.MisraGriesSketch(_NUM_BUCKETS, large_string_threshold=1024) with self.assertRaisesRegex( RuntimeError, "Must provide both or neither large_string_threshold and " "large_string_placeholder"): _ = sketches.MisraGriesSketch( _NUM_BUCKETS, large_string_placeholder=b"<L>")
def _create_basic_sketch(items, weights=None, num_buckets=_NUM_BUCKETS): sketch = sketches.MisraGriesSketch(num_buckets) if weights: sketch.AddValues(items, weights) else: sketch.AddValues(items) return sketch
def __init__(self, invalidate=False, num_in_vocab_tokens: int = 0, total_num_tokens: int = 0, sum_in_vocab_token_lengths: int = 0, num_examples: int = 0) -> None: # True only if this feature should never be considered, e.g: some # value_lists have inconsistent types or feature doesn't have an # NL domain. self.invalidate = invalidate self.num_in_vocab_tokens = num_in_vocab_tokens self.total_num_tokens = total_num_tokens self.sum_in_vocab_token_lengths = sum_in_vocab_token_lengths self.num_examples = num_examples self.vocab_token_length_quantiles = sketches.QuantilesSketch( _QUANTILES_SKETCH_ERROR, _QUANTILES_SKETCH_NUM_ELEMENTS, _QUANTILES_SKETCH_NUM_STREAMS) self.min_sequence_length = None self.max_sequence_length = None self.sequence_length_quantiles = sketches.QuantilesSketch( _QUANTILES_SKETCH_ERROR, _QUANTILES_SKETCH_NUM_ELEMENTS, _QUANTILES_SKETCH_NUM_STREAMS) self.token_occurrence_counts = sketches.MisraGriesSketch( _NUM_MISRAGRIES_SKETCH_BUCKETS) self.token_statistics = collections.defaultdict(_TokenStats) self.reported_sequences_coverage = [] self.reported_sequences_avg_token_length = []
def test_no_replace_invalid_utf8(self): sketch = sketches.MisraGriesSketch( _NUM_BUCKETS) sketch.AddValues(pa.array([b"\x80"])) actual = sketch.Estimate() self.assertEqual(actual.to_pylist(), [ {"values": b"\x80", "counts": 1.0}, ])
def test_add_unsupported_type(self): values = pa.array([True, False], pa.bool_()) sketch = sketches.MisraGriesSketch(_NUM_BUCKETS) with self.assertRaisesRegex(RuntimeError, "UNIMPLEMENTED: bool"): sketch.AddValues(values)
def create_accumulator(self) -> sketches.MisraGriesSketch: return sketches.MisraGriesSketch(self._top_k)