예제 #1
0
 def test_merging_lt_with_gt_continuous_rules(self):
     """Merging `x > 1` and `x < 2` should produce `1 < x < 2`."""
     rule1 = ContinuousRule('Rule', True, 1)
     rule2 = ContinuousRule('Rule', False, 2)
     new_rule = rule1.merge_with(rule2)
     self.assertIsInstance(new_rule, IntervalRule)
     self.assertEquals(new_rule.left_rule, rule1)
     self.assertEquals(new_rule.right_rule, rule2)
 def test_merging_lt_with_gt_continuous_rules(self):
     """Merging `x > 1` and `x < 2` should produce `1 < x < 2`."""
     rule1 = ContinuousRule('Rule', True, 1)
     rule2 = ContinuousRule('Rule', False, 2)
     new_rule = rule1.merge_with(rule2)
     self.assertIsInstance(new_rule, IntervalRule)
     self.assertEquals(new_rule.left_rule, rule1)
     self.assertEquals(new_rule.right_rule, rule2)
 def test_merging_interval_rule_with_larger_lt_continuous_rule(self):
     """Merging `0 < x < 3` and `x > 1` should produce `1 < x < 3`."""
     rule1 = IntervalRule('Rule', ContinuousRule('Rule', True, 0),
                          ContinuousRule('Rule', False, 3))
     rule2 = ContinuousRule('Rule', True, 1)
     new_rule = rule1.merge_with(rule2)
     self.assertIsInstance(new_rule, IntervalRule)
     self.assertEquals(new_rule.left_rule.value, 1)
 def test_merging_interval_rule_with_smaller_gt_continuous_rule(self):
     """Merging `0 < x < 3` and `x < 2` should produce `0 < x < 2`."""
     rule1 = IntervalRule('Rule', ContinuousRule('Rule', True, 0),
                          ContinuousRule('Rule', False, 3))
     rule2 = ContinuousRule('Rule', False, 2)
     new_rule = rule1.merge_with(rule2)
     self.assertIsInstance(new_rule, IntervalRule)
     self.assertEquals(new_rule.right_rule.value, 2)
 def test_merging_interval_rules_generally(self):
     """Merging `0 < x < 4` and `2 < x < 6` should produce `2 < x < 4`."""
     rule1 = IntervalRule('Rule', ContinuousRule('Rule', True, 0),
                          ContinuousRule('Rule', False, 4))
     rule2 = IntervalRule('Rule', ContinuousRule('Rule', True, 2),
                          ContinuousRule('Rule', False, 6))
     new_rule = rule1.merge_with(rule2)
     self.assertEquals(new_rule.left_rule.value, 2)
     self.assertEquals(new_rule.right_rule.value, 4)
 def test_merging_interval_rules_with_larger_lt_component(self):
     """Merging `0 < x < 4` and `1 < x < 4` should produce `1 < x < 4`."""
     rule1 = IntervalRule('Rule', ContinuousRule('Rule', True, 0),
                          ContinuousRule('Rule', False, 4))
     rule2 = IntervalRule('Rule', ContinuousRule('Rule', True, 1),
                          ContinuousRule('Rule', False, 4))
     new_rule = rule1.merge_with(rule2)
     self.assertEquals(new_rule.left_rule.value, 1)
     self.assertEquals(new_rule.right_rule.value, 4)
 def test_merge_commutativity_on_interval_rules(self):
     rule1 = IntervalRule('Rule', ContinuousRule('Rule', True, 0),
                          ContinuousRule('Rule', False, 4))
     rule2 = IntervalRule('Rule', ContinuousRule('Rule', True, 2),
                          ContinuousRule('Rule', False, 6))
     new_rule1 = rule1.merge_with(rule2)
     new_rule2 = rule2.merge_with(rule1)
     self.assertEquals(new_rule1.left_rule.value, new_rule2.left_rule.value)
     self.assertEquals(new_rule1.right_rule.value,
                       new_rule2.right_rule.value)
예제 #8
0
    def rules(self, node):
        if node != self.root:
            parent = self.parent(node)
            # Convert the parent list of rules into an ordered dict
            pr = OrderedDict([(r.attr_name, r) for r in self.rules(parent)])

            parent_attr = self.attribute(parent)
            # Get the parent attribute type
            parent_attr_cv = parent_attr.compute_value

            is_left_child = self.__left_child(parent) == node

            # The parent split variable is discrete
            if isinstance(parent_attr_cv, Indicator) and \
                    hasattr(parent_attr_cv.variable, 'values'):
                values = parent_attr_cv.variable.values
                attr_name = parent_attr_cv.variable.name
                eq = not is_left_child * (len(values) != 2)
                value = values[abs(parent_attr_cv.value - is_left_child *
                                   (len(values) == 2))]
                new_rule = DiscreteRule(attr_name, eq, value)
                # Since discrete variables should appear in their own lines
                # they must not be merged, so the dict key is set with the
                # value, so the same keys can exist with different values
                # e.g. #legs ≠ 2 and #legs ≠ 4
                attr_name = attr_name + '_' + value
            # The parent split variable is continuous
            else:
                attr_name = parent_attr.name
                sign = not is_left_child
                value = self._tree.threshold[self.parent(node)]
                new_rule = ContinuousRule(attr_name,
                                          sign,
                                          value,
                                          inclusive=is_left_child)

            # Check if a rule with that attribute exists
            if attr_name in pr:
                pr[attr_name] = pr[attr_name].merge_with(new_rule)
                pr.move_to_end(attr_name)
            else:
                pr[attr_name] = new_rule

            return list(pr.values())
        else:
            return []
예제 #9
0
 def test_merge_commutativity_on_continuous_rules(self):
     rule1 = ContinuousRule('Rule1', True, 1)
     rule2 = ContinuousRule('Rule1', True, 2)
     new_rule1 = rule1.merge_with(rule2)
     new_rule2 = rule2.merge_with(rule1)
     self.assertEqual(new_rule1.value, new_rule2.value)
예제 #10
0
 def test_merge_keeps_attr_name_on_continuous_rules(self):
     rule1 = ContinuousRule('Rule1', True, 1)
     rule2 = ContinuousRule('Rule1', True, 2)
     new_rule = rule1.merge_with(rule2)
     self.assertEquals(new_rule.attr_name, 'Rule1')
예제 #11
0
 def test_merging_two_lt_continuous_rules(self):
     """Merging `x < 1` and `x < 2` should produce `x < 1`."""
     rule1 = ContinuousRule('Rule', False, 1)
     rule2 = ContinuousRule('Rule', False, 2)
     new_rule = rule1.merge_with(rule2)
     self.assertEqual(new_rule.value, 1)
예제 #12
0
 def test_merging_lt_with_lte_rule(self):
     """Merging `x < 1` and `x ≤ 1` should produce `x < 1`."""
     rule1 = ContinuousRule('Rule', False, 1, inclusive=True)
     rule2 = ContinuousRule('Rule', False, 1, inclusive=False)
     new_rule = rule1.merge_with(rule2)
     self.assertEqual(new_rule.inclusive, False)
 def test_merging_lt_with_lte_rule(self):
     """Merging `x < 1` and `x ≤ 1` should produce `x < 1`."""
     rule1 = ContinuousRule('Rule', False, 1, inclusive=True)
     rule2 = ContinuousRule('Rule', False, 1, inclusive=False)
     new_rule = rule1.merge_with(rule2)
     self.assertEqual(new_rule.inclusive, False)
 def test_merging_two_lt_continuous_rules(self):
     """Merging `x < 1` and `x < 2` should produce `x < 1`."""
     rule1 = ContinuousRule('Rule', False, 1)
     rule2 = ContinuousRule('Rule', False, 2)
     new_rule = rule1.merge_with(rule2)
     self.assertEqual(new_rule.value, 1)
 def test_merge_keeps_attr_name_on_continuous_rules(self):
     rule1 = ContinuousRule('Rule1', True, 1)
     rule2 = ContinuousRule('Rule1', True, 2)
     new_rule = rule1.merge_with(rule2)
     self.assertEquals(new_rule.attr_name, 'Rule1')
 def test_merge_keeps_sign_on_continuous_rules(self):
     rule1 = ContinuousRule('Rule1', True, 1)
     rule2 = ContinuousRule('Rule1', True, 2)
     new_rule = rule1.merge_with(rule2)
     self.assertEquals(new_rule.sign, True)
예제 #17
0
 def test_merge_keeps_sign_on_continuous_rules(self):
     rule1 = ContinuousRule('Rule1', True, 1)
     rule2 = ContinuousRule('Rule1', True, 2)
     new_rule = rule1.merge_with(rule2)
     self.assertEquals(new_rule.sign, True)
 def test_merge_commutativity_on_continuous_rules(self):
     rule1 = ContinuousRule('Rule1', True, 1)
     rule2 = ContinuousRule('Rule1', True, 2)
     new_rule1 = rule1.merge_with(rule2)
     new_rule2 = rule2.merge_with(rule1)
     self.assertEqual(new_rule1.value, new_rule2.value)