Ejemplo n.º 1
0
 def test_plot(self):
     synthetic_original = SyntheticDatabase(10, 5)
     synthetic_original.plot(synthetic_original.labels, title='No Noise')
     plt.show()
     synthetic_noise = deepcopy(synthetic_original)
     noise = (np.random.randn(50, 2)*0.02).tolist()
     synthetic_noise.corrupt(noise)
     synthetic_noise.plot(synthetic_noise.labels, title='Even Noise')
     plt.show()
Ejemplo n.º 2
0
 def test_synthetic_sample_and_remove(self):
     synthetic = SyntheticDatabase(10, 10, number_features=3)  # 100 records, 10 entities, 10 records each
     synthetic2 = synthetic.sample_and_remove(50)
     self.assertEqual(len(synthetic.database.records), 50)
     self.assertEqual(len(synthetic2.database.records), 50)
     self.assertEqual(set(synthetic.database.records.keys()), set(synthetic.labels.keys()))
     self.assertEqual(set(synthetic2.database.records.keys()), set(synthetic2.labels.keys()))
     self.assertEqual(len(set(synthetic.database.records.keys()) & set(synthetic2.database.records.keys())), 0)
     self.assertEqual(len(set(synthetic.labels.keys()) & set(synthetic2.labels.keys())), 0)
Ejemplo n.º 3
0
 def test_synthetic_number_features(self):
     synthetic = SyntheticDatabase(10, 10, number_features=2)
     self.assertEqual(synthetic.database.feature_descriptor.names, ['Name_0', 'Name_1'])
     self.assertEqual(synthetic.database.feature_descriptor.types, ['float', 'float'])
     self.assertEqual(synthetic.database.feature_descriptor.strengths, ['weak', 'weak'])
     self.assertEqual(synthetic.database.feature_descriptor.blocking, ['', ''])
     self.assertEqual(synthetic.database.feature_descriptor.pairwise_uses, ['numerical_difference', 'numerical_difference'])
     self.assertEqual(len(synthetic.database.records[0].features), 2)
Ejemplo n.º 4
0
 def test_synthetic(self):
     synthetic = SyntheticDatabase(10, 10, number_features=10)  # 100 records, 10 entities, 10 records each
     self.assertEqual(len(synthetic.labels), 100)
     self.assertEqual(synthetic.labels[0], 0)
     self.assertEqual(synthetic.labels[99], 9)
     self.assertEqual(len(synthetic.database.records), 100)
     self.assertEqual(synthetic.database.records[0].features, synthetic.database.records[1].features)  # line indices won't (and shouldn't) match. Different records
     self.assertEqual(synthetic.database.records[98].features, synthetic.database.records[98].features)
     self.assertNotEqual(synthetic.database.records[9].features, synthetic.database.records[10].features)
     self.assertEqual(synthetic.database.feature_descriptor.number, 10)
Ejemplo n.º 5
0
 def test_corrupt(self):
     synthetic = SyntheticDatabase(3, 3, number_features=2)
     self.assertEqual(synthetic.database.records[0].features, synthetic.database.records[1].features)
     corruption = list(np.random.normal(loc=0.0, scale=1.0, size=[9, 2]))
     synthetic.corrupt(corruption)
     self.assertNotEqual(synthetic.database.records[0].features, synthetic.database.records[1].features)
Ejemplo n.º 6
0
 def test_synthetic_add(self):
     synthetic = SyntheticDatabase(10, 10, number_features=3)  # 100 records, 10 entities, 10 records each
     self.assertEqual(len(synthetic.database.records), 100)
     synthetic.add(5, 10)
     self.assertEqual(len(synthetic.database.records), 150)
Ejemplo n.º 7
0
def main():
    """
    Runs a single entity resolution on data (real or synthetic) using a match function (logistic regression, decision
    tree, or random forest)
    """
    data_type = 'real'
    decision_threshold = 0.7
    train_class_balance = 0.5
    max_block_size = 1000
    cores = 2
    if data_type == 'synthetic':
        database_train = SyntheticDatabase(100, 10, 10)
        corruption = 0.1
        corruption_array = corruption*np.random.normal(loc=0.0, scale=1.0, size=[1000,
                                                       database_train.database.feature_descriptor.number])
        database_train.corrupt(corruption_array)

        database_validation = SyntheticDatabase(100, 10, 10)
        corruption_array = corruption*np.random.normal(loc=0.0, scale=1.0, size=[1000,
                                                       database_validation.database.feature_descriptor.number])
        database_validation.corrupt(corruption_array)

        database_test = SyntheticDatabase(10, 10, 10)
        corruption_array = corruption*np.random.normal(loc=0.0, scale=1.0, size=[1000,
                                                       database_test.database.feature_descriptor.number])
        database_test.corrupt(corruption_array)
        labels_train = database_train.labels
        labels_validation = database_validation.labels
        labels_test = database_test.labels
        database_train = database_train.database
        database_validation = database_validation.database
        database_test = database_test.database
        single_block = True
    elif data_type == 'real':
        # Uncomment to use all features (annotations and LM)
        #database_train = Database('../data/trafficking/cluster_subsample0_10000.csv', header_path='../data/trafficking/cluster_subsample_header_all.csv')
        #database_validation = Database('../data/trafficking/cluster_subsample1_10000.csv', header_path='../data/trafficking/cluster_subsample_header_all.csv')
        #database_test = Database('../data/trafficking/cluster_subsample2_10000.csv', header_path='../data/trafficking/cluster_subsample_header_all.csv')

        # Uncomment to only use annotation features
        #database_train = Database('../data/trafficking/cluster_subsample0_10000.csv', header_path='../data/trafficking/cluster_subsample_header_annotations.csv')
        #database_validation = Database('../data/trafficking/cluster_subsample1_10000.csv', header_path='../data/trafficking/cluster_subsample_header_annotations.csv')
        #database_test = Database('../data/trafficking/cluster_subsample2_10000.csv', header_path='../data/trafficking/cluster_subsample_header_annotations.csv')

        # Uncomment to only use LM features
        database_train = Database('../data/trafficking/cluster_subsample0_10000.csv', header_path='../data/trafficking/cluster_subsample_header_LM.csv')
        database_validation = Database('../data/trafficking/cluster_subsample1_10000.csv', header_path='../data/trafficking/cluster_subsample_header_LM.csv')
        database_test = Database('../data/trafficking/cluster_subsample2_10000.csv', header_path='../data/trafficking/cluster_subsample_header_LM.csv')

        labels_train = fast_strong_cluster(database_train)
        labels_validation = fast_strong_cluster(database_validation)
        labels_test = fast_strong_cluster(database_test)
        single_block = False
    else:
        Exception('Invalid experiment type'+data_type)

    entities = deepcopy(database_test)
    blocking_scheme = BlockingScheme(entities, max_block_size, single_block=single_block)

    train_seed = generate_pair_seed(database_train, labels_train, train_class_balance, require_direct_match=True, max_minor_class=5000)
    validation_seed = generate_pair_seed(database_validation, labels_validation, 0.5, require_direct_match=True, max_minor_class=5000)
    # forest_all = ForestMatchFunction(database_all_train, labels_train, train_seed, decision_threshold)
    # forest_all.test(database_all_validation, labels_validation, validation_seed)
    # tree_all = TreeMatchFunction(database_all_train, labels_train, train_seed, decision_threshold)
    # tree_all.test(database_all_validation, labels_validation, validation_seed)
    # logistic_all = LogisticMatchFunction(database_all_train, labels_train, train_seed, decision_threshold)
    # logistic_all.test(database_all_validation, labels_validation, validation_seed)

    forest_annotations = ForestMatchFunction(database_train, labels_train, train_seed, decision_threshold)
    roc = forest_annotations.test(database_validation, labels_validation, validation_seed)
    #roc.make_plot()
    #plt.show()

    # tree_annotations = TreeMatchFunction(database_annotations_train, labels_train, train_seed, decision_threshold)
    # tree_annotations.test(database_annotations_validation, labels_validation, validation_seed)
    # logistic_annotations = LogisticMatchFunction(database_annotations_train, labels_train, train_seed, decision_threshold)
    # logistic_annotations.test(database_annotations_validation, labels_validation, validation_seed)

    # forest_LM = ForestMatchFunction(database_LM_train, labels_train, train_seed, decision_threshold)
    # forest_LM.test(database_LM_validation, labels_validation, validation_seed)
    # tree_LM = TreeMatchFunction(database_LM_train, labels_train, train_seed, decision_threshold)
    # tree_LM.test(database_LM_validation, labels_validation, validation_seed)
    # logistic_LM = LogisticMatchFunction(database_LM_train, labels_train, train_seed, decision_threshold)
    # logistic_LM.test(database_LM_validation, labels_validation, validation_seed)

    # forest_all.roc.write_rates('match_forest_all.csv')
    # tree_all.roc.write_rates('match_tree_all.csv')
    # logistic_all.roc.write_rates('match_logistic_all.csv')
    #
    # forest_annotations.roc.write_rates('match_forest_annotations.csv')
    # tree_annotations.roc.write_rates('match_tree_annotations.csv')
    # logistic_annotations.roc.write_rates('match_logistic_annotations.csv')
    #
    # forest_LM.roc.write_rates('match_forest_LM.csv')
    # tree_LM.roc.write_rates('match_tree_LM.csv')
    # logistic_LM.roc.write_rates('match_logistic_LM.csv')
    # ax = forest_all.roc.make_plot()
    # _ = tree_all.roc.make_plot(ax=ax)
    # _ = logistic_all.roc.make_plot(ax=ax)
    # plt.show()
    #forest_annotations.roc.make_plot()
    #plt.show()

    #entities.merge(strong_labels)

    #er = EntityResolution()
    #weak_labels = er.run(entities, match_function, blocking_scheme, cores=cores)
    weak_labels = weak_connected_components(database_test, forest_annotations, blocking_scheme)
    entities.merge(weak_labels)
    #strong_labels = fast_strong_cluster(entities)
    #entities.merge(strong_labels)

    # out = open('ER.csv', 'w')
    # out.write('phone,cluster_id\n')
    # for cluster_counter, (entity_id, entity) in enumerate(entities.records.iteritems()):
    #     phone_index = 21
    #     for phone in entity.features[phone_index]:
    #         out.write(str(phone)+','+str(cluster_counter)+'\n')
    # out.close()

    print 'Metrics using strong features as surrogate label. Entity resolution run using weak and strong features'
    metrics = Metrics(labels_test, weak_labels)
    # estimated_test_class_balance = count_pairwise_class_balance(labels_test)
    # new_metrics = NewMetrics(database_all_test, weak_labels, forest_all, estimated_test_class_balance)
    metrics.display()
Ejemplo n.º 8
0
def experiment_wrapper(dataset_name):
    """
    Experiment wrapper, just takes the type of experiment, all parameters saved here
    :param dataset_name: Name of the database to run on, either synthetic, restaurant, abt-buy, trafficking
    """
    if dataset_name == 'synthetic':
        number_entities = 100
        records_per_entity = 10
        train_database_size = 200
        train_class_balance = 0.5
        validation_database_size = 200
        corruption = 0.001  #0.025
        number_thresholds = 30
        number_features = 10

        synthetic_database = SyntheticDatabase(number_entities, records_per_entity, number_features=number_features)
        corruption_array = corruption*np.random.normal(loc=0.0, scale=1.0, size=[validation_database_size,
                                                       synthetic_database.database.feature_descriptor.number])
        synthetic_database.corrupt(corruption_array)
        synthetic_train = synthetic_database.sample_and_remove(train_database_size)
        synthetic_validation = synthetic_database.sample_and_remove(validation_database_size)
        synthetic_test = synthetic_database
        thresholds = np.linspace(0, 1, number_thresholds)
        experiment = Experiment(synthetic_train.database, synthetic_validation.database, synthetic_test.database,
                                synthetic_train.labels, synthetic_validation.labels, synthetic_test.labels,
                                train_class_balance, thresholds)
        experiment.plot()
    else:
        number_thresholds = 5
        if dataset_name == 'restaurant':  # 864 records, 112 matches
            features_path = '../data/restaurant/merged.csv'
            labels_path = '../data/restaurant/labels.csv'
            train_database_size = 300
            train_class_balance = .4
            validation_database_size = 200
            database = Database(annotation_path=features_path)
        elif dataset_name == 'abt-buy':  # ~4900 records, 1300 matches
            features_path = '../data/Abt-Buy/merged.csv'
            labels_path = '../data/Abt-Buy/labels.csv'
            train_database_size = 300
            train_class_balance = 0.4
            validation_database_size = 300
            database = Database(annotation_path=features_path)
        elif dataset_name == 'trafficking':
            features_path = '../data/trafficking/features.csv'
            labels_path = '../data/trafficking/labels.csv'
            train_database_size = 300
            train_class_balance = 0.5
            validation_database_size = 300
            #database = Database(annotation_path=features_path)
        else:
            raise Exception('Invalid dataset name')
        thresholds = np.linspace(0, 1, number_thresholds)
        # labels = np.loadtxt(open(labels_path, 'rb'))
        # database_train = database.sample_and_remove(train_database_size)
        # database_validation = database.sample_and_remove(validation_database_size)
        # database_test = database
        # labels_train = dict()
        # labels_validation = dict()
        # labels_test = dict()
        # for identifier, label in enumerate(labels):
        #     if identifier in database_train.records:
        #         labels_train[identifier] = label
        #     elif identifier in database_validation.records:
        #         labels_validation[identifier] = label
        #     elif identifier in database_test.records:
        #         labels_test[identifier] = label
        #     else:
        #         raise Exception('Record identifier ' + str(identifier) + ' not in either database')
        ###
        database_train = Database('../data/trafficking/cluster_subsample0_10000.csv', header_path='../data/trafficking/cluster_subsample_header_LM.csv', max_records=5000)
        database_validation = Database('../data/trafficking/cluster_subsample1_10000.csv', header_path='../data/trafficking/cluster_subsample_header_LM.csv', max_records=5000)
        database_test = Database('../data/trafficking/cluster_subsample2_10000.csv', header_path='../data/trafficking/cluster_subsample_header_LM.csv', max_records=1000)

        labels_train = fast_strong_cluster(database_train)
        labels_validation = fast_strong_cluster(database_validation)
        labels_test = fast_strong_cluster(database_test)
        ###

        experiment = Experiment(database_train, database_validation, database_test,
                                labels_train, labels_validation, labels_test,
                                train_class_balance, thresholds)
        #print 'Saving results'
        #pickle.dump(experiment, open('experiment.p', 'wb'))
        experiment.plot()
    print 'Finished'
Ejemplo n.º 9
0
def synthetic_sizes():
    """
    Sizes experiment here
    """
    resolution = 88
    number_features = 10
    number_entities = np.linspace(10, 100, num=resolution)
    number_entities = number_entities.astype(int)
    records_per_entity = 10
    #train_database_size = 100
    train_class_balance = 0.5
    #validation_database_size = 100
    corruption_multiplier = .001

    databases = list()
    db = SyntheticDatabase(number_entities[0], records_per_entity, number_features=number_features)
    databases.append(deepcopy(db))
    add_entities = [x - number_entities[i - 1] for i, x in enumerate(number_entities)][1:]
    for add in add_entities:
        db.add(add, records_per_entity)
        databases.append(deepcopy(db))
    corruption = np.random.normal(loc=0.0, scale=1.0, size=[number_entities[-1]*records_per_entity, number_features])
    train = deepcopy(databases[0])
    validation = deepcopy(databases[0])
    train.corrupt(corruption_multiplier*np.random.normal(loc=0.0, scale=1.0, size=[len(train.database.records), number_features]))
    validation.corrupt(corruption_multiplier*np.random.normal(loc=0.0, scale=1.0, size=[len(train.database.records), number_features]))
    for db in databases:
        db.corrupt(corruption_multiplier*corruption[:len(db.database.records), :])
    er = EntityResolution()
    train_pair_seed = generate_pair_seed(train.database, train.labels, train_class_balance)
    weak_match_function = LogisticMatchFunction(train.database, train.labels, train_pair_seed, 0.5)
    ROC = weak_match_function.test(validation.database, validation.labels, 0.5)
    #ROC.make_plot()

    ## Optimize ER on small dataset
    thresholds = np.linspace(0, 1.0, 10)
    metrics_list = list()
    #new_metrics_list = list()
    pairwise_precision = list()
    pairwise_recall = list()
    pairwise_f1 = list()
    for threshold in thresholds:
        weak_match_function.set_decision_threshold(threshold)
        labels_pred = er.run(deepcopy(databases[0].database), weak_match_function, single_block=True,
                             max_block_size=np.Inf, cores=1)
        met = Metrics(databases[0].labels, labels_pred)
        metrics_list.append(met)
        pairwise_precision.append(met.pairwise_precision)
        pairwise_recall.append(met.pairwise_recall)
        pairwise_f1.append(met.pairwise_f1)
        #class_balance_test = get_pairwise_class_balance(databases[0].labels)
        #new_metrics_list.append(NewMetrics(databases[0].database, er, class_balance_test))
    plt.plot(thresholds, pairwise_precision, label='Precision')
    plt.plot(thresholds, pairwise_recall, label='Recall')
    plt.plot(thresholds, pairwise_f1, label='F1')
    plt.xlabel('Threshold')
    plt.legend()
    plt.ylabel('Score')
    plt.title('Optimizing ER on small dataset')
    #i = np.argmax(np.array(pairwise_f1))
    #small_optimal_threshold = thresholds[i]  # optimize this
    small_optimal_threshold = 0.6
    print 'Optimal small threshold set at =', small_optimal_threshold
    plt.show()

    ## Possible score by optimizing on larger dataset
    metrics_list = list()
    pairwise_precision = list()
    pairwise_recall = list()
    pairwise_f1 = list()
    thresholds_largedataset = np.linspace(0.6, 1.0, 8)
    precision_lower_bound = list()
    recall_lower_bound = list()
    f1_lower_bound = list()
    for threshold in thresholds_largedataset:
        weak_match_function.set_decision_threshold(threshold)
        labels_pred = er.run(deepcopy(databases[-1].database), weak_match_function, single_block=True,
                             max_block_size=np.Inf, cores=1)
        met = Metrics(databases[-1].labels, labels_pred)
        metrics_list.append(met)
        pairwise_precision.append(met.pairwise_precision)
        pairwise_recall.append(met.pairwise_recall)
        pairwise_f1.append(met.pairwise_f1)
        class_balance_test = count_pairwise_class_balance(databases[-1].labels)
        new_metric = NewMetrics(databases[-1].database, labels_pred, weak_match_function, class_balance_test)
        precision_lower_bound.append(new_metric.precision_lower_bound)
        recall_lower_bound.append(new_metric.recall_lower_bound)
        f1_lower_bound.append(new_metric.f1_lower_bound)
    plt.plot(thresholds_largedataset, pairwise_precision, label='Precision', color='r')
    plt.plot(thresholds_largedataset, pairwise_recall, label='Recall', color='b')
    plt.plot(thresholds_largedataset, pairwise_f1, label='F1', color='g')
    plt.plot(thresholds_largedataset, precision_lower_bound, label='Precision Bound', color='r', linestyle=':')
    plt.plot(thresholds_largedataset, recall_lower_bound, label='Recall Bound', color='b', linestyle=':')
    plt.plot(thresholds_largedataset, f1_lower_bound, label='F1 Bound', color='g', linestyle=':')
    i = np.argmax(np.array(f1_lower_bound))
    large_optimal_threshold = thresholds_largedataset[i]
    print 'Optimal large threshold automatically set at =', large_optimal_threshold
    print 'If not correct: debug.'
    plt.xlabel('Threshold')
    plt.legend()
    plt.ylabel('Score')
    plt.title('Optimizing ER on large dataset')
    plt.show()

    ## Run on all dataset sizes
    #new_metrics_list = list()
    database_sizes = list()
    small_pairwise_precision = list()
    small_pairwise_recall = list()
    small_pairwise_f1 = list()
    large_precision_bound = list()
    large_precision_bound_lower_ci = list()
    large_precision_bound_upper_ci = list()
    large_precision = list()
    large_recall_bound = list()
    large_recall_bound_lower_ci = list()
    large_recall_bound_upper_ci = list()
    large_recall = list()
    large_f1 = list()
    large_f1_bound = list()
    for db in databases:
        print 'Analyzing synthetic database with', len(db.database.records), 'records'
        database_sizes.append(len(db.database.records))
        weak_match_function.set_decision_threshold(small_optimal_threshold)
        labels_pred = er.run(db.database, weak_match_function, single_block=True, max_block_size=np.Inf, cores=1)
        met = Metrics(db.labels, labels_pred)
        small_pairwise_precision.append(met.pairwise_precision)
        small_pairwise_recall.append(met.pairwise_recall)
        small_pairwise_f1.append(met.pairwise_f1)
        weak_match_function.set_decision_threshold(large_optimal_threshold)
        labels_pred = er.run(db.database, weak_match_function, single_block=True, max_block_size=np.Inf, cores=1)
        met = Metrics(db.labels, labels_pred)
        large_precision.append(met.pairwise_precision)
        large_recall.append(met.pairwise_recall)
        large_f1.append(met.pairwise_f1)
        class_balance_test = count_pairwise_class_balance(db.labels)
        new_metric = NewMetrics(db.database, labels_pred, weak_match_function, class_balance_test)
        large_precision_bound.append(new_metric.precision_lower_bound)
        large_recall_bound.append(new_metric.recall_lower_bound)
        large_f1_bound.append(new_metric.f1_lower_bound)
        large_precision_bound_lower_ci.append(new_metric.precision_lower_bound_lower_ci)
        large_precision_bound_upper_ci.append(new_metric.precision_lower_bound_upper_ci)
        large_recall_bound_lower_ci.append(new_metric.recall_lower_bound_lower_ci)
        large_recall_bound_upper_ci.append(new_metric.recall_lower_bound_upper_ci)

    with open('synthetic_sizes_temp.csv', 'wb') as f:
        f.write('Database size, Precision (small opt), Recall (small opt), F1 (small opt), Precision (large opt), Precision bound (large opt), Lower CI, Upper CI, Recall (large opt), Recall bound (large opt), Lower CI, Upper CI, F1 (large opt), F1 bound (large opt)\n')
        writer = csv.writer(f)
        writer.writerows(izip(database_sizes, small_pairwise_precision, small_pairwise_recall, small_pairwise_f1, large_precision, large_precision_bound, large_precision_bound_lower_ci, large_precision_bound_upper_ci, large_recall, large_recall_bound, large_recall_bound_lower_ci, large_recall_bound_upper_ci, large_f1, large_f1_bound))
    f.close()
    plt.figure()
    plt.plot(database_sizes, pairwise_precision, label='Precision', color='#4477AA', linewidth=3)
    plt.plot(database_sizes, pairwise_recall, label='Recall', color='#CC6677', linewidth=3)
    #plt.plot(database_sizes, pairwise_f1, label='F1', color='#DDCC77', linewidth=2)
    plt.ylim([0, 1.05])
    plt.yticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
    plt.legend(title='Pairwise:', loc='lower left')
    plt.xlabel('Number of Records')
    plt.ylabel('Pairwise Score')
    plt.title('Performance Degredation')
    plt.show()