def interpret(self, raw_input): """ Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box interpretation for a certain set of UI component types, as well as the custom interpretation case. :param raw_input: a list of raw inputs to apply the interpretation(s) on. """ if self.interpretation == "default": processed_input = [ input_interface.preprocess(raw_input[i]) for i, input_interface in enumerate(self.input_interfaces) ] original_output = self.run_prediction(processed_input) scores = [] for i, x in enumerate(raw_input): input_interface = self.input_interfaces[i] neighbor_raw_input = list(raw_input) neighbor_values = input_interface.get_interpretation_neighbors( x) interface_scores = [] for neighbor_input in neighbor_values[0]: neighbor_raw_input[i] = neighbor_input processed_neighbor_input = [ input_interface.preprocess(neighbor_raw_input[i]) for i, input_interface in enumerate(self.input_interfaces) ] neighbor_output = self.run_prediction( processed_neighbor_input) interface_scores.append( quantify_difference_in_label(self, original_output, neighbor_output)) scores.append( input_interface.get_interpretation_scores( raw_input[i], interface_scores, **neighbor_values[1])) return scores else: processed_input = [ input_interface.preprocess(raw_input[i]) for i, input_interface in enumerate(self.input_interfaces) ] interpreter = self.interpretation if self.capture_session and self.session is not None: graph, sess = self.session with graph.as_default(), sess.as_default(): interpretation = interpreter(*processed_input) else: try: interpretation = interpreter(*processed_input) except ValueError as exception: if str(exception).endswith( "is not an element of this graph."): raise ValueError(strings.en["TF1_ERROR"]) else: raise exception if len(raw_input) == 1: interpretation = [interpretation] return interpretation
def interpret(self, raw_input): """ Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box interpretation for a certain set of UI component types, as well as the custom interpretation case. :param raw_input: a list of raw inputs to apply the interpretation(s) on. """ if self.interpretation.lower() == "default": processed_input = [ input_component.preprocess(raw_input[i]) for i, input_component in enumerate(self.input_components) ] original_output = self.run_prediction(processed_input) scores, alternative_outputs = [], [] for i, x in enumerate(raw_input): input_component = self.input_components[i] neighbor_raw_input = list(raw_input) if input_component.interpret_by_tokens: tokens, neighbor_values, masks = input_component.tokenize( x) interface_scores = [] alternative_output = [] for neighbor_input in neighbor_values: neighbor_raw_input[i] = neighbor_input processed_neighbor_input = [ input_component.preprocess(neighbor_raw_input[i]) for i, input_component in enumerate( self.input_components) ] neighbor_output = self.run_prediction( processed_neighbor_input) processed_neighbor_output = [ output_component.postprocess(neighbor_output[i]) for i, output_component in enumerate( self.output_components) ] alternative_output.append(processed_neighbor_output) interface_scores.append( quantify_difference_in_label( self, original_output, neighbor_output)) alternative_outputs.append(alternative_output) scores.append( input_component.get_interpretation_scores( raw_input[i], neighbor_values, interface_scores, masks=masks, tokens=tokens)) else: neighbor_values, interpret_kwargs = input_component.get_interpretation_neighbors( x) interface_scores = [] alternative_output = [] for neighbor_input in neighbor_values: neighbor_raw_input[i] = neighbor_input processed_neighbor_input = [ input_component.preprocess(neighbor_raw_input[i]) for i, input_component in enumerate( self.input_components) ] neighbor_output = self.run_prediction( processed_neighbor_input) processed_neighbor_output = [ output_component.postprocess(neighbor_output[i]) for i, output_component in enumerate( self.output_components) ] alternative_output.append(processed_neighbor_output) interface_scores.append( quantify_difference_in_label( self, original_output, neighbor_output)) alternative_outputs.append(alternative_output) interface_scores = [-score for score in interface_scores] scores.append( input_component.get_interpretation_scores( raw_input[i], neighbor_values, interface_scores, **interpret_kwargs)) return scores, alternative_outputs elif self.interpretation.lower() == "shap": scores = [] try: import shap except (ImportError, ModuleNotFoundError): raise ValueError( "The package `shap` is required for this interpretation method. Try: `pip install shap`" ) processed_input = [ input_component.preprocess(raw_input[i]) for i, input_component in enumerate(self.input_components) ] original_output = self.run_prediction(processed_input) for i, x in enumerate(raw_input): # iterate over reach interface input_component = self.input_components[i] tokens, _, masks = input_component.tokenize(x) def get_masked_prediction( binary_mask ): # construct a masked version of the input masked_xs = input_component.get_masked_inputs( tokens, binary_mask) preds = [] for masked_x in masked_xs: processed_masked_input = copy.deepcopy(processed_input) processed_masked_input[i] = input_component.preprocess( masked_x) new_output = self.run_prediction( processed_masked_input) pred = get_regression_or_classification_value( self, original_output, new_output) preds.append(pred) return np.array(preds) num_total_segments = len(tokens) explainer = shap.KernelExplainer( get_masked_prediction, np.zeros((1, num_total_segments))) shap_values = explainer.shap_values( np.ones((1, num_total_segments)), nsamples=int(self.num_shap * num_total_segments), silent=True) scores.append( input_component.get_interpretation_scores(raw_input[i], None, shap_values[0], masks=masks, tokens=tokens)) return scores, [] else: processed_input = [ input_component.preprocess(raw_input[i]) for i, input_component in enumerate(self.input_components) ] interpreter = self.interpretation if self.capture_session and self.session is not None: graph, sess = self.session with graph.as_default(), sess.as_default(): interpretation = interpreter(*processed_input) else: try: interpretation = interpreter(*processed_input) except ValueError as exception: if str(exception).endswith( "is not an element of this graph."): raise ValueError(strings.en["TF1_ERROR"]) else: raise exception if len(raw_input) == 1: interpretation = [interpretation] return interpretation, []