def _importance(self, feature): prior_prediction = self.predict([util.bias_feature()] + [pb.Feature()] * (self._config.num_features - 1)) with_weight_prediction = self.predict([util.bias_feature()] + [pb.Feature()] * (self._config.num_features - 2) + [feature]) return util.kl_divergence(with_weight_prediction, prior_prediction)
def test_importance_is_monotonic_in_mean(self): f = pb.Feature(feature=10, value=5) importances = [] for mean in np.linspace(0.0, 3.0, 10): self._predictor._set_weight(f, pb.Gaussian(mean=0.5, variance=0.5)) importances.append(self._predictor._importance(f)) self.assertEqual(sorted(importances), importances)
def next(self): """Implementing the Python iterator protocol """ feature_vector = [ pb.Feature(feature=f, value=np.random.randint(0, self._cardinality(f))) for f in range(self._simulation.predictor_config.num_features) ] label = self._label(feature_vector) self._num_samples += 1 return (feature_vector, label)
def _construct_biased_weights(simulation): biased_weights = {} for feature, value in itertools.product( range(1, simulation.predictor_config.num_features), range(simulation.feature_cardinality)): key = util.serialize_feature( pb.Feature(feature=feature, value=value)) if np.random.rand() < simulation.biased_feature_proportion: direction = np.random.rand() < \ simulation.predictor_config.prior_probability biased_weights[key] = direction logger.info("Biased truth feature (%s, %s) to %s", feature, value, direction) return biased_weights
def test_importance_of_set_feature(self): f = pb.Feature(feature=10, value=5) self._predictor._set_weight(f, pb.Gaussian(mean=0.5, variance=0.5)) self.assertGreater(self._predictor._importance(f), 0.0)
def test_importance_of_empty_feature(self): self.assertEqual(self._predictor._importance(pb.Feature()), 0.0)
def _create_feature_vector(num_features): return [util.bias_feature()] + \ [pb.Feature()] * (num_features - 1)
def deserialize_feature(string): f = pb.Feature() f.ParseFromString(string) return f
def bias_feature(): return pb.Feature(feature=0, value=0)