Пример #1
0
  def test_large_string_threshold(self):
    values1 = pa.array(["a", "bbb", "c", "d", "eeff"])
    values2 = pa.array(["a", "gghh"])
    sketch1 = sketches.MisraGriesSketch(
        _NUM_BUCKETS,
        large_string_threshold=2,
        large_string_placeholder=b"<LARGE>")
    sketch1.AddValues(values1)

    sketch2 = sketches.MisraGriesSketch(
        _NUM_BUCKETS,
        large_string_threshold=2,
        large_string_placeholder=b"<LARGE>")
    sketch2.AddValues(values2)

    serialized1 = sketch1.Serialize()
    serialized2 = sketch2.Serialize()

    sketch1 = sketches.MisraGriesSketch.Deserialize(serialized1)
    sketch2 = sketches.MisraGriesSketch.Deserialize(serialized2)
    sketch1.AddValues(values2)
    sketch1.Merge(sketch2)

    actual = sketch1.Estimate()
    actual.validate(full=True)
    self.assertEqual(actual.to_pylist(), [
        {"values": b"<LARGE>", "counts": 4.0},
        {"values": b"a", "counts": 3.0},
        {"values": b"c", "counts": 1.0},
        {"values": b"d", "counts": 1.0},
    ])
Пример #2
0
  def test_replace_invalid_utf8(self):
    values1 = pa.array([
        b"a",
        b"\x80",  # invalid
        b"\xC1",  # invalid
    ])
    values2 = pa.array([
        b"\xc0\x80",  # invalid
        b"a"])
    sketch1 = sketches.MisraGriesSketch(
        _NUM_BUCKETS,
        invalid_utf8_placeholder=b"<BYTES>")
    sketch1.AddValues(values1)

    sketch2 = sketches.MisraGriesSketch(
        _NUM_BUCKETS,
        invalid_utf8_placeholder=b"<BYTES>")
    sketch2.AddValues(values2)

    serialized1 = sketch1.Serialize()
    serialized2 = sketch2.Serialize()

    sketch1 = sketches.MisraGriesSketch.Deserialize(serialized1)
    sketch2 = sketches.MisraGriesSketch.Deserialize(serialized2)
    sketch1.AddValues(values2)
    sketch1.Merge(sketch2)

    actual = sketch1.Estimate()
    actual.validate(full=True)
    self.assertEqual(actual.to_pylist(), [
        {"values": b"<BYTES>", "counts": 4.0},
        {"values": b"a", "counts": 3.0},
    ])
Пример #3
0
  def test_invalid_large_string_replacing_config(self):
    with self.assertRaisesRegex(
        RuntimeError,
        "Must provide both or neither large_string_threshold and "
        "large_string_placeholder"):
      _ = sketches.MisraGriesSketch(_NUM_BUCKETS, large_string_threshold=1024)

    with self.assertRaisesRegex(
        RuntimeError,
        "Must provide both or neither large_string_threshold and "
        "large_string_placeholder"):
      _ = sketches.MisraGriesSketch(
          _NUM_BUCKETS, large_string_placeholder=b"<L>")
Пример #4
0
def _create_basic_sketch(items, weights=None, num_buckets=_NUM_BUCKETS):
    sketch = sketches.MisraGriesSketch(num_buckets)
    if weights:
        sketch.AddValues(items, weights)
    else:
        sketch.AddValues(items)
    return sketch
Пример #5
0
 def __init__(self,
              invalidate=False,
              num_in_vocab_tokens: int = 0,
              total_num_tokens: int = 0,
              sum_in_vocab_token_lengths: int = 0,
              num_examples: int = 0) -> None:
     # True only if this feature should never be considered, e.g: some
     # value_lists have inconsistent types or feature doesn't have an
     # NL domain.
     self.invalidate = invalidate
     self.num_in_vocab_tokens = num_in_vocab_tokens
     self.total_num_tokens = total_num_tokens
     self.sum_in_vocab_token_lengths = sum_in_vocab_token_lengths
     self.num_examples = num_examples
     self.vocab_token_length_quantiles = sketches.QuantilesSketch(
         _QUANTILES_SKETCH_ERROR, _QUANTILES_SKETCH_NUM_ELEMENTS,
         _QUANTILES_SKETCH_NUM_STREAMS)
     self.min_sequence_length = None
     self.max_sequence_length = None
     self.sequence_length_quantiles = sketches.QuantilesSketch(
         _QUANTILES_SKETCH_ERROR, _QUANTILES_SKETCH_NUM_ELEMENTS,
         _QUANTILES_SKETCH_NUM_STREAMS)
     self.token_occurrence_counts = sketches.MisraGriesSketch(
         _NUM_MISRAGRIES_SKETCH_BUCKETS)
     self.token_statistics = collections.defaultdict(_TokenStats)
     self.reported_sequences_coverage = []
     self.reported_sequences_avg_token_length = []
Пример #6
0
 def test_no_replace_invalid_utf8(self):
   sketch = sketches.MisraGriesSketch(
       _NUM_BUCKETS)
   sketch.AddValues(pa.array([b"\x80"]))
   actual = sketch.Estimate()
   self.assertEqual(actual.to_pylist(), [
       {"values": b"\x80", "counts": 1.0},
   ])
Пример #7
0
 def test_add_unsupported_type(self):
     values = pa.array([True, False], pa.bool_())
     sketch = sketches.MisraGriesSketch(_NUM_BUCKETS)
     with self.assertRaisesRegex(RuntimeError, "UNIMPLEMENTED: bool"):
         sketch.AddValues(values)
Пример #8
0
 def create_accumulator(self) -> sketches.MisraGriesSketch:
     return sketches.MisraGriesSketch(self._top_k)