Example #1
0
def test_div():
    # Fake dataset
    Itemset.set_items(range(5))
    Itemset.clear_db()
    Itemset.set_db(0, [Transaction([0])])
    Itemset.set_db(1, [Transaction([1,2,3]),
                       Transaction([0,1,2]),
                       Transaction([1,2]),
                       Transaction([2,3,4]),
                      ])

    U = []
    U.append(Rule({0}, 1))
    U.append(Rule({2,3}, 1))
    U.append(Rule({2,3,4}, 1))
    U.append(Rule({1,2,3}, 1))
    U.append(Rule({1,2}, 1))

    greed = GreedyDiv(U, 10)
    rules = greed.greedy(2)
    print(rules[0].s)
    print(rules[1].s)
    assert rules[0] == Rule({1,2}, 1)
    assert rules[1] == Rule({2,3,4}, 1)

    U = []
    U.append(Rule({0}, 1))
    U.append(Rule({1,2}, 1))
    greed.update_univ(U)
    assert len(greed.U) == 1
    assert greed.U[0] == Rule({0}, 1)
Example #2
0
def test_decset():
    Itemset.set_items(range(5))
    Itemset.set_db(0, [Transaction([0])])
    Itemset.set_db(1, [Transaction([1,2,3]),
                       Transaction([0,1,2]),
                       Transaction([1,2]),
                       Transaction([2,3,4]),
                      ])

    dec = DecisionSet()
    dec.set_default(0)

    sel = [Rule([1,2], 1), Rule([3], 1)]
    for r in sel: dec.add(r)
    dec.build()

    assert dispersion(dec.rules) - (1-1/4) * 2 < 1e-8

    assert dec.predict(Transaction([1,2,3])) == True
    assert dec.predict(Transaction([1,2])) == True
    assert dec.predict(Transaction([2,3,4])) == True
    assert dec.predict(Transaction([4])) == False
    assert dec.predict(Transaction([0])) == False

    assert dec.predict_and_rule(Transaction([1,2,3])) == (True, Rule([1,2], 1))
Example #3
0
def test_sampling():
    Itemset.set_items(range(4))
    Itemset.clear_db()
    db0 = set([Transaction([0]),
           ])
    db1 = set([Transaction([1,2]),
               Transaction([1,3]),
               Transaction([1]),
           ])
    Itemset.set_db(0, db0)
    Itemset.set_db(1, db1)

    n = 3000
    nsamp, samp = sample(n, [db0], db1, {}, mode=3)
    print(nsamp)
    for s,cnt in samp.items():
        if frozenset([1]) == s:
            assert (3/7-0.05)*nsamp < cnt < (3/7+0.05)*nsamp
        if frozenset([1,2]) == s:
            assert (1/7-0.05)*nsamp < cnt < (1/7+0.05)*nsamp
        if frozenset([1,3]) == s:
            assert (1/7-0.05)*nsamp < cnt < (1/7+0.05)*nsamp
        if frozenset([2]) == s:
            assert (1/7-0.05)*nsamp < cnt < (1/7+0.05)*nsamp
        if frozenset([3]) == s:
            assert (1/7-0.05)*nsamp < cnt < (1/7+0.05)*nsamp

    sel = [Rule([1,3], 1)]
    covered = set([t for r in sel for t in r.trans()])
    nsamp, samp = sample(n, [db0], db1, covered, mode=3)
    print(nsamp)
    for s,cnt in samp.items():
        if frozenset([1,2]) == s:
            assert (0.5-0.1)*nsamp < cnt < (0.5+0.1)*nsamp
        if frozenset([2]) == s:
            assert (0.5-0.1)*nsamp < cnt < (0.5+0.1)*nsamp

    nsamp, samp = sample(n, [db0], db1, covered, mode=2)
    print(nsamp)
    for s,cnt in samp.items():
        if frozenset([1]) == s:
            assert (1/3-0.1)*nsamp < cnt < (1/3+0.1)*nsamp
        if frozenset([1,2]) == s:
            assert (1/3-0.1)*nsamp < cnt < (1/3+0.1)*nsamp
        if frozenset([2]) == s:
            assert (1/3-0.1)*nsamp < cnt < (1/3+0.1)*nsamp

    sel = [Rule([1], 1), Rule([1,3], 1)]
    covered = [t for r in sel for t in r.trans()]
    assert len(covered) == 4
    covered = set([t for r in sel for t in r.trans()])
    assert len(covered) == 3
Example #4
0
    def greedy_once(self) -> Itemset:
        '''
        Running this func changes the state of the current instance, i.e., its selection.
        '''
        if len(self.sel) == 0:
            j = np.argmax([Rule.quality([u]) for u in self.U])
            s = self.U[j]
            del self.U[j]
            self.sel.append(s)
            return s

        qs = [Rule.quality([u]+self.sel) for u in self.U]
        ds = [u.overlap(self.sel) for u in self.U]
        vals = 0.5 * np.array(qs) + self.lamb * np.array(ds)
        j = np.argmax(vals)
        s = self.U[j]
        del self.U[j]
        self.sel.append(s)
        return s
Example #5
0
 def sample_from_each_label(self, labels_samp, nsamp, covered, mode):
     # Sample rules from each label
     samps = []
     for label in labels_samp:
         db0 = [db for l, db in Itemset.dbs.items() if l != label]
         nsamp_, samp = sample(nsamp,
                               db0,
                               Itemset.dbs[label],
                               covered,
                               mode=mode)
         if nsamp_ == 0:
             continue
         samp = [Rule(s, label) for s in list(samp)]  # Very time-consuming
         samps = samps + samp
     return set(samps)
Example #6
0
def obj(rules, lamb, qtype='kl', sep=False) -> float:
    q = Rule.quality(rules, metric=qtype)
    d = dispersion(rules)
    if sep:
        return q, lamb * d
    return q + lamb * d
Example #7
0
    def train(self,
              X,
              Y,
              max_k=100,
              nsamp=100,
              lamb=None,
              q='kl',
              mode=3,
              rerun=True,
              min_recall_per_class=0.5):
        print('##### START #####')
        Itemset.clear_db()
        prep_db(X, Y)

        # Allow specify lamb to a certain number by users
        if type(lamb) == str or lamb is None:
            samp = self.sample_from_each_label(set(Itemset.labels), 100, set(),
                                               mode)
            if lamb == 'max':
                lamb = np.max([Rule.quality([r], metric=q) for r in samp])
            elif lamb == 'mean':
                lamb = np.mean([Rule.quality([r], metric=q) for r in samp])
            else:
                lamb = 0
            print('lamb:', lamb)

        greed = GreedyDiv([], lamb)
        U_all = []
        labels_samp = set(Itemset.labels)
        while len(self) < max_k and len(labels_samp) > 0:
            if mode == 0:
                samps = []
                for label in labels_samp:
                    _, samp = sample_rn(nsamp, label)
                    samp = [Rule(s, label)
                            for s in list(samp)]  # Very time-consuming
                    samps.extend(samp)
                U = set(samps)
            else:
                covered = set([t for r in self.rules for t in r.trans()])
                U = self.sample_from_each_label(labels_samp, nsamp, covered,
                                                mode)
            print('nsamp (after):', len(U))
            if len(U) == 0:
                break
            U_all.extend(U)

            # Greedy
            greed.update_univ(U)
            r = greed.greedy_once()
            # Termination criteria. Also check zero sampling above.
            if self.enough(r):
                # Include at least one rule per class, except default class.
                labels_samp.remove(r.label)
                print('remove label:', r.label)
            else:
                # Print quality vs. dispersion
                q, d = obj(self.rules, lamb, sep=True)
                qr, dr = obj(self.rules + [r], lamb, sep=True)
                print('inc q vs. d: {}, {}'.format(qr - q, dr - d))

                self.add(r)
                if np.abs(recall(self.rules)[r.label] - 1.0) < 1e-8:
                    labels_samp.remove(r.label)
                print('#{} '.format(len(self.rules)), end='')
                printRules([r])

        # Consecutive greedy over all sampels
        if rerun:
            greed.clear()
            greed.update_univ(list(set(U_all)))
            rules = greed.greedy(len(self.rules))
            if obj(rules, lamb) > obj(self.rules, lamb):
                print('Full greedy wins: {} > {}'.format(
                    obj(rules, lamb), obj(self.rules, lamb)))
                self.reset(rules)

        default = self.set_default()
        print('default:', default)

        self.build()

        print('precision: ', precision(self).items())
        print('recall (coverage): ', recall(self.rules).items())
        print('ave disp: ', dispersion(self.rules, average=True))
        print('##### END #####')
Example #8
0
 def _sort_rules(self, rules):
     #vals = np.array([Rule.quality([r]) for r in rules])
     vals = np.array([Rule.quality([r], metric='acc') for r in rules])
     idx = np.argsort(-vals)
     return [rules[i] for i in idx]
Example #9
0
def test_itemset():
    Itemset.set_items(range(5))
    Itemset.clear_db()
    ls = [0,1]
    Itemset.set_db(ls[0], [])
    Itemset.set_db(ls[1], [Transaction({1,2,3})])

    # test eq and hash
    r = Rule({1,2,3,4}, 1)
    assert len(r) == 4
    assert r == Rule({1,2,3,4}, 1)
    assert r != Rule({1,2,3,4}, 0)
    rules = [Rule({1,2,3,4}, 1), Rule({1,2,3,4}, 1)]
    l = list(set(rules))
    assert len(l) == 1
    assert l[0] == Rule({1,2,3,4}, 1)
    rules = [Rule({1,2,3,4}, 0), Rule({1,2,3,4}, 1)]
    l = list(set(rules))
    assert len(l) == 2

    assert r.support(ls) == 0
    assert Rule({4}, 1).support(ls) == 0
    assert Rule({2}, 1).support(ls) == 1

    assert r.overlap([Rule({1,2}, 1)]) - 1 < 1e-8
    assert r.overlap([Rule({1,2}, 0)]) - 1 < 1e-8
    assert Rule({1,2,3}, 1).overlap([Rule({1,2}, 1)]) - 0 < 1e-8
    
    assert r.itemdiff(Rule({1,2}, 1)) == frozenset([3,4])

    Itemset.set_db(0, [Transaction({0}),
                       Transaction({0}),
                       ])
    Itemset.set_db(1, [Transaction({0}),
                       Transaction({0}),
                       ])
    r = Rule({0}, 1)
    assert len(r.coverage(ls)) == 4
    assert len(r.coverage([0])) == 2
    assert len(r.coverage([1])) == 2

    with pytest.raises(Exception):
        Rule({1, 2, 3, 4, 5, 6}, 1)