def test_lime_stability_indices_good_behaviour(self): np.random.seed(1) rf = RandomForestClassifier(n_estimators=500) rf.fit(self.train, self.labels_train) i = np.random.randint(0, self.test.shape[0]) explainer = LimeTabularExplainerOvr(self.train, mode="classification", feature_names=self.feature_names, class_names=self.target_names, discretize_continuous=True) indices = explainer.check_stability(data_row=self.test[i], predict_fn=rf.predict_proba, num_features=2, n_calls=10, index_verbose=True, verbose=True, model_regressor=None) self.assertIsNotNone(indices) csi, vsi = indices self.assertTrue((csi >= 0) & (csi <= 100), "CSI Index value is not in the range [0,100]") self.assertTrue((vsi >= 0) & (vsi <= 100), "VSI Index value is not in the range [0,100]")
def test_lime_stability_indices_model_error(self): np.random.seed(1) rf = RandomForestClassifier(n_estimators=500) rf.fit(self.train, self.labels_train) lin_regr = LinearRegression(fit_intercept=True) i = np.random.randint(0, self.test.shape[0]) explainer = LimeTabularExplainerOvr(self.train, mode="classification", feature_names=self.feature_names, class_names=self.target_names, discretize_continuous=True) with self.assertRaises(LocalModelError): exp = explainer.check_stability(self.test[i], rf.predict_proba, num_features=2, n_calls=10, model_regressor=lin_regr)