def test_understand(self): x, y = make_moons(n_samples=1500, noise=.4, random_state=17) clf = MLPClassifier() x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=.8, test_size=.2, random_state=17) clf.fit(x_train, y_train) gpx = Gpx(clf.predict_proba, x_train=x, y_train=y, feature_names=['x', 'y']) gpx.explaining(x_test[30, :]) y = clf.predict_proba(x_test) gpx.logger.info(gpx.proba_transform(y)) try: u = gpx.understand(metric='loss') except ValueError as e: gpx.logger.exception(e) u = gpx.understand(metric='accuracy') gpx.logger.info('test_understand accuracy {}'.format(u)) self.assertGreater(u, .9, 'test_understand accuracy {}'.format(u))
def test_grafic_sensibility(self): INSTANCE: int = 74 x, y = make_moons(n_samples=1500, noise=.4, random_state=17) clf = MLPClassifier() x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=.8, test_size=.2, random_state=17) clf.fit(x_train, y_train) gpx = Gpx(clf.predict_proba, x_train=x, y_train=y, random_state=42, feature_names=['x', 'y']) gpx.explaining(x_test[INSTANCE, :]) x, y = gpx.x_around[:, 0], gpx.x_around[:, 1] y_proba = gpx.proba_transform(gpx.y_around) resolution = 0.02 x1_min, x1_max = x.min() - 1, x.max() + 1 x2_min, x2_max = y.min() - 1, y.max() + 1 xm1, xm2 = np.meshgrid(np.arange(x1_min, x1_max, resolution), np.arange(x2_min, x2_max, resolution)) Z_bb = gpx.gp_prediction(np.array([xm1.ravel(), xm2.ravel()]).T) fig, ax = plt.subplots() ax.set_xlim(x1_min, x1_max) ax.set_xlim(x2_min, x2_max) scat = plt.scatter(x, y, y_proba) def func(data): k, j = data scat.set_offsets(k) scat.set_array(j) mmm = gpx.max_min_matrix(noise_range=10) gen = [] for n in mmm[:, 0]: aux = gpx.x_around.copy() aux[:, 0] = n gen.append((aux.copy(), gpx.gp_prediction(aux.copy()))) animation = ani.FuncAnimation(fig, func, gen, interval=200, save_count=200) plt.contourf(xm1, xm2, Z_bb.reshape(xm1.shape), alpha=0.4) plt.scatter(x, y, c=y_proba) plt.show() writergif = ani.PillowWriter(fps=5) animation.save('sens_x_2.gif', writer=writergif) sens_gpx = gpx.feature_sensitivity() print(sens_gpx)