Exemple #1
0
    def test_batch_evaluate(self):
        d = {'x1': [1, 2], 'x2': [-2, -1]}
        df = pd.DataFrame(data=d)

        def side_effect_fn(assignment):
            return True if assignment == {'x1': 1, 'x2': -2} else False

        mock_ruleset = DnfRuleSet([], 'class1')
        mock_ruleset.evaluate = MagicMock(side_effect=side_effect_fn)
        actual = batch_evaluate(mock_ruleset, df)
        expected = pd.Series([True, False])
        pd.testing.assert_series_equal(actual, expected)
Exemple #2
0
    def explain_multiclass(self):
        """
        Export rules to technical interchange format trxf from internal representation
        Returns a list of rule sets.

        Returns:
            list(trxf.DnfRuleSet): -- Ordered list of rulesets
        """
        res = list()
        if len(self._rule_map.items()) == 0:
            return DnfRuleSet([], self.target_label)
        for label, rules in self._rule_map.items():
            dnf_ruleset = self._rules_to_trxf_dnf_ruleset(rules, label)
            res.append(dnf_ruleset)
        default_rule = DnfRuleSet([], self.default_label)
        res.append(default_rule)
        return res
Exemple #3
0
    def explain(self):
        """
        Export rule set to technical interchange format trxf from internal representation
        for the positive value (i.e. label value) it has been fitted for.

        When the internal rule set is empty an empty dnf rule set with the internal pos value
        is returned.

        Returns:
            trxf.DnfRuleSet
        """
        assert (self.target_label is not None), 'Not fitted or not fitted for a specific pos value. Use export_rules ' \
                                                'in the latter case. '

        if len(self._rule_map.items()) == 0:
            return DnfRuleSet([], self.target_label)
        for label, rules in self._rule_map.items():
            if label == self.target_label:
                return self._rules_to_trxf_dnf_ruleset(rules, label)
        raise Exception('No rules found for label: ' + str(self.target_label))
Exemple #4
0
    def _rules_to_trxf_dnf_ruleset(self, rules, label):
        """
        Transform rules to trxf dnf_ruleset given their internal presentation and their label

        Parameters
        ----------
        rules : list
            Rules for one target
        label : str
            The label of rules

        Returns
        -------
            DnfRuleSet
        """
        conjunctions = list()
        for rule in rules:
            conjunction = self._rule_to_trxf_conjunction(rule)
            conjunctions.append(conjunction)
        dnf_ruleset = DnfRuleSet(conjunctions, label)
        return dnf_ruleset
Exemple #5
0
    def test_always_false_rule(self):
        logging.basicConfig(stream=sys.stdout, level=logging.INFO)
        logger = logging.getLogger(__name__)

        df = create_degenerate_test_df2()
        TARGET_LABEL = 'target'
        POS_VALUE = 'True'
        x_train = df.drop(columns=[TARGET_LABEL])
        y_train = df[TARGET_LABEL]

        estimator = RipperExplainer()
        estimator.fit(x_train, y_train, target_label=POS_VALUE)

        actual_rule_set = estimator.explain()
        logger.info(actual_rule_set)

        expected_rule_set = DnfRuleSet([], 'True')
        self.assertEqual(actual_rule_set, expected_rule_set)

        assignment = {'int_col': 1, 'float_col': 1.2, 'str_col': 'foo'}
        result = actual_rule_set.evaluate(assignment)
        self.assertFalse(result)
Exemple #6
0
    def test_trxf_export(self):
        logging.basicConfig(stream=sys.stdout, level=logging.INFO)
        logger = logging.getLogger(__name__)

        df = create_test_df()
        TARGET_LABEL = 'target'
        POS_VALUE = 'True'
        x_train = df.drop(columns=[TARGET_LABEL])
        y_train = df[TARGET_LABEL]

        estimator = RipperExplainer()
        estimator.fit(x_train, y_train, target_label=POS_VALUE)
        actual_rule_set = estimator.explain()
        logger.info(actual_rule_set)

        feature = Feature('int_col')
        predicate = Predicate(feature, Relation.EQ, 1)
        conjunction = Conjunction([predicate])
        expected_rule_set = DnfRuleSet([conjunction], 'True')
        self.assertEqual(actual_rule_set, expected_rule_set)

        assignment = {'int_col': 1, 'float_col': 1.2, 'str_col': 'foo'}
        result = actual_rule_set.evaluate(assignment)
        self.assertTrue(result)