Пример #1
0
    def interpret(self, raw_input):
        """
        Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box
        interpretation for a certain set of UI component types, as well as the custom interpretation case.
        :param raw_input: a list of raw inputs to apply the interpretation(s) on.
        """
        if self.interpretation == "default":
            processed_input = [
                input_interface.preprocess(raw_input[i])
                for i, input_interface in enumerate(self.input_interfaces)
            ]
            original_output = self.run_prediction(processed_input)
            scores = []
            for i, x in enumerate(raw_input):
                input_interface = self.input_interfaces[i]
                neighbor_raw_input = list(raw_input)
                neighbor_values = input_interface.get_interpretation_neighbors(
                    x)
                interface_scores = []
                for neighbor_input in neighbor_values[0]:
                    neighbor_raw_input[i] = neighbor_input
                    processed_neighbor_input = [
                        input_interface.preprocess(neighbor_raw_input[i]) for
                        i, input_interface in enumerate(self.input_interfaces)
                    ]
                    neighbor_output = self.run_prediction(
                        processed_neighbor_input)
                    interface_scores.append(
                        quantify_difference_in_label(self, original_output,
                                                     neighbor_output))
                scores.append(
                    input_interface.get_interpretation_scores(
                        raw_input[i], interface_scores, **neighbor_values[1]))
            return scores
        else:
            processed_input = [
                input_interface.preprocess(raw_input[i])
                for i, input_interface in enumerate(self.input_interfaces)
            ]
            interpreter = self.interpretation

            if self.capture_session and self.session is not None:
                graph, sess = self.session
                with graph.as_default(), sess.as_default():
                    interpretation = interpreter(*processed_input)
            else:
                try:
                    interpretation = interpreter(*processed_input)
                except ValueError as exception:
                    if str(exception).endswith(
                            "is not an element of this graph."):
                        raise ValueError(strings.en["TF1_ERROR"])
                    else:
                        raise exception

            if len(raw_input) == 1:
                interpretation = [interpretation]
        return interpretation
Пример #2
0
    def interpret(self, raw_input):
        """
        Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box
        interpretation for a certain set of UI component types, as well as the custom interpretation case.
        :param raw_input: a list of raw inputs to apply the interpretation(s) on.
        """
        if self.interpretation.lower() == "default":
            processed_input = [
                input_component.preprocess(raw_input[i])
                for i, input_component in enumerate(self.input_components)
            ]
            original_output = self.run_prediction(processed_input)
            scores, alternative_outputs = [], []
            for i, x in enumerate(raw_input):
                input_component = self.input_components[i]
                neighbor_raw_input = list(raw_input)
                if input_component.interpret_by_tokens:
                    tokens, neighbor_values, masks = input_component.tokenize(
                        x)
                    interface_scores = []
                    alternative_output = []
                    for neighbor_input in neighbor_values:
                        neighbor_raw_input[i] = neighbor_input
                        processed_neighbor_input = [
                            input_component.preprocess(neighbor_raw_input[i])
                            for i, input_component in enumerate(
                                self.input_components)
                        ]
                        neighbor_output = self.run_prediction(
                            processed_neighbor_input)
                        processed_neighbor_output = [
                            output_component.postprocess(neighbor_output[i])
                            for i, output_component in enumerate(
                                self.output_components)
                        ]

                        alternative_output.append(processed_neighbor_output)
                        interface_scores.append(
                            quantify_difference_in_label(
                                self, original_output, neighbor_output))
                    alternative_outputs.append(alternative_output)
                    scores.append(
                        input_component.get_interpretation_scores(
                            raw_input[i],
                            neighbor_values,
                            interface_scores,
                            masks=masks,
                            tokens=tokens))
                else:
                    neighbor_values, interpret_kwargs = input_component.get_interpretation_neighbors(
                        x)
                    interface_scores = []
                    alternative_output = []
                    for neighbor_input in neighbor_values:
                        neighbor_raw_input[i] = neighbor_input
                        processed_neighbor_input = [
                            input_component.preprocess(neighbor_raw_input[i])
                            for i, input_component in enumerate(
                                self.input_components)
                        ]
                        neighbor_output = self.run_prediction(
                            processed_neighbor_input)
                        processed_neighbor_output = [
                            output_component.postprocess(neighbor_output[i])
                            for i, output_component in enumerate(
                                self.output_components)
                        ]

                        alternative_output.append(processed_neighbor_output)
                        interface_scores.append(
                            quantify_difference_in_label(
                                self, original_output, neighbor_output))
                    alternative_outputs.append(alternative_output)
                    interface_scores = [-score for score in interface_scores]
                    scores.append(
                        input_component.get_interpretation_scores(
                            raw_input[i], neighbor_values, interface_scores,
                            **interpret_kwargs))
                return scores, alternative_outputs
        elif self.interpretation.lower() == "shap":
            scores = []
            try:
                import shap
            except (ImportError, ModuleNotFoundError):
                raise ValueError(
                    "The package `shap` is required for this interpretation method. Try: `pip install shap`"
                )

            processed_input = [
                input_component.preprocess(raw_input[i])
                for i, input_component in enumerate(self.input_components)
            ]
            original_output = self.run_prediction(processed_input)

            for i, x in enumerate(raw_input):  # iterate over reach interface
                input_component = self.input_components[i]
                tokens, _, masks = input_component.tokenize(x)

                def get_masked_prediction(
                        binary_mask
                ):  # construct a masked version of the input
                    masked_xs = input_component.get_masked_inputs(
                        tokens, binary_mask)
                    preds = []
                    for masked_x in masked_xs:
                        processed_masked_input = copy.deepcopy(processed_input)
                        processed_masked_input[i] = input_component.preprocess(
                            masked_x)
                        new_output = self.run_prediction(
                            processed_masked_input)
                        pred = get_regression_or_classification_value(
                            self, original_output, new_output)
                        preds.append(pred)
                    return np.array(preds)

                num_total_segments = len(tokens)
                explainer = shap.KernelExplainer(
                    get_masked_prediction, np.zeros((1, num_total_segments)))
                shap_values = explainer.shap_values(
                    np.ones((1, num_total_segments)),
                    nsamples=int(self.num_shap * num_total_segments),
                    silent=True)
                scores.append(
                    input_component.get_interpretation_scores(raw_input[i],
                                                              None,
                                                              shap_values[0],
                                                              masks=masks,
                                                              tokens=tokens))
            return scores, []
        else:
            processed_input = [
                input_component.preprocess(raw_input[i])
                for i, input_component in enumerate(self.input_components)
            ]
            interpreter = self.interpretation

            if self.capture_session and self.session is not None:
                graph, sess = self.session
                with graph.as_default(), sess.as_default():
                    interpretation = interpreter(*processed_input)
            else:
                try:
                    interpretation = interpreter(*processed_input)
                except ValueError as exception:
                    if str(exception).endswith(
                            "is not an element of this graph."):
                        raise ValueError(strings.en["TF1_ERROR"])
                    else:
                        raise exception
            if len(raw_input) == 1:
                interpretation = [interpretation]
            return interpretation, []