コード例 #1
0
 def check_literal(self, atr_col_name, best_foil, best_l, i, n0, p0, unique_values, op):
     literal = Literal(atr_col_name, op, unique_values[i])
     p, n = self.count_p_n_literal(literal)
     tmp_foil = count_foil_grow(p0, n0, p, n)
     if tmp_foil > best_foil:
         best_foil = tmp_foil
         best_l = copy.deepcopy(literal)
     return best_foil, best_l
コード例 #2
0
 def choose_best_literal(self, atr_col_name, best_foil, best_l, df, n0, p0, unique_values, values_to_literal):
     for i in range(0, len(unique_values)):
         values_to_literal.append(df.at[i, 'value'])
         literal = Literal(atr_col_name, 'in', values_to_literal)
         p, n = self.count_p_n_literal(literal)
         foil = count_foil_grow(p0, n0, p, n)
         if foil > best_foil:
             best_foil = foil
             best_l = copy.deepcopy(literal)
         else:
             break
     return best_foil, best_l
コード例 #3
0
 def choose_best_literal(self, best_foil, best_l, n0, new_literal, old_rule,
                         p0, p_to_n, var):
     for i in range(0, len(p_to_n)):
         new_literal.append([var, p_to_n[i][0]])
         new_rule = old_rule + new_literal
         p, n = self.count_p_n_rule(new_rule)
         foil = count_foil_grow(p0, n0, p, n)
         if foil > best_foil:
             best_foil = foil
             best_l = copy.deepcopy(new_literal)
         else:
             break
     return best_foil, best_l
コード例 #4
0
 def prune_rule(self, rule):
     not_pruned_rule = copy.deepcopy(rule)
     for i in range(len(rule.literals) - 1, -1, -1):
         pruned_rule = copy.deepcopy(rule)
         pruned_rule.delete_literal(not_pruned_rule.literals[i])
         p, n = self.count_p_n_rule(rule)
         p0, n0 = self.count_p_n_rule(pruned_rule)
         if count_foil_grow(p0, n0, p, n) <= self.prune_param:
             rule.delete_literal(not_pruned_rule.literals[i])
     p, n = self.count_p_n_rule(rule)
     if p == 0 or n >= p or len(rule.literals) == 0:
         return None
     else:
         return rule
コード例 #5
0
    def prune_rule(self, rule):
        if len(rule) == 0:
            return None
        unique_atr = list(set([i[0] for i in rule]))
        len_rule_unique_atr = len(unique_atr)

        for i in range(len_rule_unique_atr - 1, -1, -1):
            pruned_rule = copy.deepcopy(rule)
            for j in range(len(rule) - 1, -1, -1):
                if rule[j][0] == unique_atr[i]:
                    del pruned_rule[j]
            p, n = self.count_p_n_rule(rule)
            p0, n0 = self.count_p_n_rule(pruned_rule)
            if count_foil_grow(p0, n0, p, n) <= self.prune_param:
                rule = copy.deepcopy(pruned_rule)
        if p == 0 or n >= p or len(rule) == 0:
            return None
        else:
            return rule
コード例 #6
0
 def grow_rule_inductive(self):
     best_rule = list()
     while True:
         best_foil = -math.inf
         best_l = None
         for i in range(0, len(self.col_val_tables)):
             for j in range(0, len(self.col_val_tables[i])):
                 p0, n0 = self.count_p_n_rule(best_rule)
                 new_rule = copy.deepcopy(best_rule)
                 new_rule.append([i, j])
                 p, n = self.count_p_n_rule(new_rule)
                 tmp_foil = count_foil_grow(p0, n0, p, n)
                 if tmp_foil > best_foil:
                     best_foil = tmp_foil
                     best_l = (i, j)
         if best_foil > self.grow_param:
             best_rule.append(best_l)
         else:
             break
     return best_rule
コード例 #7
0
 def test_count_foil_grow_all_zeros(self):
     self.assertEqual(-math.inf, count_foil_grow(0, 0, 0, 0))
コード例 #8
0
 def test_count_foil_grow_all_not_zero(self):
     self.assertEqual(round(1.7531, 4), round(count_foil_grow(7, 3, 6, 1),
                                              4))
コード例 #9
0
 def test_count_foil_grow_p_zero(self):
     self.assertEqual(-math.inf, count_foil_grow(6, 4, 0, 2))
コード例 #10
0
 def test_count_foil_grow_p0_and_n0_zeros(self):
     self.assertEqual(round(5.3333, 4), round(count_foil_grow(0, 0, 8, 4),
                                              4))