コード例 #1
0
def train_classifier(collection, bow, pagerank, dataset, output, max_iter):
    """
    Trains a tag classifier on a NIF dataset.
    """
    if output is None:
        output = 'trained_classifier.pkl'
    b = BOWLanguageModel()
    b.load(bow)
    graph = WikidataGraph()
    graph.load_pagerank(pagerank)
    tagger = Tagger(collection, b, graph)
    d = NIFCollection.load(dataset)
    clf = SimpleTagClassifier(tagger)
    max_iter = int(max_iter)

    parameter_grid = []
    for max_distance in [50, 75, 150, 200]:
        for similarity, beta in [('one_step', 0.2), ('one_step', 0.1),
                                 ('one_step', 0.3)]:
            for C in [10.0, 1.0, 0.1]:
                for smoothing in [0.8, 0.6, 0.5, 0.4, 0.3]:
                    parameter_grid.append({
                        'nb_steps': 4,
                        'max_similarity_distance': max_distance,
                        'C': C,
                        'similarity': similarity,
                        'beta': beta,
                        'similarity_smoothing': smoothing,
                    })

    best_params = clf.crossfit_model(d, parameter_grid, max_iter=max_iter)
    print('#########')
    print(best_params)
    clf.save(output)
コード例 #2
0
    def setUpClass(cls):
        cls.testdir = os.path.dirname(os.path.abspath(__file__))

        # Load dummy bow
        bow_fname = os.path.join(cls.testdir, 'data/sample_bow.pkl')
        cls.bow = BOWLanguageModel()
        cls.bow.load(bow_fname)

        # Load dummy graph
        graph_fname = os.path.join(cls.testdir,
                                   'data/sample_wikidata_items.npz')
        pagerank_fname = os.path.join(cls.testdir,
                                      'data/sample_wikidata_items.pgrank.npy')
        cls.graph = WikidataGraph()
        cls.graph.load_from_matrix(graph_fname)
        cls.graph.load_pagerank(pagerank_fname)

        # Load dummy profile
        cls.profile = IndexingProfile.load(
            os.path.join(cls.testdir, 'data/all_items_profile.json'))

        # Setup solr index (TODO delete this) and tagger
        cls.tf = TaggerFactory()
        cls.collection_name = 'wd_test_collection'
        try:
            cls.tf.create_collection(cls.collection_name)
        except CollectionAlreadyExists:
            pass
        cls.tf.index_stream(
            cls.collection_name,
            WikidataDumpReader(
                os.path.join(cls.testdir,
                             'data/sample_wikidata_items.json.bz2')),
            cls.profile)
        cls.tagger = Tagger(cls.collection_name, cls.bow, cls.graph)

        # Load NIF dataset
        cls.nif = NIFCollection.load(
            os.path.join(cls.testdir, 'data/five-affiliations.ttl'))

        cls.classifier = SimpleTagClassifier(cls.tagger,
                                             max_similarity_distance=10,
                                             similarity_smoothing=2)
コード例 #3
0
 def setUpClass(cls):
     testdir = os.path.dirname(os.path.abspath(__file__))
     cls.dbpedia_nif = NIFCollection.load(
         os.path.join(testdir, 'data/sample_dbpedia.ttl'))
     cls.wikipedia_nif = NIFCollection.load(
         os.path.join(testdir, 'data/sample_wikipedia.ttl'))