コード例 #1
0
    def update_scene(self):
        self.clear_scene()
        self.mo_info = ""
        self.bv_info = ""
        scores = None
        if self.__results is not None:
            data = self.__results.transformed_data
            pred = self.__results.predictions
            base = self.__results.base_value
            values, _, labels, ranges = prepare_force_plot_data(
                self.__results.values, data, pred, self.target_index)

            index = 0
            HIGH, LOW = 0, 1
            plot_data = PlotData(high_values=values[index][HIGH],
                                 low_values=values[index][LOW][::-1],
                                 high_labels=labels[index][HIGH],
                                 low_labels=labels[index][LOW][::-1],
                                 value_range=ranges[index],
                                 model_output=pred[index][self.target_index],
                                 base_value=base[self.target_index])
            self.setup_plot(plot_data)

            self.mo_info = f"Model prediction: {_str(plot_data.model_output)}"
            self.bv_info = f"Base value: {_str(plot_data.base_value)}"

            assert isinstance(self.__results.values, list)
            scores = self.__results.values[self.target_index][0, :]
            names = [a.name for a in data.domain.attributes]
            scores = self.create_scores_table(scores, names)
        self.Outputs.scores.send(scores)
コード例 #2
0
    def test_prepare_force_plot_no_top_n_features(self):
        shap_values = [
            np.random.random((3, 4)),
            np.array([[1, -2, 6, 5], [-2, -3, -1, -5], [1, 2, 4, 5]]),
        ]
        predictions = np.array([[1, 2], [1, 3], [1, 4]])

        shaps, segments, labels, ranges = prepare_force_plot_data(
            shap_values, self.iris[:4], predictions, 1)

        self.assertEqual(len(shaps), 3)
        self.assertEqual(len(shaps[0][0]), 3)
        self.assertEqual(len(shaps[0][1]), 1)
        self.assertEqual(len(shaps[1][0]), 0)
        self.assertEqual(len(shaps[1][1]), 4)
        self.assertEqual(len(shaps[2][0]), 4)
        self.assertEqual(len(shaps[2][1]), 0)

        self.assertEqual(len(segments), 3)
        self.assertEqual(len(segments[0][0]), 3)
        self.assertEqual(len(segments[0][1]), 1)
        self.assertEqual(len(segments[1][0]), 0)
        self.assertEqual(len(segments[1][1]), 4)
        self.assertEqual(len(segments[2][0]), 4)
        self.assertEqual(len(segments[2][1]), 0)

        self.assertEqual(len(labels), 3)
        self.assertEqual(len(labels[0][0]), 3)
        self.assertEqual(len(labels[0][1]), 1)
        self.assertEqual(len(labels[1][0]), 0)
        self.assertEqual(len(labels[1][1]), 4)
        self.assertEqual(len(labels[2][0]), 4)
        self.assertEqual(len(labels[2][1]), 0)

        self.assertEqual(len(ranges), 3)
コード例 #3
0
    def test_prepare_force_plot_data_zero_shap(self):
        """
        prepare_force_plot_data should remove all values and variables that
        have SHAP values 0. Test if it works
        """
        shap_values = [
            np.random.random((3, 4)),
            np.array([[1, -2, 6, 0], [-2, -3, 0, -5], [1, 0, 4, 5]]),
        ]
        predictions = np.array([[1, 2], [1, 3], [1, 4]])

        shaps, segments, labels, ranges = prepare_force_plot_data(
            shap_values, self.iris[:4], predictions, 1)
        self.assertListEqual([([6, 1], [-2]), ([], [-5, -3, -2]),
                              ([5, 4, 1], [])], shaps)
        self.assertListEqual(
            [
                ([(2, -4), (-4, -5)], [(2, 4)]),
                ([], [(3, 8), (8, 11), (11, 13)]),
                ([(4, -1), (-1, -5), (-5, -6)], []),
            ],
            segments,
        )
        self.assertListEqual(
            [
                (
                    [("petal length", 1.4), ("sepal length", 5.1)],
                    [("sepal width", 3.5)],
                ),
                (
                    [],
                    [
                        ("petal width", 0.2),
                        ("sepal width", 3.0),
                        ("sepal length", 4.9),
                    ],
                ),
                (
                    [
                        ("petal width", 0.2),
                        ("petal length", 1.3),
                        ("sepal length", 4.7),
                    ],
                    [],
                ),
            ],
            labels,
        )
        self.assertListEqual([(-5, 4), (3, 13), (-6, 4)], ranges)
コード例 #4
0
    def test_prepare_force_plot_data_target_1(self):
        # for target class 1
        shap_values = [
            np.random.random((3, 4)),
            np.array([[1, -2, 6, 5], [-2, -3, -1, -5], [1, 2, 4, 5]]),
        ]
        predictions = np.array([[1, 2], [1, 3], [1, 4]])

        shaps, segments, labels, ranges = prepare_force_plot_data(
            shap_values, self.iris[:4], predictions, 1, top_n_features=3)
        self.assertListEqual([([6, 5], [-2]), ([], [-5, -3, -2]),
                              ([5, 4, 2], [])], shaps)
        self.assertListEqual(
            [
                ([(2, -4), (-4, -9)], [(2, 4)]),
                ([], [(3, 8), (8, 11), (11, 13)]),
                ([(4, -1), (-1, -5), (-5, -7)], []),
            ],
            segments,
        )
        self.assertListEqual(
            [
                (
                    [("petal length", 1.4), ("petal width", 0.2)],
                    [("sepal width", 3.5)],
                ),
                (
                    [],
                    [
                        ("petal width", 0.2),
                        ("sepal width", 3.0),
                        ("sepal length", 4.9),
                    ],
                ),
                (
                    [
                        ("petal width", 0.2),
                        ("petal length", 1.3),
                        ("sepal width", 3.2),
                    ],
                    [],
                ),
            ],
            labels,
        )
        self.assertListEqual([(-9, 4), (3, 13), (-7, 4)], ranges)