def save_to_cache_for_jackknife(self, key, val, split_by=None): """Used to monkey patch the save_to_cache() during Jackknife.precompute(). What cache_key to use for the point estimate of Jackknife is tricky because we want to support two use cases at the same time. 1. We want sumx to be computed only once in MetricList([Jackknife(sumx), sumx]).compute_on(df, return_dataframe=False), so the key for point estimate should be the same sumx uses. 2. But then it will fail when multiple Jackknifes are involved. For example, (Jackknife(unit1, sumx) - Jackknife(unit2, sumx)).compute_on(df) will fail because two Jackknifes share point estimate but not LOO estimates. When the 2nd Jackknife precomputes its point esitmate, as it uses the same key as the 1st one, it will mistakenly assume LOO has been cached, but unfortunately it's not true. The solution here is we use different keys for different Jackknifes, so LOO will always be precomputed. Additionally we cache the point estimate again with the key other Metrics like Sum would use so they can reuse it. Args: self: An instance of metrics.Metric. key: The cache key currently being used in computation. val: The value to cache. split_by: Something can be passed into df.group_by(). """ key = self.wrap_cache_key(key, split_by) if isinstance(key.key, tuple) and key.key[:2] == ('_RESERVED', 'jk'): val = val.copy() if isinstance(val, (pd.Series, pd.DataFrame)) else val base_key = key.key[2] base_key = utils.CacheKey(base_key, key.where, key.split_by, key.slice_val) self.cache[base_key] = val if utils.is_tmp_key(base_key): self.tmp_cache_keys.add(base_key) val = val.copy() if isinstance(val, (pd.Series, pd.DataFrame)) else val self.cache[key] = val
def precompute_loo(self, df, split_by=None): """Precomputes leave-one-out (LOO) results to make Jackknife faster. For Sum, Count and Mean, it's possible to compute the LOO estimates in a vectorized way. LOO mean is just LOO sum / LOO count. Here we precompute and cache the LOO results. Args: self: The Mean instance callling this function. df: The DataFrame passed to Mean.compute_slies(). split_by: The split_by passed to Mean.compute_slies(). Returns: Same as what normal Mean.compute_slies() would have returned. """ data = df.copy() split_by_with_unit = [unit] + split_by if split_by else [unit] if self.weight: weighted_var = '_weighted_%s' % self.var data[weighted_var] = data[self.var] * data[self.weight] total_sum = self.group(data, split_by)[weighted_var].sum() total_weight = self.group(data, split_by)[self.weight].sum() bucket_sum = self.group(data, split_by_with_unit)[weighted_var].sum() bucket_sum = utils.adjust_slices_for_loo(bucket_sum, original_split_by) bucket_weight = self.group(data, split_by_with_unit)[self.weight].sum() bucket_weight = utils.adjust_slices_for_loo(bucket_weight, original_split_by) loo_sum = total_sum - bucket_sum loo_weight = total_weight - bucket_weight if split_by: # total - bucket_sum might put the unit as the innermost level, but we # want the unit as the outermost level. loo_sum = loo_sum.reorder_levels(split_by_with_unit) loo_weight = loo_weight.reorder_levels(split_by_with_unit) loo = loo_sum / loo_weight mean = total_sum / total_weight else: total_sum = self.group(data, split_by)[self.var].sum() bucket_sum = self.group(data, split_by_with_unit)[self.var].sum() bucket_sum = utils.adjust_slices_for_loo(bucket_sum, original_split_by) total_ct = self.group(data, split_by)[self.var].count() bucket_ct = self.group(data, split_by_with_unit)[self.var].count() bucket_ct = utils.adjust_slices_for_loo(bucket_ct, original_split_by) loo_sum = total_sum - bucket_sum loo_ct = total_ct - bucket_ct loo = loo_sum / loo_ct mean = total_sum / total_ct if split_by: loo = loo.reorder_levels(split_by_with_unit) buckets = loo.index.get_level_values(0).unique() if split_by else loo.index for bucket in buckets: key = utils.CacheKey(('_RESERVED', 'Jackknife', unit, bucket), self.cache_key.where, split_by) self.save_to_cache(key, loo.loc[bucket]) self.tmp_cache_keys.add(key) return mean
def test_cache_key_where(self): output = utils.CacheKey('foo', 'where', ['bar', 'baz']) output = utils.CacheKey(output, 'where2') output.add_filters(['where', 'where3']) self.assertEqual(set(('where', 'where2', 'where3')), output.where) self.assertEqual('(where) & (where2) & (where3)', output.all_filters)
def test_cache_key_includes(self): derived = utils.CacheKey('foo', 'where', ['bar', 'baz']) base = utils.CacheKey('foo', 'where', ['bar']) self.assertTrue(derived.includes(base))
def test_cache_key_extend(self): output = utils.CacheKey('foo', 'where', ['bar', 'baz']) output.extend(['baz', 'qux']) self.assertEqual(utils.CacheKey('foo', 'where', ['bar', 'baz', 'qux']), output)
def test_cache_key_init_with_cache_key(self): expected = utils.CacheKey('foo', 'where', ['bar']) output = utils.CacheKey(expected) self.assertEqual(expected, output)
def wrap_cache_key(self, key, split_by=None, where=None, slice_val=None): if key is None: return None if where is None and self.cache_key: where = self.cache_key.where return utils.CacheKey(key, where or self.where, split_by, slice_val)