Пример #1
0
def explain_pred(input_data, model):
    y_preds = []
    y_probs = []
    encoded_htmls = []
    for i in input_data:
        expl = eli5.explain_prediction(
            model.steps[-1][1],
            i,
            model.steps[0][1],
            target_names=['Compliant', 'Not Compliant'],
            top=10)
        html_explanation = format_as_html(expl,
                                          force_weights=False,
                                          show_feature_values=True).replace(
                                              "\n", "").strip()
        encoded_html = base64.b64encode(
            bytes(html_explanation, encoding='utf-8'))
        encoded_htmls.append(encoded_html)
        expl_dict = format_as_dict(expl)
        targets = expl_dict['targets'][0]
        target = targets['target']
        y_pred = 1 if target.startswith('N') else 0
        y_prob = targets['proba']
        if len(i.split()) < 3:
            # one or two words can't be non-compliant
            y_pred = 0
            y_prob = 1.0
        y_prob = f'{round(y_prob, 3) * 100}%'
    y_preds.append(y_pred)
    y_probs.append(y_prob)
    inferences = np.column_stack((y_probs, y_preds, encoded_htmls))

    return inferences
Пример #2
0
def test_format_as_dict():
    assert format_as_dict(
        Explanation(
            estimator='some estimator',
            targets=[
                TargetExplanation('y',
                                  feature_weights=FeatureWeights(pos=[
                                      FeatureWeight('a', np.float32(13.0))
                                  ],
                                                                 neg=[])),
            ],
        )) == {
            'estimator':
            'some estimator',
            'targets': [
                {
                    'target': 'y',
                    'feature_weights': {
                        'pos': [{
                            'feature': 'a',
                            'weight': 13.0,
                            'std': None,
                            'value': None
                        }],
                        'pos_remaining':
                        0,
                        'neg': [],
                        'neg_remaining':
                        0,
                    },
                    'score': None,
                    'proba': None,
                    'weighted_spans': None,
                    'heatmap': None,
                },
            ],
            'decision_tree':
            None,
            'description':
            None,
            'error':
            None,
            'feature_importances':
            None,
            'highlight_spaces':
            None,
            'is_regression':
            False,
            'method':
            None,
            'transition_features':
            None,
            'image':
            None,
        }
Пример #3
0
def sklearn_predict(lines, current_app):
    estimator = current_app.config['ESTIMATOR']
    explanations = []
    y_preds = []
    y_probs = []
    for line in lines:
        line = ' '.join(t for t in line.split() if t not in stopwords)
        if not line.strip():
            continue
        expl = eli5.explain_prediction(
            estimator.steps[-1][1],
            line,
            estimator.steps[0][1],
            target_names=['Compliant', 'Not Compliant'],
            top=10)
        html_explanation = format_as_html(expl,
                                          force_weights=False,
                                          show_feature_values=True).replace(
                                              "\n", "").strip()
        explanations.append(html_explanation)
        expl_dict = format_as_dict(expl)
        targets = expl_dict['targets'][0]
        target = targets['target']
        y_pred = 1 if target.startswith('N') else 0
        y_prob = targets['proba']
        if len(line.split()) < 3:
            # one or two words can't be non-compliant
            y_pred = 0
            y_prob = 1.0
        y_preds.append(y_pred)
        y_probs.append(y_prob)
    y_probs = [f'{round(y_prob, 3) * 100}%' for y_prob in y_probs]
    data = zip(y_preds, lines, lines, y_probs, explanations)
    results = [
        dict(y_pred=y, line=l, clean_line=cl, y_prob=y_prob, expl=expl)
        for y, l, cl, y_prob, expl in data
    ]
    return results
Пример #4
0
def explain_review_prediction():
	"""
	Explain a specific prediction using the eli5 library
	"""
	data = request.get_json(force=True)

	# Use the original documents, not the corrected ones
	target_names = ['negative', 'neutral', 'positive', 'very_negative', 'very_positive']
	clf, vocabulary = load_clf_and_vocabulary(data['classifier'], data['vocabModel'], data['tfIdf'], False)
	vect = CountVectorizer(vocabulary=vocabulary)
	vect._validate_vocabulary()

	# reviews = load_files(dir_path + '/../../data/reviews/not_corrected')
	# text_train, text_test, y_train, y_test = train_test_split(reviews.data, reviews.target, test_size=0.2, random_state=0)

	# if data['tfIdf']:
	# 	if data['vocabModel'] == 'unigram':
	# 		vect = CountVectorizer(min_df=5, stop_words=stopwords, ngram_range=(1, 1)).fit(text_train)
	# 	elif data['vocabModel'] == 'bigram':
	# 		vect = CountVectorizer(min_df=5, stop_words=stopwords, ngram_range=(2, 2)).fit(text_train)
	# 	elif data['vocabModel'] == 'trigram':
	# 		vect = CountVectorizer(min_df=5, stop_words=stopwords, ngram_range=(3, 3)).fit(text_train)
	# else:
	# 	if data['vocabModel'] == 'unigram':
	# 		vect = CountVectorizer(min_df=5, stop_words=stopwords, ngram_range=(1, 1)).fit(text_train)
	# 	elif data['vocabModel'] == 'bigram':
	# 		vect = CountVectorizer(min_df=5, stop_words=stopwords, ngram_range=(2, 2)).fit(text_train)
	# 	elif data['vocabModel'] == 'trigram':
	# 		vect = CountVectorizer(min_df=5, stop_words=stopwords, ngram_range=(3, 3)).fit(text_train)

	if data['classifier'] == 'LR':
		explanation = explain_prediction.explain_prediction_linear_classifier(clf, data['review'], vec=vect, top=10, target_names=target_names)
		div = html.format_as_html(explanation, include_styles=False)
		style = html.format_html_styles()

		txt = text.format_as_text(explanation, show=eli5.formatters.fields.ALL, highlight_spaces=True, show_feature_values=True)
		print(txt)

		return jsonify({
			'div': div,
			'style': style
		})

	elif data['classifier'] == 'SVM' or data['classifier'] == 'MLP':
		te = TextExplainer(n_samples=100, clf=LogisticRegression(solver='newton-cg'), vec=vect, random_state=0)
		te.fit(data['review'], clf.predict_proba)
		explanation = te.explain_prediction(top=10, target_names=target_names)
		div = html.format_as_html(explanation, include_styles=False)
		style = html.format_html_styles()

		distorted_texts = []

		for sample in te.samples_:
			sample_explanation = explain_prediction.explain_prediction_linear_classifier(te.clf_, sample, vec=te.vec_, top=10, target_names=target_names)
			dict_explanation = as_dict.format_as_dict(sample_explanation)

			curr = {
				'text': sample
			}

			for c in dict_explanation['targets']:
				if c['target'] == 'negative':
					curr['negative'] = c['proba']
				elif c['target'] == 'neutral':
					curr['neutral'] = c['proba']
				elif c['target'] == 'positive':
					curr['positive'] = c['proba']
				elif c['target'] == 'very_negative':
					curr['very_negative'] = c['proba']
				elif c['target'] == 'very_positive':
					curr['very_positive'] = c['proba']

			distorted_texts.append(curr)

		review_explanation = as_dict.format_as_dict(explanation)
		probabilities = {}

		for c in review_explanation['targets']:
			if c['target'] == 'negative':
				probabilities['negative'] = c['proba']
			elif c['target'] == 'neutral':
				probabilities['neutral'] = c['proba']
			elif c['target'] == 'positive':
				probabilities['positive'] = c['proba']
			elif c['target'] == 'very_negative':
				probabilities['very_negative'] = c['proba']
			elif c['target'] == 'very_positive':
				probabilities['very_positive'] = c['proba']

		return jsonify({
			'div': div,
			'style': style,
			'original_text': data['review'],
			'probabilities': probabilities,
			'distorted_texts': distorted_texts,
			'metrics': te.metrics_
		})