def setUp(self):
     self._test_path = 'test_annotations_cleaned.csv'
     self._database = Database(self._test_path)
     self._labels = fast_strong_cluster(self._database)
     self._blocking = BlockingScheme(self._database, single_block=True)
     self._er = EntityResolution()
     decision_threshold = 1.0
     pair_seed = generate_pair_seed(self._database, self._labels, 0.5)
     self._match_function = LogisticMatchFunction(self._database, self._labels, pair_seed, decision_threshold)
示例#2
0
 def run(self):
     """
     Runs ER at all thresholds
     :return predicted_labels: List of lists of predicted labels.
                               predicted_labels[threshold_index] = dict [identifier, cluster label]
     :return metrics: List of lists of metric objects.
                      metrics[threshold_index] = Metrics object
     :return er_objects: List of EntityResolution objects.
                         er_objects[threshold_index] = EntityResolution
     :return new_metrics_objects: List of NewMetrics objects.
                                 new_metrics_objects[threshold_index] = NewMetrics
     """
     er = EntityResolution()
     #weak_match_function = LogisticMatchFunction(self._database_train, self._labels_train, self._train_pair_seed, 0.5)
     weak_match_function = ForestMatchFunction(self._database_train, self._labels_train, self._train_pair_seed, 0.5)
     print 'Testing pairwise match function on test database'
     ROC = weak_match_function.test(self._database_validation, self._labels_validation, self._validation_seed)
     #ROC.make_plot()
     metrics_list = list()
     labels_list = list()
     new_metrics_list = list()
     class_balance_test = count_pairwise_class_balance(self._labels_test)
     blocks = BlockingScheme(self._database_test, single_block=True)
     for threshold in self.thresholds:
         print 'Running entity resolution at threshold =', threshold
         weak_match_function.set_decision_threshold(threshold)
         labels_pred = weak_connected_components(self._database_test, weak_match_function, blocks)
         #labels_pred = er.run(self._database_test, weak_match_function, single_block=True, max_block_size=np.Inf,
         #                     cores=1)
         metrics_list.append(Metrics(self._labels_test, labels_pred))
         new_metrics_list.append(NewMetrics(self._database_test, labels_pred, weak_match_function, class_balance_test))
         labels_list.append(labels_pred)
     return labels_list, metrics_list, new_metrics_list
 def test_completeness(self):
     database = Database('test_annotations_10000_cleaned.csv', max_records=1000, header_path='test_annotations_10000_cleaned_header.csv')
     database_train = database.sample_and_remove(800)
     database_test = database
     labels_train = fast_strong_cluster(database_train)
     labels_test = fast_strong_cluster(database_test)
     er = EntityResolution()
     pair_seed = generate_pair_seed(database_train, labels_train, 0.5)
     match_function = LogisticMatchFunction(database_train, labels_train, pair_seed, 0.99)
     blocking_scheme = BlockingScheme(database_test)
     labels_pred = er.run(database_test, match_function, blocking_scheme, cores=2)
     number_fast_strong_records = len(labels_train) + len(labels_test)
     self.assertEqual(number_fast_strong_records, 1000)
     self.assertEqual(sorted((labels_train.keys() + labels_test.keys())), range(0, 1000))
     number_swoosh_records = len(get_ids(database_test.records))
     self.assertEqual(number_swoosh_records, len(database_test.records))
     self.assertEqual(get_ids(database_test.records), sorted(labels_test.keys()))
     self.assertEqual(get_ids(database_test.records), sorted(labels_pred.keys()))
 def setUp(self):
     self._test_path = 'test_annotations_cleaned.csv'
     self._database = Database(self._test_path)
     self._blocking = BlockingScheme(self._database)
     self._er = EntityResolution()
class MyTestCase(unittest.TestCase):
    def setUp(self):
        self._test_path = 'test_annotations_cleaned.csv'
        self._database = Database(self._test_path)
        self._labels = fast_strong_cluster(self._database)
        self._blocking = BlockingScheme(self._database, single_block=True)
        self._er = EntityResolution()
        decision_threshold = 1.0
        pair_seed = generate_pair_seed(self._database, self._labels, 0.5)
        self._match_function = LogisticMatchFunction(self._database, self._labels, pair_seed, decision_threshold)

    def test_run(self):
        strong_clusters = fast_strong_cluster(self._database)
        database_copy = deepcopy(self._database)
        database_copy.merge(strong_clusters)
        blocking = BlockingScheme(database_copy, single_block=True)
        labels = self._er.run(database_copy, self._match_function, blocking, cores=2)
        database_copy.merge(labels)
        entities = set()
        for _, entity in database_copy.records.iteritems():
            entities.add(entity)
        r0 = self._database.records[0]
        r1 = self._database.records[1]
        r2 = self._database.records[2]
        r3 = self._database.records[3]
        r0.merge(r1)
        r0.merge(r3)
        manual = {r0, r2}
        self.assertTrue(test_object_set(manual, entities))

    def test_rswoosh(self):
        strong_clusters = fast_strong_cluster(self._database)
        database_copy = deepcopy(self._database)
        database_copy.merge(strong_clusters)
        records = set()
        for _, record in database_copy.records.iteritems():
            records.add(record)
        self._er._match_function = self._match_function
        swooshed = self._er.rswoosh(records)
        # Compare to manually merged records
        r0 = self._database.records[0]
        r1 = self._database.records[1]
        r2 = self._database.records[2]
        r3 = self._database.records[3]
        r1.merge(r3)
        r0.merge(r1)
        merged = {r0, r2}
        self.assertEqual(len(swooshed), len(merged))
        self.assertTrue(test_object_set(merged, swooshed))

    def test_merge_duped_records(self):
        """
        Merges all entities containing the same record identifier
        """
        strong_clusters = fast_strong_cluster(self._database)
        database_copy = deepcopy(self._database)
        database_copy.merge(strong_clusters)
        self._er._match_function = self._match_function
        records = set()
        for _, record in database_copy.records.iteritems():
            records.add(record)
        swooshed = self._er.rswoosh(records)
        # Compare to manually constructed clusters with duplicates
        r0 = self._database.records[0]
        r1 = self._database.records[1]
        r2 = self._database.records[2]
        r3 = self._database.records[3]
        r0.merge(r1)
        r1.merge(r3)
        premerge = {0: r0,
                    1: r1,
                    2: r2,
                    3: r3}
        merged = merge_duped_records(premerge)
        self.assertEqual(len(merged), len(swooshed))
        self.assertTrue(test_object_set(merged, swooshed))

    def test_deep_copy(self):
        records_copy = deepcopy(self._database.records)
        r1 = records_copy[0]
        self.assertEqual(r1, self._database.records[0])
        r1.features[0].add('Santa Clause')
        self.assertNotEqual(r1, self._database.records[0])

    def test_completeness(self):
        database = Database('test_annotations_10000_cleaned.csv', max_records=1000, header_path='test_annotations_10000_cleaned_header.csv')
        database_train = database.sample_and_remove(800)
        database_test = database
        labels_train = fast_strong_cluster(database_train)
        labels_test = fast_strong_cluster(database_test)
        er = EntityResolution()
        pair_seed = generate_pair_seed(database_train, labels_train, 0.5)
        match_function = LogisticMatchFunction(database_train, labels_train, pair_seed, 0.99)
        blocking_scheme = BlockingScheme(database_test)
        labels_pred = er.run(database_test, match_function, blocking_scheme, cores=2)
        number_fast_strong_records = len(labels_train) + len(labels_test)
        self.assertEqual(number_fast_strong_records, 1000)
        self.assertEqual(sorted((labels_train.keys() + labels_test.keys())), range(0, 1000))
        number_swoosh_records = len(get_ids(database_test.records))
        self.assertEqual(number_swoosh_records, len(database_test.records))
        self.assertEqual(get_ids(database_test.records), sorted(labels_test.keys()))
        self.assertEqual(get_ids(database_test.records), sorted(labels_pred.keys()))

    def test_fast_strong_cluster(self):
        labels_pred = fast_strong_cluster(self._database)
        labels_true = {
            0: 0,
            1: 0,
            2: 1,
            3: 0
        }
        self.assertEqual(labels_pred, labels_true)

    def test_fast_strong_cluster_large(self):
        database = Database('test_annotations_10000_cleaned.csv', max_records=1000, header_path='test_annotations_10000_cleaned_header.csv')
        database_train = database.sample_and_remove(800)
        database_test = database
        labels_train = fast_strong_cluster(database_train)
        labels_test = fast_strong_cluster(database_test)
        self.assertEqual(len(labels_train), len(database_train.records))
        self.assertEqual(len(labels_test), len(database_test.records))
示例#6
0
def synthetic_sizes():
    """
    Sizes experiment here
    """
    resolution = 88
    number_features = 10
    number_entities = np.linspace(10, 100, num=resolution)
    number_entities = number_entities.astype(int)
    records_per_entity = 10
    #train_database_size = 100
    train_class_balance = 0.5
    #validation_database_size = 100
    corruption_multiplier = .001

    databases = list()
    db = SyntheticDatabase(number_entities[0], records_per_entity, number_features=number_features)
    databases.append(deepcopy(db))
    add_entities = [x - number_entities[i - 1] for i, x in enumerate(number_entities)][1:]
    for add in add_entities:
        db.add(add, records_per_entity)
        databases.append(deepcopy(db))
    corruption = np.random.normal(loc=0.0, scale=1.0, size=[number_entities[-1]*records_per_entity, number_features])
    train = deepcopy(databases[0])
    validation = deepcopy(databases[0])
    train.corrupt(corruption_multiplier*np.random.normal(loc=0.0, scale=1.0, size=[len(train.database.records), number_features]))
    validation.corrupt(corruption_multiplier*np.random.normal(loc=0.0, scale=1.0, size=[len(train.database.records), number_features]))
    for db in databases:
        db.corrupt(corruption_multiplier*corruption[:len(db.database.records), :])
    er = EntityResolution()
    train_pair_seed = generate_pair_seed(train.database, train.labels, train_class_balance)
    weak_match_function = LogisticMatchFunction(train.database, train.labels, train_pair_seed, 0.5)
    ROC = weak_match_function.test(validation.database, validation.labels, 0.5)
    #ROC.make_plot()

    ## Optimize ER on small dataset
    thresholds = np.linspace(0, 1.0, 10)
    metrics_list = list()
    #new_metrics_list = list()
    pairwise_precision = list()
    pairwise_recall = list()
    pairwise_f1 = list()
    for threshold in thresholds:
        weak_match_function.set_decision_threshold(threshold)
        labels_pred = er.run(deepcopy(databases[0].database), weak_match_function, single_block=True,
                             max_block_size=np.Inf, cores=1)
        met = Metrics(databases[0].labels, labels_pred)
        metrics_list.append(met)
        pairwise_precision.append(met.pairwise_precision)
        pairwise_recall.append(met.pairwise_recall)
        pairwise_f1.append(met.pairwise_f1)
        #class_balance_test = get_pairwise_class_balance(databases[0].labels)
        #new_metrics_list.append(NewMetrics(databases[0].database, er, class_balance_test))
    plt.plot(thresholds, pairwise_precision, label='Precision')
    plt.plot(thresholds, pairwise_recall, label='Recall')
    plt.plot(thresholds, pairwise_f1, label='F1')
    plt.xlabel('Threshold')
    plt.legend()
    plt.ylabel('Score')
    plt.title('Optimizing ER on small dataset')
    #i = np.argmax(np.array(pairwise_f1))
    #small_optimal_threshold = thresholds[i]  # optimize this
    small_optimal_threshold = 0.6
    print 'Optimal small threshold set at =', small_optimal_threshold
    plt.show()

    ## Possible score by optimizing on larger dataset
    metrics_list = list()
    pairwise_precision = list()
    pairwise_recall = list()
    pairwise_f1 = list()
    thresholds_largedataset = np.linspace(0.6, 1.0, 8)
    precision_lower_bound = list()
    recall_lower_bound = list()
    f1_lower_bound = list()
    for threshold in thresholds_largedataset:
        weak_match_function.set_decision_threshold(threshold)
        labels_pred = er.run(deepcopy(databases[-1].database), weak_match_function, single_block=True,
                             max_block_size=np.Inf, cores=1)
        met = Metrics(databases[-1].labels, labels_pred)
        metrics_list.append(met)
        pairwise_precision.append(met.pairwise_precision)
        pairwise_recall.append(met.pairwise_recall)
        pairwise_f1.append(met.pairwise_f1)
        class_balance_test = count_pairwise_class_balance(databases[-1].labels)
        new_metric = NewMetrics(databases[-1].database, labels_pred, weak_match_function, class_balance_test)
        precision_lower_bound.append(new_metric.precision_lower_bound)
        recall_lower_bound.append(new_metric.recall_lower_bound)
        f1_lower_bound.append(new_metric.f1_lower_bound)
    plt.plot(thresholds_largedataset, pairwise_precision, label='Precision', color='r')
    plt.plot(thresholds_largedataset, pairwise_recall, label='Recall', color='b')
    plt.plot(thresholds_largedataset, pairwise_f1, label='F1', color='g')
    plt.plot(thresholds_largedataset, precision_lower_bound, label='Precision Bound', color='r', linestyle=':')
    plt.plot(thresholds_largedataset, recall_lower_bound, label='Recall Bound', color='b', linestyle=':')
    plt.plot(thresholds_largedataset, f1_lower_bound, label='F1 Bound', color='g', linestyle=':')
    i = np.argmax(np.array(f1_lower_bound))
    large_optimal_threshold = thresholds_largedataset[i]
    print 'Optimal large threshold automatically set at =', large_optimal_threshold
    print 'If not correct: debug.'
    plt.xlabel('Threshold')
    plt.legend()
    plt.ylabel('Score')
    plt.title('Optimizing ER on large dataset')
    plt.show()

    ## Run on all dataset sizes
    #new_metrics_list = list()
    database_sizes = list()
    small_pairwise_precision = list()
    small_pairwise_recall = list()
    small_pairwise_f1 = list()
    large_precision_bound = list()
    large_precision_bound_lower_ci = list()
    large_precision_bound_upper_ci = list()
    large_precision = list()
    large_recall_bound = list()
    large_recall_bound_lower_ci = list()
    large_recall_bound_upper_ci = list()
    large_recall = list()
    large_f1 = list()
    large_f1_bound = list()
    for db in databases:
        print 'Analyzing synthetic database with', len(db.database.records), 'records'
        database_sizes.append(len(db.database.records))
        weak_match_function.set_decision_threshold(small_optimal_threshold)
        labels_pred = er.run(db.database, weak_match_function, single_block=True, max_block_size=np.Inf, cores=1)
        met = Metrics(db.labels, labels_pred)
        small_pairwise_precision.append(met.pairwise_precision)
        small_pairwise_recall.append(met.pairwise_recall)
        small_pairwise_f1.append(met.pairwise_f1)
        weak_match_function.set_decision_threshold(large_optimal_threshold)
        labels_pred = er.run(db.database, weak_match_function, single_block=True, max_block_size=np.Inf, cores=1)
        met = Metrics(db.labels, labels_pred)
        large_precision.append(met.pairwise_precision)
        large_recall.append(met.pairwise_recall)
        large_f1.append(met.pairwise_f1)
        class_balance_test = count_pairwise_class_balance(db.labels)
        new_metric = NewMetrics(db.database, labels_pred, weak_match_function, class_balance_test)
        large_precision_bound.append(new_metric.precision_lower_bound)
        large_recall_bound.append(new_metric.recall_lower_bound)
        large_f1_bound.append(new_metric.f1_lower_bound)
        large_precision_bound_lower_ci.append(new_metric.precision_lower_bound_lower_ci)
        large_precision_bound_upper_ci.append(new_metric.precision_lower_bound_upper_ci)
        large_recall_bound_lower_ci.append(new_metric.recall_lower_bound_lower_ci)
        large_recall_bound_upper_ci.append(new_metric.recall_lower_bound_upper_ci)

    with open('synthetic_sizes_temp.csv', 'wb') as f:
        f.write('Database size, Precision (small opt), Recall (small opt), F1 (small opt), Precision (large opt), Precision bound (large opt), Lower CI, Upper CI, Recall (large opt), Recall bound (large opt), Lower CI, Upper CI, F1 (large opt), F1 bound (large opt)\n')
        writer = csv.writer(f)
        writer.writerows(izip(database_sizes, small_pairwise_precision, small_pairwise_recall, small_pairwise_f1, large_precision, large_precision_bound, large_precision_bound_lower_ci, large_precision_bound_upper_ci, large_recall, large_recall_bound, large_recall_bound_lower_ci, large_recall_bound_upper_ci, large_f1, large_f1_bound))
    f.close()
    plt.figure()
    plt.plot(database_sizes, pairwise_precision, label='Precision', color='#4477AA', linewidth=3)
    plt.plot(database_sizes, pairwise_recall, label='Recall', color='#CC6677', linewidth=3)
    #plt.plot(database_sizes, pairwise_f1, label='F1', color='#DDCC77', linewidth=2)
    plt.ylim([0, 1.05])
    plt.yticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
    plt.legend(title='Pairwise:', loc='lower left')
    plt.xlabel('Number of Records')
    plt.ylabel('Pairwise Score')
    plt.title('Performance Degredation')
    plt.show()