def get_trees(self):
        """
        Get the decision trees.

        :return: A list of decision trees.
        :rtype: `[Tree]`
        """

        import json
        from art.metrics.verification_decisions_trees import Box, Tree

        booster_dump = self._model.get_booster().get_dump(dump_format="json")
        trees = list()

        for i_tree, tree_dump in enumerate(booster_dump):
            box = Box()

            if self._model.n_classes_ == 2:
                class_label = -1
            else:
                class_label = i_tree % self._model.n_classes_

            tree_json = json.loads(tree_dump)
            trees.append(
                Tree(class_id=class_label,
                     leaf_nodes=self._get_leaf_nodes(tree_json, i_tree,
                                                     class_label, box)))

        return trees
예제 #2
0
    def get_trees(self):
        """
        Get the decision trees.

        :return: A list of decision trees.
        :rtype: `[Tree]`
        """
        from art.metrics.verification_decisions_trees import Box, Tree

        trees = list()

        for i_tree, decision_tree_model in enumerate(self._model.estimators_):
            box = Box()

            #     if num_classes == 2:
            #         class_label = -1
            #     else:
            #         class_label = i_tree % num_classes

            decision_tree_classifier = ScikitlearnDecisionTreeClassifier(model=decision_tree_model)

            for i_class in range(self._model.n_classes_):
                class_label = i_class

                # pylint: disable=W0212
                trees.append(Tree(class_id=class_label,
                                  leaf_nodes=decision_tree_classifier._get_leaf_nodes(0, i_tree, class_label, box)))

        return trees
예제 #3
0
    def get_trees(self) -> list:
        """
        Get the decision trees.

        :return: A list of decision trees.
        """
        from art.metrics.verification_decisions_trees import Box, Tree

        booster_dump = self._model.dump_model()["tree_info"]
        trees = list()

        for i_tree, tree_dump in enumerate(booster_dump):
            box = Box()

            # pylint: disable=W0212
            if self._model._Booster__num_class == 2:
                class_label = -1
            else:
                class_label = i_tree % self._model._Booster__num_class

            trees.append(
                Tree(
                    class_id=class_label,
                    leaf_nodes=self._get_leaf_nodes(
                        tree_dump["tree_structure"], i_tree, class_label, box),
                ))

        return trees
예제 #4
0
    def get_trees(self):
        """
        Get the decision trees.

        :return: A list of decision trees.
        :rtype: `[Tree]`
        """
        from art.metrics.verification_decisions_trees import Box, Tree

        trees = list()
        num_trees, num_classes = self._model.estimators_.shape

        for i_tree in range(num_trees):
            box = Box()

            for i_class in range(num_classes):
                decision_tree_classifier = ScikitlearnDecisionTreeRegressor(
                    model=self._model.estimators_[i_tree, i_class])

                if num_classes == 2:
                    class_label = None
                else:
                    class_label = i_class

                # pylint: disable=W0212
                trees.append(Tree(class_id=class_label,
                                  leaf_nodes=decision_tree_classifier._get_leaf_nodes(0, i_tree, class_label, box)))

        return trees