def test_div(): # Fake dataset Itemset.set_items(range(5)) Itemset.clear_db() Itemset.set_db(0, [Transaction([0])]) Itemset.set_db(1, [Transaction([1,2,3]), Transaction([0,1,2]), Transaction([1,2]), Transaction([2,3,4]), ]) U = [] U.append(Rule({0}, 1)) U.append(Rule({2,3}, 1)) U.append(Rule({2,3,4}, 1)) U.append(Rule({1,2,3}, 1)) U.append(Rule({1,2}, 1)) greed = GreedyDiv(U, 10) rules = greed.greedy(2) print(rules[0].s) print(rules[1].s) assert rules[0] == Rule({1,2}, 1) assert rules[1] == Rule({2,3,4}, 1) U = [] U.append(Rule({0}, 1)) U.append(Rule({1,2}, 1)) greed.update_univ(U) assert len(greed.U) == 1 assert greed.U[0] == Rule({0}, 1)
def test_decset(): Itemset.set_items(range(5)) Itemset.set_db(0, [Transaction([0])]) Itemset.set_db(1, [Transaction([1,2,3]), Transaction([0,1,2]), Transaction([1,2]), Transaction([2,3,4]), ]) dec = DecisionSet() dec.set_default(0) sel = [Rule([1,2], 1), Rule([3], 1)] for r in sel: dec.add(r) dec.build() assert dispersion(dec.rules) - (1-1/4) * 2 < 1e-8 assert dec.predict(Transaction([1,2,3])) == True assert dec.predict(Transaction([1,2])) == True assert dec.predict(Transaction([2,3,4])) == True assert dec.predict(Transaction([4])) == False assert dec.predict(Transaction([0])) == False assert dec.predict_and_rule(Transaction([1,2,3])) == (True, Rule([1,2], 1))
def test_sampling(): Itemset.set_items(range(4)) Itemset.clear_db() db0 = set([Transaction([0]), ]) db1 = set([Transaction([1,2]), Transaction([1,3]), Transaction([1]), ]) Itemset.set_db(0, db0) Itemset.set_db(1, db1) n = 3000 nsamp, samp = sample(n, [db0], db1, {}, mode=3) print(nsamp) for s,cnt in samp.items(): if frozenset([1]) == s: assert (3/7-0.05)*nsamp < cnt < (3/7+0.05)*nsamp if frozenset([1,2]) == s: assert (1/7-0.05)*nsamp < cnt < (1/7+0.05)*nsamp if frozenset([1,3]) == s: assert (1/7-0.05)*nsamp < cnt < (1/7+0.05)*nsamp if frozenset([2]) == s: assert (1/7-0.05)*nsamp < cnt < (1/7+0.05)*nsamp if frozenset([3]) == s: assert (1/7-0.05)*nsamp < cnt < (1/7+0.05)*nsamp sel = [Rule([1,3], 1)] covered = set([t for r in sel for t in r.trans()]) nsamp, samp = sample(n, [db0], db1, covered, mode=3) print(nsamp) for s,cnt in samp.items(): if frozenset([1,2]) == s: assert (0.5-0.1)*nsamp < cnt < (0.5+0.1)*nsamp if frozenset([2]) == s: assert (0.5-0.1)*nsamp < cnt < (0.5+0.1)*nsamp nsamp, samp = sample(n, [db0], db1, covered, mode=2) print(nsamp) for s,cnt in samp.items(): if frozenset([1]) == s: assert (1/3-0.1)*nsamp < cnt < (1/3+0.1)*nsamp if frozenset([1,2]) == s: assert (1/3-0.1)*nsamp < cnt < (1/3+0.1)*nsamp if frozenset([2]) == s: assert (1/3-0.1)*nsamp < cnt < (1/3+0.1)*nsamp sel = [Rule([1], 1), Rule([1,3], 1)] covered = [t for r in sel for t in r.trans()] assert len(covered) == 4 covered = set([t for r in sel for t in r.trans()]) assert len(covered) == 3
def greedy_once(self) -> Itemset: ''' Running this func changes the state of the current instance, i.e., its selection. ''' if len(self.sel) == 0: j = np.argmax([Rule.quality([u]) for u in self.U]) s = self.U[j] del self.U[j] self.sel.append(s) return s qs = [Rule.quality([u]+self.sel) for u in self.U] ds = [u.overlap(self.sel) for u in self.U] vals = 0.5 * np.array(qs) + self.lamb * np.array(ds) j = np.argmax(vals) s = self.U[j] del self.U[j] self.sel.append(s) return s
def sample_from_each_label(self, labels_samp, nsamp, covered, mode): # Sample rules from each label samps = [] for label in labels_samp: db0 = [db for l, db in Itemset.dbs.items() if l != label] nsamp_, samp = sample(nsamp, db0, Itemset.dbs[label], covered, mode=mode) if nsamp_ == 0: continue samp = [Rule(s, label) for s in list(samp)] # Very time-consuming samps = samps + samp return set(samps)
def obj(rules, lamb, qtype='kl', sep=False) -> float: q = Rule.quality(rules, metric=qtype) d = dispersion(rules) if sep: return q, lamb * d return q + lamb * d
def train(self, X, Y, max_k=100, nsamp=100, lamb=None, q='kl', mode=3, rerun=True, min_recall_per_class=0.5): print('##### START #####') Itemset.clear_db() prep_db(X, Y) # Allow specify lamb to a certain number by users if type(lamb) == str or lamb is None: samp = self.sample_from_each_label(set(Itemset.labels), 100, set(), mode) if lamb == 'max': lamb = np.max([Rule.quality([r], metric=q) for r in samp]) elif lamb == 'mean': lamb = np.mean([Rule.quality([r], metric=q) for r in samp]) else: lamb = 0 print('lamb:', lamb) greed = GreedyDiv([], lamb) U_all = [] labels_samp = set(Itemset.labels) while len(self) < max_k and len(labels_samp) > 0: if mode == 0: samps = [] for label in labels_samp: _, samp = sample_rn(nsamp, label) samp = [Rule(s, label) for s in list(samp)] # Very time-consuming samps.extend(samp) U = set(samps) else: covered = set([t for r in self.rules for t in r.trans()]) U = self.sample_from_each_label(labels_samp, nsamp, covered, mode) print('nsamp (after):', len(U)) if len(U) == 0: break U_all.extend(U) # Greedy greed.update_univ(U) r = greed.greedy_once() # Termination criteria. Also check zero sampling above. if self.enough(r): # Include at least one rule per class, except default class. labels_samp.remove(r.label) print('remove label:', r.label) else: # Print quality vs. dispersion q, d = obj(self.rules, lamb, sep=True) qr, dr = obj(self.rules + [r], lamb, sep=True) print('inc q vs. d: {}, {}'.format(qr - q, dr - d)) self.add(r) if np.abs(recall(self.rules)[r.label] - 1.0) < 1e-8: labels_samp.remove(r.label) print('#{} '.format(len(self.rules)), end='') printRules([r]) # Consecutive greedy over all sampels if rerun: greed.clear() greed.update_univ(list(set(U_all))) rules = greed.greedy(len(self.rules)) if obj(rules, lamb) > obj(self.rules, lamb): print('Full greedy wins: {} > {}'.format( obj(rules, lamb), obj(self.rules, lamb))) self.reset(rules) default = self.set_default() print('default:', default) self.build() print('precision: ', precision(self).items()) print('recall (coverage): ', recall(self.rules).items()) print('ave disp: ', dispersion(self.rules, average=True)) print('##### END #####')
def _sort_rules(self, rules): #vals = np.array([Rule.quality([r]) for r in rules]) vals = np.array([Rule.quality([r], metric='acc') for r in rules]) idx = np.argsort(-vals) return [rules[i] for i in idx]
def test_itemset(): Itemset.set_items(range(5)) Itemset.clear_db() ls = [0,1] Itemset.set_db(ls[0], []) Itemset.set_db(ls[1], [Transaction({1,2,3})]) # test eq and hash r = Rule({1,2,3,4}, 1) assert len(r) == 4 assert r == Rule({1,2,3,4}, 1) assert r != Rule({1,2,3,4}, 0) rules = [Rule({1,2,3,4}, 1), Rule({1,2,3,4}, 1)] l = list(set(rules)) assert len(l) == 1 assert l[0] == Rule({1,2,3,4}, 1) rules = [Rule({1,2,3,4}, 0), Rule({1,2,3,4}, 1)] l = list(set(rules)) assert len(l) == 2 assert r.support(ls) == 0 assert Rule({4}, 1).support(ls) == 0 assert Rule({2}, 1).support(ls) == 1 assert r.overlap([Rule({1,2}, 1)]) - 1 < 1e-8 assert r.overlap([Rule({1,2}, 0)]) - 1 < 1e-8 assert Rule({1,2,3}, 1).overlap([Rule({1,2}, 1)]) - 0 < 1e-8 assert r.itemdiff(Rule({1,2}, 1)) == frozenset([3,4]) Itemset.set_db(0, [Transaction({0}), Transaction({0}), ]) Itemset.set_db(1, [Transaction({0}), Transaction({0}), ]) r = Rule({0}, 1) assert len(r.coverage(ls)) == 4 assert len(r.coverage([0])) == 2 assert len(r.coverage([1])) == 2 with pytest.raises(Exception): Rule({1, 2, 3, 4, 5, 6}, 1)