def traverse_nodes(node_id=0, operator=None, threshold=None, feature=None, conditions=[]): if node_id != 0: if feature_names is not None: feature_name = feature_names[feature] else: feature_name = feature rule_condition = RuleCondition(feature_index=feature, threshold=threshold, operator=operator, support=tree.n_node_samples[node_id] / float(tree.n_node_samples[0]), feature_name=feature_name) new_conditions = conditions + [rule_condition] else: new_conditions = [] ## if not terminal node if tree.children_left[node_id] != tree.children_right[node_id]: feature = tree.feature[node_id] threshold = tree.threshold[node_id] left_node_id = tree.children_left[node_id] traverse_nodes(left_node_id, "<=", threshold, feature, new_conditions) right_node_id = tree.children_right[node_id] traverse_nodes(right_node_id, ">", threshold, feature, new_conditions) else: # a leaf node if len(new_conditions) > 0: new_rule = Rule(new_conditions, tree.value[node_id][0][0]) rules.update([new_rule]) else: pass # tree only has a root node! return None
def test_rule_condition_hashing_different3(): assert (RuleCondition(2, 5, ">", 0.4) != RuleCondition(1, 5, ">", 0.4))
def test_rule_condition_hashing_different1(): assert (RuleCondition(1, 4, "<=", 0.4) != RuleCondition(1, 5, "<=", 0.4))
def test_rule_condition_hashing_equal2(): assert (RuleCondition(1, 5, "<=", 0.5) == RuleCondition(1, 5, "<=", 0.4))
import numpy as np from imodels.util.rules import RuleCondition, Rule from imodels.util.transforms import FriedScale from imodels.rule_set.rule_fit import RuleFitRegressor rule_condition_smaller = RuleCondition(1, 5, "<=", 0.4) rule_condition_greater = RuleCondition(0, 1, ">", 0.1) X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) ## Testing RuleCondition def test_rule_condition_hashing_equal1(): assert (RuleCondition(1, 5, "<=", 0.4) == RuleCondition(1, 5, "<=", 0.4)) def test_rule_condition_hashing_equal2(): assert (RuleCondition(1, 5, "<=", 0.5) == RuleCondition(1, 5, "<=", 0.4)) def test_rule_condition_hashing_different1(): assert (RuleCondition(1, 4, "<=", 0.4) != RuleCondition(1, 5, "<=", 0.4)) def test_rule_condition_hashing_different2(): assert (RuleCondition(1, 5, ">", 0.4) != RuleCondition(1, 5, "<=", 0.4)) def test_rule_condition_hashing_different3(): assert (RuleCondition(2, 5, ">", 0.4) != RuleCondition(1, 5, ">", 0.4))