示例#1
0
文件: utils.py 项目: bgulowaty/rules
def bound_rule(rule: Rule, x_train) -> Rule:
    feature_max_values = np.max(x_train, axis=0)
    feature_min_values = np.min(x_train, axis=0)

    new_statements_for_rule = set(rule.statements)
    for statement in rule.statements:
        new_statements_for_rule = new_statements_for_rule.union(
            bound_if_needed(rule, statement, feature_min_values, feature_max_values)
        )

    for feature_idx in range(len(feature_max_values)):
        statements_for_feature = rule.get_statements_for_feature(feature_idx)

        if len(statements_for_feature) == 0:
            lower_statement = Statement(
                feature_idx, Relation.MT, feature_min_values[feature_idx] - EPS
            )
            upper_statement = Statement(
                feature_idx, Relation.LEQ, feature_max_values[feature_idx]
            )
            new_statements_for_rule = new_statements_for_rule.union(
                {lower_statement, upper_statement}
            )

    return Rule(new_statements_for_rule, rule.distribution_or_class)
示例#2
0
文件: utils.py 项目: bgulowaty/rules
def bound_if_needed(
    rule: Rule, statement: Statement, feature_min_values, feature_max_values
) -> Set[Statement]:
    statements = set(rule.get_statements_for_feature(statement.feature_idx)).difference(
        {statement}
    )
    new_statements = set()
    next_higher = pipe(
        statements, filter(lambda s: s.threshold > statement.threshold), list
    )

    next_lower = pipe(
        statements, filter(lambda s: s.threshold < statement.threshold), list
    )

    if statement.relation == Relation.LEQ and not any(next_lower):
        new_statements.add(
            Statement(
                statement.feature_idx,
                Relation.MT,
                feature_min_values[statement.feature_idx],
            )
        )
    elif statement.relation == Relation.MT and not any(next_higher):
        new_statements.add(
            Statement(
                statement.feature_idx,
                Relation.LEQ,
                feature_max_values[statement.feature_idx] - EPS,
            )
        )
    else:
        return set()

    return new_statements
示例#3
0
def to_rule(hyperrect: Hyperrectangle) -> Rule:
    rule_statements = set()

    for feature, statement in hyperrect.get_statement_by_feature().items():
        if statement.upper_boundary != np.inf and statement.lower_boundary != -np.inf:
            rule_statements.add(
                Statement(feature.idx, Relation.LEQ, statement.upper_boundary))
            rule_statements.add(
                Statement(feature.idx, Relation.MT,
                          statement.lower_boundary - EPS))

    return Rule(list(rule_statements), hyperrect.label)
示例#4
0
def test_get_rule_corners_not_sorted_statements():
    r1 = Rule(
        [
            Statement(1, Relation.LEQ, 4),
            Statement(1, Relation.MT, 8),
            Statement(0, Relation.LEQ, 2),
            Statement(0, Relation.MT, 0),
        ],
        0,
    )

    edges = get_rule_edges(r1)

    assert {(2, 4), (0, 4), (2, 8), (0, 8)} == edges
示例#5
0
def test_measurer_adjacent_2():
    rule1 = Rule(
        [
            Statement(0, Relation.MT, 1),
            Statement(0, Relation.LEQ, 3),
            Statement(2, Relation.MT, 1),
            Statement(2, Relation.LEQ, 3),
        ],
        1,
    )

    rule2 = Rule(
        [
            Statement(0, Relation.MT, 1),
            Statement(0, Relation.LEQ, 3),
            Statement(1, Relation.MT, 1),
            Statement(1, Relation.LEQ, 3),
        ],
        1,
    )

    assert OverlappingMeasurer().measure(rule1,
                                         rule2) == AdjacentOrNot.ADJACENT
    assert OverlappingMeasurer().measure(rule2,
                                         rule1) == AdjacentOrNot.ADJACENT
示例#6
0
    def as_rules(self, rects: Collection[Rectangle]) -> Collection[Rule]:
        rules = []

        for rect in rects:
            statements = []
            for feature, bound in rect.feature_by_bounds.items():
                statements.append(Statement(feature, Relation.MT, bound.lower))
                statements.append(Statement(feature, Relation.LEQ,
                                            bound.upper))

            labeling = (self.rect_by_labeling[rect] if rect
                        in self.rect_by_labeling else self.default_class)
            rules.append(Rule(statements, labeling))

        return rules
示例#7
0
文件: utils.py 项目: bgulowaty/rules
def test_join_consecutive():
    rule = Rule(
        [
            Statement(0, Relation.MT, 10),
            Statement(0, Relation.MT, 20),
            Statement(0, Relation.LEQ, 35),
            Statement(0, Relation.LEQ, 40),
            Statement(0, Relation.LEQ, 50),
            Statement(1, Relation.MT, 10),
            Statement(1, Relation.MT, 20),
        ],
        1,
    )

    actual = join_consecutive_statements(rule)
    assert set(actual.statements) == {
        Statement(0, Relation.MT, 10),
        Statement(0, Relation.LEQ, 50),
        Statement(1, Relation.MT, 10),
    }
示例#8
0
def test_span_calculation_2():
    statements = {
        Statement(0, Relation.MT, 0),
        Statement(0, Relation.LEQ, 1),
        Statement(0, Relation.LEQ, 2),
        Statement(0, Relation.LEQ, 3),
        Statement(0, Relation.MT, 4),
        Statement(0, Relation.LEQ, 5),
        Statement(0, Relation.LEQ, 6),
        Statement(0, Relation.MT, 7),
    }

    spans = calculate_span_for_statements(statements)

    assert {Span(3, 0), Span(6, 4), Span(INFINITY, 7)} == spans
示例#9
0
def test_measures_1():
    rule = Rule(
        statements=[Statement(0, Relation.LEQ, threshold=5)], distribution_or_class=0
    )

    x = [[0], [1], [2], [6], [7], [8]]
    y = [0, 0, 1, 0, 1, 0]

    measures = BayesianRuleMeasures.create(rule, x, y)

    assert measures.a() == 2
    assert measures.b() == 2
    assert measures.c() == 1
    assert measures.d() == 1
示例#10
0
        def recurse(node: int, statements: List[Statement] = []) -> Set[Rule]:
            currentNodeFeatureIndex: int = tree_.feature[node]

            if currentNodeFeatureIndex != _tree.TREE_UNDEFINED:
                threshold: float = tree_.threshold[node]

                leftStatements = statements + [
                    Statement(currentNodeFeatureIndex, Relation.LEQ, threshold)
                ]
                leftRules = recurse(tree_.children_left[node], leftStatements)

                rightStatements = statements + [
                    Statement(currentNodeFeatureIndex, Relation.MT, threshold)
                ]
                rightRules = recurse(tree_.children_right[node], rightStatements)
                return leftRules.union(rightRules)
            else:
                samplesCountForEachClass = tree_.value[node][0]
                samplesCountByClass = {
                    self._skLearnTree.classes_[idx]: count
                    for idx, count in enumerate(samplesCountForEachClass)
                }
                return {Rule(set(statements), samplesCountByClass)}
示例#11
0
def test_span_calculation():
    statements = {
        Statement(0, Relation.LEQ, 1),
        Statement(0, Relation.MT, 2),
        Statement(0, Relation.LEQ, 3),
        Statement(0, Relation.LEQ, 4),
        Statement(0, Relation.MT, 5),
        Statement(0, Relation.LEQ, 6),
    }

    spans = calculate_span_for_statements(statements)

    assert {Span(1, -INFINITY), Span(4, 2), Span(6, 5)} == spans
示例#12
0
文件: utils.py 项目: bgulowaty/rules
def test_bound_rule_2():
    rule = Rule(
        [
            Statement(0, Relation.MT, 0),
            Statement(0, Relation.LEQ, 5),
            Statement(1, Relation.LEQ, 20),
        ],
        1,
    )

    x_train = [[-5, -5], [30, 30]]

    bounded_rule = bound_rule(rule, x_train)

    assert set(bounded_rule.statements) == {
        Statement(0, Relation.MT, 0),
        Statement(0, Relation.LEQ, 5),
        Statement(1, Relation.LEQ, 20),
        Statement(1, Relation.MT, -5 - EPS),
    }
示例#13
0
文件: join.py 项目: bgulowaty/rules
    return any(
        [span1.overlaps(span2) for span1, span2 in product(spans1, spans2)])

    if (adjacency == AdjacentOrNot.ADJACENT
            and rule1.classified_class != rule2.classified_class):
        return None

    return Rule(rule1.statements + rule2.statements, rule1.classified_class)


@pytest.mark.skip("to fix")
@pytest.mark.parametrize(
    "rule1,rule2,expected",
    [
        (
            Rule([Statement(0, Relation.MT, 0),
                  Statement(0, Relation.LEQ, 5)], 1),
            Rule([Statement(1, Relation.MT, 0),
                  Statement(1, Relation.LEQ, 5)], 1),
            Rule(
                [
                    Statement(0, Relation.MT, 0),
                    Statement(0, Relation.LEQ, 5),
                    Statement(1, Relation.MT, 0),
                    Statement(1, Relation.LEQ, 5),
                ],
                1,
            ),
        ),
        (
            Rule([Statement(0, Relation.MT, 0)], 1),
示例#14
0
def test_measurer_not_adjacent_real():
    rule1 = Rule(
        [
            Statement(feature_idx=3, relation=Relation.LEQ, threshold=2.5),
            Statement(feature_idx=2, relation=Relation.MT, threshold=1.1),
            Statement(feature_idx=0, relation=Relation.MT, threshold=5.75),
            Statement(feature_idx=1,
                      relation=Relation.LEQ,
                      threshold=3.700000047683716),
            Statement(feature_idx=0, relation=Relation.LEQ, threshold=7.9),
            Statement(feature_idx=2,
                      relation=Relation.LEQ,
                      threshold=4.950000047683716),
            Statement(feature_idx=3,
                      relation=Relation.MT,
                      threshold=1.699999988079071),
            Statement(feature_idx=1, relation=Relation.MT, threshold=2.0),
        ],
        0,
    )
    rule2 = Rule(
        [
            Statement(feature_idx=3, relation=Relation.LEQ, threshold=2.5),
            Statement(feature_idx=2, relation=Relation.LEQ, threshold=6.9),
            Statement(feature_idx=2,
                      relation=Relation.MT,
                      threshold=2.350000023841858),
            Statement(feature_idx=3,
                      relation=Relation.MT,
                      threshold=1.6500000357627869),
        ],
        1,
    )

    assert OverlappingMeasurer().measure(rule1,
                                         rule2) == AdjacentOrNot.ADJACENT
    assert OverlappingMeasurer().measure(rule2,
                                         rule1) == AdjacentOrNot.ADJACENT
示例#15
0
                                         rule2) == AdjacentOrNot.ADJACENT
    assert OverlappingMeasurer().measure(rule2,
                                         rule1) == AdjacentOrNot.ADJACENT


@pytest.mark.parametrize(
    "span1,span2,expected",
    [
        (Span(5, 3), Span(6, 2), True),
        (Span(5, 3), Span(0, -5), False),
        (Span(10, 0), Span(15, 5), True),
        (Span(10, 0), Span(5, -10), True),
        (Span(10, 0), Span(20, 15), False),
        (Span(10, 0), Span(-10, -20), False),
    ],
)
def test_span_overlapping(span1: Span, span2: Span, expected: bool):
    assert span1.overlaps(span2) == expected
    assert span2.overlaps(span1) == expected


@pytest.mark.skip("To fix")
@pytest.mark.parametrize(
    "statements1,statements2,expected",
    [([Statement(0, Relation.LEQ, 5)], [Statement(0, Relation.MT, 5)], False)],
)
def test_statements_overlap(statements1: List[Statement],
                            statements2: List[Statement], expected: bool):
    assert statements_overlap(statements1, statements2) == expected
    assert statements_overlap(statements2, statements1) == expected