示例#1
0
def train_classifier(collection, bow, pagerank, dataset, output, max_iter):
    """
    Trains a tag classifier on a NIF dataset.
    """
    if output is None:
        output = 'trained_classifier.pkl'
    b = BOWLanguageModel()
    b.load(bow)
    graph = WikidataGraph()
    graph.load_pagerank(pagerank)
    tagger = Tagger(collection, b, graph)
    d = NIFCollection.load(dataset)
    clf = SimpleTagClassifier(tagger)
    max_iter = int(max_iter)

    parameter_grid = []
    for max_distance in [50, 75, 150, 200]:
        for similarity, beta in [('one_step', 0.2), ('one_step', 0.1),
                                 ('one_step', 0.3)]:
            for C in [10.0, 1.0, 0.1]:
                for smoothing in [0.8, 0.6, 0.5, 0.4, 0.3]:
                    parameter_grid.append({
                        'nb_steps': 4,
                        'max_similarity_distance': max_distance,
                        'C': C,
                        'similarity': similarity,
                        'beta': beta,
                        'similarity_smoothing': smoothing,
                    })

    best_params = clf.crossfit_model(d, parameter_grid, max_iter=max_iter)
    print('#########')
    print(best_params)
    clf.save(output)
示例#2
0
def preprocess(filename, outfile):
    """
    Preprocesses a Wikidata .json.bz2 dump into a TSV format representing its adjacency matrix.
    """
    if outfile is None:
        outfile = '.'.join(filename.split('.')[:-2] + ["unsorted.tsv"])
    g = WikidataGraph()
    g.preprocess_dump(filename, outfile)
示例#3
0
def pagerank_shell(filename):
    """
    Interactively retrieve the pagerank on chosen items
    """
    g = WikidataGraph()
    g.load_pagerank(filename)
    while True:
        qid = input('>>> ')
        print(g.get_pagerank(qid))
示例#4
0
 def test_compute_pagerank(self):
     graph = WikidataGraph()
     graph.load_from_matrix(
         os.path.join(self.testdir, 'data/sample_wikidata_items.npz'))
     graph.compute_pagerank()
     self.assertTrue(
         graph.get_pagerank('Q45') > 0.0003
         and graph.get_pagerank('Q45') < 0.0004)
示例#5
0
def compute_pagerank(filename, outfile):
    """
    Computes the pagerank of a Wikidata adjacency matrix as represented by a Numpy sparse matrix in NPZ format.
    """
    if outfile is None:
        outfile = '.'.join(filename.split('.')[:-1] + ['pgrank.npy'])
    g = WikidataGraph()
    g.load_from_matrix(filename)
    g.compute_pagerank()
    g.save_pagerank(outfile)
示例#6
0
def compile(filename, outfile):
    """
    Compiles a sorted preprocessed Wikidata dump in TSV format to a Numpy sparse matrix.
    """
    if outfile is None:
        outfile = '.'.join(filename.split('.')[:-1] + ['npz'])
    g = WikidataGraph()
    g.load_from_preprocessed_dump(filename)
    g.save_matrix(outfile)
示例#7
0
    def setUpClass(cls):
        cls.testdir = os.path.dirname(os.path.abspath(__file__))

        # Load dummy bow
        bow_fname = os.path.join(cls.testdir, 'data/sample_bow.pkl')
        cls.bow = BOWLanguageModel()
        cls.bow.load(bow_fname)

        # Load dummy graph
        graph_fname = os.path.join(cls.testdir,
                                   'data/sample_wikidata_items.npz')
        pagerank_fname = os.path.join(cls.testdir,
                                      'data/sample_wikidata_items.pgrank.npy')
        cls.graph = WikidataGraph()
        cls.graph.load_from_matrix(graph_fname)
        cls.graph.load_pagerank(pagerank_fname)

        # Load dummy profile
        cls.profile = IndexingProfile.load(
            os.path.join(cls.testdir, 'data/all_items_profile.json'))

        # Setup solr index (TODO delete this) and tagger
        cls.tf = TaggerFactory()
        cls.collection_name = 'wd_test_collection'
        try:
            cls.tf.create_collection(cls.collection_name)
        except CollectionAlreadyExists:
            pass
        cls.tf.index_stream(
            cls.collection_name,
            WikidataDumpReader(
                os.path.join(cls.testdir,
                             'data/sample_wikidata_items.json.bz2')),
            cls.profile)
        cls.tagger = Tagger(cls.collection_name, cls.bow, cls.graph)

        # Load NIF dataset
        cls.nif = NIFCollection.load(
            os.path.join(cls.testdir, 'data/five-affiliations.ttl'))

        cls.classifier = SimpleTagClassifier(cls.tagger,
                                             max_similarity_distance=10,
                                             similarity_smoothing=2)
示例#8
0
    def setUpClass(cls):
        super(TaggerTest, cls).tearDownClass()
        testdir = os.path.dirname(os.path.abspath(__file__))

        # Load dummy bow
        bow_fname = os.path.join(testdir, 'data/sample_bow.pkl')
        cls.bow = BOWLanguageModel()
        cls.bow.load(bow_fname)

        # Load dummy graph
        graph_fname = os.path.join(testdir, 'data/sample_wikidata_items.npz')
        pagerank_fname = os.path.join(testdir,
                                      'data/sample_wikidata_items.pgrank.npy')
        cls.graph = WikidataGraph()
        cls.graph.load_from_matrix(graph_fname)
        cls.graph.load_pagerank(pagerank_fname)

        # Load indexing profile
        cls.profile = IndexingProfile.load(
            os.path.join(testdir, 'data/all_items_profile.json'))

        # Setup solr index
        cls.tf = TaggerFactory()
        cls.collection_name = 'wd_test_collection'
        try:
            cls.tf.delete_collection('wd_test_collection')
        except requests.exceptions.RequestException:
            pass
        cls.tf.create_collection(cls.collection_name)
        cls.tf.index_stream(
            'wd_test_collection',
            WikidataDumpReader(
                os.path.join(testdir, 'data/sample_wikidata_items.json.bz2')),
            cls.profile)

        cls.sut = Tagger(cls.collection_name, cls.bow, cls.graph)
示例#9
0
import settings

from opentapioca.wikidatagraph import WikidataGraph
from opentapioca.languagemodel import BOWLanguageModel
from opentapioca.tagger import Tagger
from opentapioca.classifier import SimpleTagClassifier

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s')

tapioca_dir = os.path.dirname(__file__)

bow = BOWLanguageModel()
if settings.LANGUAGE_MODEL_PATH:
    bow.load(settings.LANGUAGE_MODEL_PATH)
graph = WikidataGraph()
if settings.PAGERANK_PATH:
    graph.load_pagerank(settings.PAGERANK_PATH)
tagger = None
classifier = None
if settings.SOLR_COLLECTION:
    tagger = Tagger(settings.SOLR_COLLECTION, bow, graph)
    classifier = SimpleTagClassifier(tagger)
    if settings.CLASSIFIER_PATH:
        classifier.load(settings.CLASSIFIER_PATH)


def jsonp(view):
    """
    Decorator for views that return JSON
    """
示例#10
0
 def test_compile_dump(self):
     graph = WikidataGraph()
     graph.load_from_preprocessed_dump(
         os.path.join(self.testdir, 'data/sample_wikidata_items.tsv'))
     graph.mat.check_format()
     self.assertEqual(graph.shape, 3942)
示例#11
0
 def test_compile_unordered_dump(self):
     graph = WikidataGraph()
     with self.assertRaises(ValueError):
         graph.load_from_preprocessed_dump(
             os.path.join(self.testdir,
                          'data/sample_wikidata_items.unsorted.tsv'))