def _create_basic_sketch(items, weights=None, num_buckets=_NUM_BUCKETS): sketch = MisraGriesSketch(num_buckets) if weights: sketch.AddValues(items, weights) else: sketch.AddValues(items) return sketch
def add_input( self, accumulator: sketches.MisraGriesSketch, next_input: Tuple[np.ndarray, np.ndarray]) -> sketches.MisraGriesSketch: items, weights = next_input if items.size: accumulator.AddValues(pa.array(items), pa.array(weights, pa.float32())) return accumulator
def extract_output(self, accumulator: sketches.MisraGriesSketch) -> np.ndarray: estimate = accumulator.Estimate() estimate.validate() result = np.dstack(reversed(estimate.flatten())) if not result.size: return np.array([[ analyzers.get_empy_vocabulary_dummy_value(self._input_dtype) ]], dtype=object) else: return result
def _update_combined_sketch_for_feature( self, feature_name: tfdv_types.FeaturePath, values: pa.Array, weights: Optional[np.ndarray], accumulator: Dict[tfdv_types.FeaturePath, _CombinedSketch]): """Updates combined sketch with values (and weights if provided).""" flattened_values, parent_indices = arrow_util.flatten_nested( values, weights is not None) combined_sketch = accumulator.get(feature_name, None) if combined_sketch is None: combined_sketch = _CombinedSketch( distinct=KmvSketch(self._num_kmv_buckets), topk_unweighted=MisraGriesSketch(self._num_misragries_buckets), topk_weighted=MisraGriesSketch(self._num_misragries_buckets), ) weight_array = None if weights is not None: flattened_weights = weights[parent_indices] weight_array = pa.array(flattened_weights, type=pa.float32()) combined_sketch.add(flattened_values, weight_array) accumulator[feature_name] = combined_sketch
def make_mg_sketch(): num_buckets = max(self._num_misragries_buckets, self._num_top_values, self._num_rank_histogram_buckets) self._num_mg_buckets_gauge.set(num_buckets) self._num_top_values_gauge.set(self._num_top_values) self._num_rank_histogram_buckets_gauge.set( self._num_rank_histogram_buckets) return MisraGriesSketch( num_buckets=num_buckets, invalid_utf8_placeholder=constants.NON_UTF8_PLACEHOLDER, # Maximum sketch size: 32 * num_buckets * constant_factor. large_string_threshold=32, large_string_placeholder=constants.LARGE_BYTES_PLACEHOLDER)
def test_serialization(self): sketch = _create_basic_sketch(pa.array(["a", "b", "c", "a"])) serialized = sketch.Serialize() self.assertIsInstance(serialized, bytes) deserialized = MisraGriesSketch.Deserialize(serialized) self.assertIsInstance(deserialized, MisraGriesSketch) estimate = deserialized.Estimate().to_pylist() expected_counts = [{ "values": b"a", "counts": 2.0 }, { "values": b"b", "counts": 1.0 }, { "values": b"c", "counts": 1.0 }] self.assertEqual(estimate, expected_counts)
def test_add_unsupported_type(self): values = pa.array([True, False], pa.bool_()) sketch = MisraGriesSketch(_NUM_BUCKETS) with self.assertRaisesRegex(RuntimeError, "Unimplemented: bool"): sketch.AddValues(values)
def encode_cache(self, accumulator: sketches.MisraGriesSketch) -> bytes: return accumulator.Serialize()