Ejemplo n.º 1
0
 def test_confidence(self):
     db = CsvDatabase(self.path)
     a1 = AttributeValue('age', 'young')
     a2 = AttributeValue('spectacle-prescrip', 'myope')
     a3 = AttributeValue('astigmatism', 'yes')
     set1 = ItemSet.create_itemset(a1, a2)
     set2 = ItemSet.create_itemset(a3)
     self.assertEqual(db.confidence(set1, set2), 2 / 4)
Ejemplo n.º 2
0
 def test_support_count(self):
     db = CsvDatabase(self.path)
     a1 = AttributeValue('age', 'young')
     a2 = AttributeValue('spectacle-prescrip', 'myope')
     set1 = ItemSet.create_itemset(a1)
     self.assertEqual(db.support_count(set1), 8)
     set2 = ItemSet.create_itemset(a1, a2)
     self.assertEqual(db.support_count(set2), 4)
     self.assertEqual(db.counter, 5)
Ejemplo n.º 3
0
 def test_distinct_attr_value(self):
     db = CsvDatabase(self.path)
     attr_values = db.get_distinct_attr_values()
     self.assertEqual(len(attr_values), 12)
     self._check('age', 'young', attr_values)
     self._check('age', 'pre-presbyopic', attr_values)
     self._check('age', 'presbyopic', attr_values)
     self._check('spectacle-prescrip', 'myope', attr_values)
     self._check('spectacle-prescrip', 'hypermetrope', attr_values)
Ejemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', '-i', help='input file (csv file)')
    parser.add_argument('--minsup', '-m', type=float, help='minimum support')
    parser.add_argument('--minconf',
                        '-c',
                        type=float,
                        help='minimun confidence')
    parser.add_argument('--numrule',
                        '-n',
                        type=int,
                        help='max number of rules output')
    args = parser.parse_args()

    db = CsvDatabase(args.input)
    Factory.setup_db(db)
    apriori = Apriori(db, args.minsup, args.minconf, args.numrule)
    frequent_itemsets = apriori.generate_frequent_itemset()
    for level in frequent_itemsets:
        print(level, ': ', len(frequent_itemsets[level]))
        for itemset in frequent_itemsets[level]:
            print(itemset)
    print('Rules:')
    for rule in apriori.generate_all_confidence_rules():
        print(rule, '(confidence: %.4f)' % (rule.confidence))
Ejemplo n.º 5
0
 def test_generate_all_rule(self):
     db = CsvDatabase(self.path)
     apriori = Apriori(db, 0.3, 0.3, 10)
     confident_rules = apriori.generate_all_confidence_rules()
     self.assertEqual(len(confident_rules), 6)
     apriori = Apriori(db, 0.3, 0.3, 5)
     confident_rules = apriori.generate_all_confidence_rules()
     self.assertEqual(len(confident_rules), 5)
Ejemplo n.º 6
0
 def test_generate_rule_confidence(self):
     db = CsvDatabase(self.path)
     a1 = AttributeValue('age', 'young')
     a2 = AttributeValue('spectacle-prescrip', 'myope')
     a3 = AttributeValue('astigmatism', 'yes')
     a4 = AttributeValue('tear-prod-rate', 'reduced')
     itemset = ItemSet.create_itemset(a1, a2, a3, a4)
     apriori = Apriori(db, 0, 0.5, 100)
     rules = apriori.generate_confident_rules(itemset)
     self.assertEqual(len(rules), 3)
Ejemplo n.º 7
0
def main():
    db = CsvDatabase('test_data/zero_one.csv')
    Factory.setup_db(db)
    runtimes = []  # store runtime (in second)
    numrules = []  # store number of confident rules
    for support in np.linspace(low_support, high_support, interval):
        start_time = time.monotonic()
        apriori = Apriori(db, support, confidence, None)
        rules = apriori.generate_all_confidence_rules()
        #for rule in rules:
        #  print(rule, ': ', rule.confidence)
        end_time = time.monotonic()
        runtimes.append(end_time - start_time)
        numrule = len(rules)
        numrules.append(numrule)
        print('Support %.2f, runtime %.9f, numrules %d' %
              (support, end_time - start_time, numrule))
    graph_runtime(runtimes)
    graph_numrules(numrules)
Ejemplo n.º 8
0
 def setUp(self):
     db = CsvDatabase(
         os.path.join(os.getcwd(), 'test_data/contact-lenses.csv'))
     Factory.setup_db(db)
Ejemplo n.º 9
0
 def setUp(self):
     Factory.database = CsvDatabase(
         os.path.join(os.getcwd(), 'test_data/contact-lenses.csv'))