示例#1
0
def test_measurer_adjacent_2():
    rule1 = Rule(
        [
            Statement(0, Relation.MT, 1),
            Statement(0, Relation.LEQ, 3),
            Statement(2, Relation.MT, 1),
            Statement(2, Relation.LEQ, 3),
        ],
        1,
    )

    rule2 = Rule(
        [
            Statement(0, Relation.MT, 1),
            Statement(0, Relation.LEQ, 3),
            Statement(1, Relation.MT, 1),
            Statement(1, Relation.LEQ, 3),
        ],
        1,
    )

    assert OverlappingMeasurer().measure(rule1,
                                         rule2) == AdjacentOrNot.ADJACENT
    assert OverlappingMeasurer().measure(rule2,
                                         rule1) == AdjacentOrNot.ADJACENT
示例#2
0
文件: utils.py 项目: bgulowaty/rules
def bound_rule(rule: Rule, x_train) -> Rule:
    feature_max_values = np.max(x_train, axis=0)
    feature_min_values = np.min(x_train, axis=0)

    new_statements_for_rule = set(rule.statements)
    for statement in rule.statements:
        new_statements_for_rule = new_statements_for_rule.union(
            bound_if_needed(rule, statement, feature_min_values, feature_max_values)
        )

    for feature_idx in range(len(feature_max_values)):
        statements_for_feature = rule.get_statements_for_feature(feature_idx)

        if len(statements_for_feature) == 0:
            lower_statement = Statement(
                feature_idx, Relation.MT, feature_min_values[feature_idx] - EPS
            )
            upper_statement = Statement(
                feature_idx, Relation.LEQ, feature_max_values[feature_idx]
            )
            new_statements_for_rule = new_statements_for_rule.union(
                {lower_statement, upper_statement}
            )

    return Rule(new_statements_for_rule, rule.distribution_or_class)
示例#3
0
文件: utils.py 项目: bgulowaty/rules
def join_consecutive_statements(rule: Rule) -> Rule:
    all_statements = set()

    for feature in rule.get_features():
        statements = rule.get_statements_for_feature(feature)
        new_this_rule_statements = set(statements)

        sorted_by_threshold = iter(sorted(statements, key=lambda s: s.threshold))
        try:
            current_statement = next(sorted_by_threshold)
            while True:
                next_statement = next(sorted_by_threshold)

                if current_statement.relation == next_statement.relation:
                    if current_statement.relation == Relation.MT:
                        new_this_rule_statements.remove(next_statement)
                    else:
                        new_this_rule_statements.remove(current_statement)

                current_statement = next_statement
        except StopIteration:
            pass

        all_statements = all_statements.union(new_this_rule_statements)

    return Rule(all_statements, rule.distribution_or_class)
示例#4
0
文件: join.py 项目: bgulowaty/rules
def join_if_possible(rule1: Rule, rule2: Rule) -> Optional[Rule]:
    spans_by_feature_idx_1 = {
        feature: calculate_span_for_statements(statements)
        for feature, statements in groupby(
            rule1.statements, lambda statement: statement.feature_idx)
    }
    spans_by_feature_idx_2 = {
        feature: calculate_span_for_statements(statements)
        for feature, statements in groupby(
            rule2.statements, lambda statement: statement.feature_idx)
    }

    overlaps = any([
        any_span_overlaps(spans1, spans2)
        for feature1, spans1 in spans_by_feature_idx_1.items()
        for feature2, spans2 in spans_by_feature_idx_2.items()
        if feature1 == feature2
    ])

    if overlaps:
        return None

    new_distribution = defaultdict(lambda: 0)
    for label, value in rule1.distribution_or_class.items():
        new_distribution[label] = new_distribution[label] + value
    for label, value in rule2.distribution_or_class.items():
        new_distribution[label] = new_distribution[label] + value

    return Rule(rule1.statements + rule2.statements, new_distribution)
示例#5
0
文件: utils.py 项目: bgulowaty/rules
def bound_if_needed(
    rule: Rule, statement: Statement, feature_min_values, feature_max_values
) -> Set[Statement]:
    statements = set(rule.get_statements_for_feature(statement.feature_idx)).difference(
        {statement}
    )
    new_statements = set()
    next_higher = pipe(
        statements, filter(lambda s: s.threshold > statement.threshold), list
    )

    next_lower = pipe(
        statements, filter(lambda s: s.threshold < statement.threshold), list
    )

    if statement.relation == Relation.LEQ and not any(next_lower):
        new_statements.add(
            Statement(
                statement.feature_idx,
                Relation.MT,
                feature_min_values[statement.feature_idx],
            )
        )
    elif statement.relation == Relation.MT and not any(next_higher):
        new_statements.add(
            Statement(
                statement.feature_idx,
                Relation.LEQ,
                feature_max_values[statement.feature_idx] - EPS,
            )
        )
    else:
        return set()

    return new_statements
示例#6
0
文件: join.py 项目: bgulowaty/rules
def any_span_overlaps(spans1: Set[Span], spans2: Set[Span]) -> bool:
    return any(
        [span1.overlaps(span2) for span1, span2 in product(spans1, spans2)])

    if (adjacency == AdjacentOrNot.ADJACENT
            and rule1.classified_class != rule2.classified_class):
        return None

    return Rule(rule1.statements + rule2.statements, rule1.classified_class)
示例#7
0
def to_rule(hyperrect: Hyperrectangle) -> Rule:
    rule_statements = set()

    for feature, statement in hyperrect.get_statement_by_feature().items():
        if statement.upper_boundary != np.inf and statement.lower_boundary != -np.inf:
            rule_statements.add(
                Statement(feature.idx, Relation.LEQ, statement.upper_boundary))
            rule_statements.add(
                Statement(feature.idx, Relation.MT,
                          statement.lower_boundary - EPS))

    return Rule(list(rule_statements), hyperrect.label)
示例#8
0
def test_measurer_not_adjacent_real():
    rule1 = Rule(
        [
            Statement(feature_idx=3, relation=Relation.LEQ, threshold=2.5),
            Statement(feature_idx=2, relation=Relation.MT, threshold=1.1),
            Statement(feature_idx=0, relation=Relation.MT, threshold=5.75),
            Statement(feature_idx=1,
                      relation=Relation.LEQ,
                      threshold=3.700000047683716),
            Statement(feature_idx=0, relation=Relation.LEQ, threshold=7.9),
            Statement(feature_idx=2,
                      relation=Relation.LEQ,
                      threshold=4.950000047683716),
            Statement(feature_idx=3,
                      relation=Relation.MT,
                      threshold=1.699999988079071),
            Statement(feature_idx=1, relation=Relation.MT, threshold=2.0),
        ],
        0,
    )
    rule2 = Rule(
        [
            Statement(feature_idx=3, relation=Relation.LEQ, threshold=2.5),
            Statement(feature_idx=2, relation=Relation.LEQ, threshold=6.9),
            Statement(feature_idx=2,
                      relation=Relation.MT,
                      threshold=2.350000023841858),
            Statement(feature_idx=3,
                      relation=Relation.MT,
                      threshold=1.6500000357627869),
        ],
        1,
    )

    assert OverlappingMeasurer().measure(rule1,
                                         rule2) == AdjacentOrNot.ADJACENT
    assert OverlappingMeasurer().measure(rule2,
                                         rule1) == AdjacentOrNot.ADJACENT
示例#9
0
def test_measures_1():
    rule = Rule(
        statements=[Statement(0, Relation.LEQ, threshold=5)], distribution_or_class=0
    )

    x = [[0], [1], [2], [6], [7], [8]]
    y = [0, 0, 1, 0, 1, 0]

    measures = BayesianRuleMeasures.create(rule, x, y)

    assert measures.a() == 2
    assert measures.b() == 2
    assert measures.c() == 1
    assert measures.d() == 1
示例#10
0
def test_get_rule_corners_not_sorted_statements():
    r1 = Rule(
        [
            Statement(1, Relation.LEQ, 4),
            Statement(1, Relation.MT, 8),
            Statement(0, Relation.LEQ, 2),
            Statement(0, Relation.MT, 0),
        ],
        0,
    )

    edges = get_rule_edges(r1)

    assert {(2, 4), (0, 4), (2, 8), (0, 8)} == edges
示例#11
0
    def as_rules(self, rects: Collection[Rectangle]) -> Collection[Rule]:
        rules = []

        for rect in rects:
            statements = []
            for feature, bound in rect.feature_by_bounds.items():
                statements.append(Statement(feature, Relation.MT, bound.lower))
                statements.append(Statement(feature, Relation.LEQ,
                                            bound.upper))

            labeling = (self.rect_by_labeling[rect] if rect
                        in self.rect_by_labeling else self.default_class)
            rules.append(Rule(statements, labeling))

        return rules
示例#12
0
def to_nnge_hyper(rule: Rule, x) -> Hyperrectangle:
    hyperrect_statemets = set()
    for feature_idx, statements in rule.get_statements_by_feature().items():
        if len(statements) != 2:
            raise Exception("nah")
        sorted_thresholds = pipe(statements, map(lambda s: s.threshold),
                                 sorted, list)
        feature = HyperrectFeature(feature_idx, FeatureType.REAL)
        hyperrect_statemets.add(
            HyperrectStatement(
                feature,
                lower_boundary=sorted_thresholds[0] - EPS,
                upper_boundary=sorted_thresholds[1],
            ))

    hyper = Hyperrectangle(hyperrect_statemets, label=rule.classified_class)
    return hyper
示例#13
0
文件: utils.py 项目: bgulowaty/rules
def test_bound_rule():
    rule = Rule(
        [
            Statement(0, Relation.MT, 0),
            Statement(1, Relation.LEQ, 20),
        ],
        1,
    )

    x_train = [[-5, -5], [30, 30]]

    bounded_rule = bound_rule(rule, x_train)

    assert set(bounded_rule.statements) == {
        Statement(0, Relation.MT, 0),
        Statement(0, Relation.LEQ, 30),
        Statement(1, Relation.LEQ, 20),
        Statement(1, Relation.MT, -5 - EPS),
    }
示例#14
0
文件: utils.py 项目: bgulowaty/rules
def test_join_consecutive():
    rule = Rule(
        [
            Statement(0, Relation.MT, 10),
            Statement(0, Relation.MT, 20),
            Statement(0, Relation.LEQ, 35),
            Statement(0, Relation.LEQ, 40),
            Statement(0, Relation.LEQ, 50),
            Statement(1, Relation.MT, 10),
            Statement(1, Relation.MT, 20),
        ],
        1,
    )

    actual = join_consecutive_statements(rule)
    assert set(actual.statements) == {
        Statement(0, Relation.MT, 10),
        Statement(0, Relation.LEQ, 50),
        Statement(1, Relation.MT, 10),
    }
示例#15
0
        def recurse(node: int, statements: List[Statement] = []) -> Set[Rule]:
            currentNodeFeatureIndex: int = tree_.feature[node]

            if currentNodeFeatureIndex != _tree.TREE_UNDEFINED:
                threshold: float = tree_.threshold[node]

                leftStatements = statements + [
                    Statement(currentNodeFeatureIndex, Relation.LEQ, threshold)
                ]
                leftRules = recurse(tree_.children_left[node], leftStatements)

                rightStatements = statements + [
                    Statement(currentNodeFeatureIndex, Relation.MT, threshold)
                ]
                rightRules = recurse(tree_.children_right[node], rightStatements)
                return leftRules.union(rightRules)
            else:
                samplesCountForEachClass = tree_.value[node][0]
                samplesCountByClass = {
                    self._skLearnTree.classes_[idx]: count
                    for idx, count in enumerate(samplesCountForEachClass)
                }
                return {Rule(set(statements), samplesCountByClass)}
示例#16
0
文件: join.py 项目: bgulowaty/rules
    return any(
        [span1.overlaps(span2) for span1, span2 in product(spans1, spans2)])

    if (adjacency == AdjacentOrNot.ADJACENT
            and rule1.classified_class != rule2.classified_class):
        return None

    return Rule(rule1.statements + rule2.statements, rule1.classified_class)


@pytest.mark.skip("to fix")
@pytest.mark.parametrize(
    "rule1,rule2,expected",
    [
        (
            Rule([Statement(0, Relation.MT, 0),
                  Statement(0, Relation.LEQ, 5)], 1),
            Rule([Statement(1, Relation.MT, 0),
                  Statement(1, Relation.LEQ, 5)], 1),
            Rule(
                [
                    Statement(0, Relation.MT, 0),
                    Statement(0, Relation.LEQ, 5),
                    Statement(1, Relation.MT, 0),
                    Statement(1, Relation.LEQ, 5),
                ],
                1,
            ),
        ),
        (
            Rule([Statement(0, Relation.MT, 0)], 1),
            Rule([Statement(1, Relation.MT, 0)], 2),