コード例 #1
0
 def test_test(self):
     database = Database('test_annotations_10000_cleaned.csv', header_path='test_annotations_10000_cleaned_header.csv')
     database_train = database.sample_and_remove(5000)
     database_test = database
     labels_train = fast_strong_cluster(database_train)
     labels_test = fast_strong_cluster(database_test)
     train_seed = generate_pair_seed(database_train, labels_train, 0.5)
     match_function = LogisticMatchFunction(database_train, labels_train, train_seed, 0.7)
     roc = match_function.test(database_test, labels_test, 0.5)
     roc.make_plot()
コード例 #2
0
 def setUp(self):
     self._test_path = 'test_annotations_cleaned.csv'
     self._database = Database(self._test_path)
     self._labels = fast_strong_cluster(self._database)
     self._blocking = BlockingScheme(self._database, single_block=True)
     self._er = EntityResolution()
     decision_threshold = 1.0
     pair_seed = generate_pair_seed(self._database, self._labels, 0.5)
     self._match_function = LogisticMatchFunction(self._database, self._labels, pair_seed, decision_threshold)
コード例 #3
0
 def test_completeness(self):
     database = Database('test_annotations_10000_cleaned.csv', max_records=1000, header_path='test_annotations_10000_cleaned_header.csv')
     database_train = database.sample_and_remove(800)
     database_test = database
     labels_train = fast_strong_cluster(database_train)
     labels_test = fast_strong_cluster(database_test)
     er = EntityResolution()
     pair_seed = generate_pair_seed(database_train, labels_train, 0.5)
     match_function = LogisticMatchFunction(database_train, labels_train, pair_seed, 0.99)
     blocking_scheme = BlockingScheme(database_test)
     labels_pred = er.run(database_test, match_function, blocking_scheme, cores=2)
     number_fast_strong_records = len(labels_train) + len(labels_test)
     self.assertEqual(number_fast_strong_records, 1000)
     self.assertEqual(sorted((labels_train.keys() + labels_test.keys())), range(0, 1000))
     number_swoosh_records = len(get_ids(database_test.records))
     self.assertEqual(number_swoosh_records, len(database_test.records))
     self.assertEqual(get_ids(database_test.records), sorted(labels_test.keys()))
     self.assertEqual(get_ids(database_test.records), sorted(labels_pred.keys()))
コード例 #4
0
 def setUp(self):
     self._database = Database('test_annotations_cleaned.csv')
     labels = fast_strong_cluster(self._database)
     pair_seed = generate_pair_seed(self._database, labels, 0.5)
     self._blocking = BlockingScheme(self._database)
     self._match_function = LogisticMatchFunction(self._database, labels, pair_seed, 0.5)
コード例 #5
0
class MyTestCase(unittest.TestCase):
    def setUp(self):
        self._database = Database('test_annotations_cleaned.csv')
        labels = fast_strong_cluster(self._database)
        pair_seed = generate_pair_seed(self._database, labels, 0.5)
        self._blocking = BlockingScheme(self._database)
        self._match_function = LogisticMatchFunction(self._database, labels, pair_seed, 0.5)

    def test_pairs(self):
        database = Database('test_annotations_10000_cleaned.csv', header_path='test_annotations_10000_cleaned_header.csv')
        labels = fast_strong_cluster(database)
        pair_seed = generate_pair_seed(database, labels, 0.5)
        # x1_a, x2_a, m_a = _get_pairs(database, labels, 10, balancing=True)
        # x1_b, x2_b, m_b = _get_pairs(database, labels, 10, balancing=True)
        # self.assertNotEqual(x1_a, x1_b)
        # self.assertNotEqual(x2_a, x2_b)
        # self.assertNotEqual(m_a, m_b)
        x1_a, x2_a, m_a = get_pairwise_features(database, labels, pair_seed)
        x1_b, x2_b, m_b = get_pairwise_features(database, labels, pair_seed)
        np.testing.assert_array_equal(x1_a, x1_b)
        np.testing.assert_array_equal(x2_a, x2_b)
        np.testing.assert_array_equal(m_a, m_b)

    def test_mean_imputation(self):
        x = np.array([[1, 2, 3, 4], [np.NaN, 4, 5, np.NaN], [1, 6, np.NaN, np.NaN]])
        m = mean_imputation(x)
        self.assertTrue((m == np.array([1, 4, 4, 4])).all())

    def test_match(self):
        r0 = self._database.records[0]
        r1 = self._database.records[1]
        r2 = self._database.records[2]
        r3 = self._database.records[3]
        labels = {
            0: 0,
            1: 0,
            2: 1,
            3: 1
        }
        pair_seed = generate_pair_seed(self._database, labels, 0.5)
        self._match_function._train(self._database, labels, pair_seed)
        self.assertTrue(self._match_function.match(r0, deepcopy(r0))[0])
        self.assertTrue(self._match_function.match(r1, deepcopy(r1))[0])
        self.assertTrue(self._match_function.match(r2, deepcopy(r2))[0])
        self.assertTrue(self._match_function.match(r3, deepcopy(r3))[0])

    def test_test(self):
        database = Database('test_annotations_10000_cleaned.csv', header_path='test_annotations_10000_cleaned_header.csv')
        database_train = database.sample_and_remove(5000)
        database_test = database
        labels_train = fast_strong_cluster(database_train)
        labels_test = fast_strong_cluster(database_test)
        train_seed = generate_pair_seed(database_train, labels_train, 0.5)
        match_function = LogisticMatchFunction(database_train, labels_train, train_seed, 0.7)
        roc = match_function.test(database_test, labels_test, 0.5)
        roc.make_plot()

    def test_get_x1(self):
        r0 = self._database.records[0]
        r1 = self._database.records[1]
        r2 = self._database.records[2]
        r3 = self._database.records[3]
        self.assertEqual(strong_match(r0, r3), True)
        self.assertEqual(strong_match(r1, r3), True)
        self.assertEqual(strong_match(r0, r1), False)
        self.assertEqual(strong_match(r0, r2), False)
        self.assertEqual(strong_match(r1, r2), False)
        self.assertEqual(strong_match(r2, r3), False)

    def test_get_x2(self):
        r0 = self._database.records[0]
        x2 = get_weak_pairwise_features(r0, r0)
        self.assertEqual(x2[0], 0) # [1], binary match
        self.assertEqual(x2[1], 0) # [2], date diff
        self.assertEqual(x2[2], 0) # [3], bin
        self.assertEqual(x2[3], 0) # [4], bin
        self.assertEqual(x2[4], 0) # [7] bin
        self.assertEqual(x2[5], 0) # [8] num diff
        self.assertTrue(isnan(x2[6]))  # [9] bin
        self.assertTrue(isnan(x2[7]))  # [10]  num diff
        self.assertTrue(isnan(x2[8]))  # [11]  num diff
        self.assertTrue(isnan(x2[9]))  # [12]  bin
        self.assertTrue(isnan(x2[10]))  # [13]   num diff
        self.assertTrue(isnan(x2[11]))  # [14]   num diff
        self.assertTrue(isnan(x2[12]))  # [15]   num diff
        self.assertTrue(isnan(x2[13]))  # [16] bin
        self.assertTrue(isnan(x2[14]))  # [17] bin
        self.assertTrue(isnan(x2[15]))  # [18] bin
        self.assertTrue(isnan(x2[16]))  # [19] bin
        self.assertTrue(isnan(x2[17]))  # [24] bin
        self.assertTrue(isnan(x2[18]))  # [25] bin
        self.assertEqual(x2[19], np.exp(-3))  # [26] number matches

    def test_number_matches(self):
        x_a = {1, 2, 3}
        x_b = {3, 4, 5}
        x_c = set()
        self.assertEqual(number_matches(x_a, x_a), 3)
        self.assertEqual(number_matches(x_a, x_b), 1)
        self.assertTrue(isnan(number_matches(x_a, x_c)))

    def test_numerical_difference(self):
        x_a = {1, 2, 3}
        x_b = {4, 5, 5}
        x_c = set()
        self.assertEqual(numerical_difference(x_a, x_a), 0)
        self.assertEqual(numerical_difference(x_a, x_b), 1)
        self.assertTrue(isnan(numerical_difference(x_a, x_c)))

    def test_binary_match(self):
        x_a = {1, 2, 3}
        x_b = {3, 4, 5}
        x_c = set()
        x_d = {5}
        self.assertEqual(binary_match(x_a, x_a), 1)
        self.assertEqual(binary_match(x_a, x_b), 1)
        self.assertEqual(binary_match(x_a, x_d), 0)
        self.assertTrue(isnan(binary_match(x_a, x_c)))

    def test_levenshtein(self):
        r1 = {'Matthew'}
        r2 = {'Matt'}
        d = levenshtein(r1, r2)
        self.assertEqual(d, 3)
        r1 = {'abcd', 'efgh', 'ijkl'}
        r2 = {'abbb', 'egfe', 'i'}
        d = levenshtein(r1, r2)
        self.assertEqual(d, 2)
        d = levenshtein(r1, r1)
        self.assertEqual(d, 0)
コード例 #6
0
def synthetic_sizes():
    """
    Sizes experiment here
    """
    resolution = 88
    number_features = 10
    number_entities = np.linspace(10, 100, num=resolution)
    number_entities = number_entities.astype(int)
    records_per_entity = 10
    #train_database_size = 100
    train_class_balance = 0.5
    #validation_database_size = 100
    corruption_multiplier = .001

    databases = list()
    db = SyntheticDatabase(number_entities[0], records_per_entity, number_features=number_features)
    databases.append(deepcopy(db))
    add_entities = [x - number_entities[i - 1] for i, x in enumerate(number_entities)][1:]
    for add in add_entities:
        db.add(add, records_per_entity)
        databases.append(deepcopy(db))
    corruption = np.random.normal(loc=0.0, scale=1.0, size=[number_entities[-1]*records_per_entity, number_features])
    train = deepcopy(databases[0])
    validation = deepcopy(databases[0])
    train.corrupt(corruption_multiplier*np.random.normal(loc=0.0, scale=1.0, size=[len(train.database.records), number_features]))
    validation.corrupt(corruption_multiplier*np.random.normal(loc=0.0, scale=1.0, size=[len(train.database.records), number_features]))
    for db in databases:
        db.corrupt(corruption_multiplier*corruption[:len(db.database.records), :])
    er = EntityResolution()
    train_pair_seed = generate_pair_seed(train.database, train.labels, train_class_balance)
    weak_match_function = LogisticMatchFunction(train.database, train.labels, train_pair_seed, 0.5)
    ROC = weak_match_function.test(validation.database, validation.labels, 0.5)
    #ROC.make_plot()

    ## Optimize ER on small dataset
    thresholds = np.linspace(0, 1.0, 10)
    metrics_list = list()
    #new_metrics_list = list()
    pairwise_precision = list()
    pairwise_recall = list()
    pairwise_f1 = list()
    for threshold in thresholds:
        weak_match_function.set_decision_threshold(threshold)
        labels_pred = er.run(deepcopy(databases[0].database), weak_match_function, single_block=True,
                             max_block_size=np.Inf, cores=1)
        met = Metrics(databases[0].labels, labels_pred)
        metrics_list.append(met)
        pairwise_precision.append(met.pairwise_precision)
        pairwise_recall.append(met.pairwise_recall)
        pairwise_f1.append(met.pairwise_f1)
        #class_balance_test = get_pairwise_class_balance(databases[0].labels)
        #new_metrics_list.append(NewMetrics(databases[0].database, er, class_balance_test))
    plt.plot(thresholds, pairwise_precision, label='Precision')
    plt.plot(thresholds, pairwise_recall, label='Recall')
    plt.plot(thresholds, pairwise_f1, label='F1')
    plt.xlabel('Threshold')
    plt.legend()
    plt.ylabel('Score')
    plt.title('Optimizing ER on small dataset')
    #i = np.argmax(np.array(pairwise_f1))
    #small_optimal_threshold = thresholds[i]  # optimize this
    small_optimal_threshold = 0.6
    print 'Optimal small threshold set at =', small_optimal_threshold
    plt.show()

    ## Possible score by optimizing on larger dataset
    metrics_list = list()
    pairwise_precision = list()
    pairwise_recall = list()
    pairwise_f1 = list()
    thresholds_largedataset = np.linspace(0.6, 1.0, 8)
    precision_lower_bound = list()
    recall_lower_bound = list()
    f1_lower_bound = list()
    for threshold in thresholds_largedataset:
        weak_match_function.set_decision_threshold(threshold)
        labels_pred = er.run(deepcopy(databases[-1].database), weak_match_function, single_block=True,
                             max_block_size=np.Inf, cores=1)
        met = Metrics(databases[-1].labels, labels_pred)
        metrics_list.append(met)
        pairwise_precision.append(met.pairwise_precision)
        pairwise_recall.append(met.pairwise_recall)
        pairwise_f1.append(met.pairwise_f1)
        class_balance_test = count_pairwise_class_balance(databases[-1].labels)
        new_metric = NewMetrics(databases[-1].database, labels_pred, weak_match_function, class_balance_test)
        precision_lower_bound.append(new_metric.precision_lower_bound)
        recall_lower_bound.append(new_metric.recall_lower_bound)
        f1_lower_bound.append(new_metric.f1_lower_bound)
    plt.plot(thresholds_largedataset, pairwise_precision, label='Precision', color='r')
    plt.plot(thresholds_largedataset, pairwise_recall, label='Recall', color='b')
    plt.plot(thresholds_largedataset, pairwise_f1, label='F1', color='g')
    plt.plot(thresholds_largedataset, precision_lower_bound, label='Precision Bound', color='r', linestyle=':')
    plt.plot(thresholds_largedataset, recall_lower_bound, label='Recall Bound', color='b', linestyle=':')
    plt.plot(thresholds_largedataset, f1_lower_bound, label='F1 Bound', color='g', linestyle=':')
    i = np.argmax(np.array(f1_lower_bound))
    large_optimal_threshold = thresholds_largedataset[i]
    print 'Optimal large threshold automatically set at =', large_optimal_threshold
    print 'If not correct: debug.'
    plt.xlabel('Threshold')
    plt.legend()
    plt.ylabel('Score')
    plt.title('Optimizing ER on large dataset')
    plt.show()

    ## Run on all dataset sizes
    #new_metrics_list = list()
    database_sizes = list()
    small_pairwise_precision = list()
    small_pairwise_recall = list()
    small_pairwise_f1 = list()
    large_precision_bound = list()
    large_precision_bound_lower_ci = list()
    large_precision_bound_upper_ci = list()
    large_precision = list()
    large_recall_bound = list()
    large_recall_bound_lower_ci = list()
    large_recall_bound_upper_ci = list()
    large_recall = list()
    large_f1 = list()
    large_f1_bound = list()
    for db in databases:
        print 'Analyzing synthetic database with', len(db.database.records), 'records'
        database_sizes.append(len(db.database.records))
        weak_match_function.set_decision_threshold(small_optimal_threshold)
        labels_pred = er.run(db.database, weak_match_function, single_block=True, max_block_size=np.Inf, cores=1)
        met = Metrics(db.labels, labels_pred)
        small_pairwise_precision.append(met.pairwise_precision)
        small_pairwise_recall.append(met.pairwise_recall)
        small_pairwise_f1.append(met.pairwise_f1)
        weak_match_function.set_decision_threshold(large_optimal_threshold)
        labels_pred = er.run(db.database, weak_match_function, single_block=True, max_block_size=np.Inf, cores=1)
        met = Metrics(db.labels, labels_pred)
        large_precision.append(met.pairwise_precision)
        large_recall.append(met.pairwise_recall)
        large_f1.append(met.pairwise_f1)
        class_balance_test = count_pairwise_class_balance(db.labels)
        new_metric = NewMetrics(db.database, labels_pred, weak_match_function, class_balance_test)
        large_precision_bound.append(new_metric.precision_lower_bound)
        large_recall_bound.append(new_metric.recall_lower_bound)
        large_f1_bound.append(new_metric.f1_lower_bound)
        large_precision_bound_lower_ci.append(new_metric.precision_lower_bound_lower_ci)
        large_precision_bound_upper_ci.append(new_metric.precision_lower_bound_upper_ci)
        large_recall_bound_lower_ci.append(new_metric.recall_lower_bound_lower_ci)
        large_recall_bound_upper_ci.append(new_metric.recall_lower_bound_upper_ci)

    with open('synthetic_sizes_temp.csv', 'wb') as f:
        f.write('Database size, Precision (small opt), Recall (small opt), F1 (small opt), Precision (large opt), Precision bound (large opt), Lower CI, Upper CI, Recall (large opt), Recall bound (large opt), Lower CI, Upper CI, F1 (large opt), F1 bound (large opt)\n')
        writer = csv.writer(f)
        writer.writerows(izip(database_sizes, small_pairwise_precision, small_pairwise_recall, small_pairwise_f1, large_precision, large_precision_bound, large_precision_bound_lower_ci, large_precision_bound_upper_ci, large_recall, large_recall_bound, large_recall_bound_lower_ci, large_recall_bound_upper_ci, large_f1, large_f1_bound))
    f.close()
    plt.figure()
    plt.plot(database_sizes, pairwise_precision, label='Precision', color='#4477AA', linewidth=3)
    plt.plot(database_sizes, pairwise_recall, label='Recall', color='#CC6677', linewidth=3)
    #plt.plot(database_sizes, pairwise_f1, label='F1', color='#DDCC77', linewidth=2)
    plt.ylim([0, 1.05])
    plt.yticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
    plt.legend(title='Pairwise:', loc='lower left')
    plt.xlabel('Number of Records')
    plt.ylabel('Pairwise Score')
    plt.title('Performance Degredation')
    plt.show()