def __init__(self, example_weight_map: ExampleWeightMap, **kwargs): """Initializes a weighted lift statistics generator. Args: example_weight_map: an ExampleWeightMap that maps a FeaturePath to its corresponding weight column. **kwargs: The set of args to be passed to _LiftStatsGenerator. """ self._unweighted_generator = _LiftStatsGenerator( example_weight_map=ExampleWeightMap(), **kwargs) self._has_any_weight = bool(example_weight_map.all_weight_features()) if self._has_any_weight: self._weighted_generator = _LiftStatsGenerator( example_weight_map=example_weight_map, **kwargs)
def _to_topk_tuples( sliced_record_batch: Tuple[types.SliceKey, pa.RecordBatch], bytes_features: FrozenSet[types.FeaturePath], categorical_features: FrozenSet[types.FeaturePath], example_weight_map: ExampleWeightMap, ) -> Iterable[Tuple[Tuple[types.SliceKey, types.FeaturePathTuple, Any], Union[ int, Tuple[int, Union[int, float]]]]]: """Generates tuples for computing top-k and uniques from the input.""" slice_key, record_batch = sliced_record_batch has_any_weight = bool(example_weight_map.all_weight_features()) for feature_path, feature_array, weights in arrow_util.enumerate_arrays( record_batch, example_weight_map=example_weight_map, enumerate_leaves_only=True): feature_array_type = feature_array.type feature_type = stats_util.get_feature_type_from_arrow_type( feature_path, feature_array_type) if feature_path in bytes_features: continue if ((feature_type == statistics_pb2.FeatureNameStatistics.INT and feature_path in categorical_features) or feature_type == statistics_pb2.FeatureNameStatistics.STRING): flattened_values, parent_indices = arrow_util.flatten_nested( feature_array, weights is not None) if weights is not None and flattened_values: # Slow path: weighted uniques. flattened_values_np = np.asarray(flattened_values) weights_ndarray = weights[parent_indices] for value, count, weight in _weighted_unique( flattened_values_np, weights_ndarray): yield (slice_key, feature_path.steps(), value), (count, weight) else: value_counts = flattened_values.value_counts() values = value_counts.field('values').to_pylist() counts = value_counts.field('counts').to_pylist() if has_any_weight: for value, count in zip(values, counts): yield ((slice_key, feature_path.steps(), value), (count, 1)) else: for value, count in zip(values, counts): yield ((slice_key, feature_path.steps(), value), count)