コード例 #1
0
ファイル: entity_linking.py プロジェクト: DenXX/aqqu
def get_number_of_external_entities():
    import scorer_globals
    globals.read_configuration('config_webentity.cfg')
    parser = CoreNLPParser.init_from_config()
    entity_linker = WebSearchResultsExtenderEntityLinker.init_from_config()
    entity_linker.topn_entities = 100000
    scorer_globals.init()

    parameters = translator.TranslatorParameters()
    parameters.require_relation_match = False
    parameters.restrict_answer_type = False

    datasets = ["webquestions_split_train", "webquestions_split_dev",]
    # datasets = ["webquestions_split_train_externalentities", "webquestions_split_dev_externalentities",]
    # datasets = ["webquestions_split_train_externalentities3", "webquestions_split_dev_externalentities3",]

    external_entities_count = []
    for dataset in datasets:
        queries = load_eval_queries(dataset)
        for index, query in enumerate(queries):
            entities = entity_linker.identify_entities_in_tokens(parser.parse(query.utterance).tokens, text=query.utterance, find_dates=False)
            print "-------------------------"
            print query.utterance
            print "\n".join(map(str, sorted(entities, key=lambda entity: entity.external_entity_count, reverse=True)))

            external_entities_count.append(0)
            for entity in entities:
                if entity.external_entity:
                    external_entities_count[-1] += 1
            if index % 100 == 0:
                print >> sys.stderr, "%s queries processed" % index
    print "========================================="
    print external_entities_count
    print sum(external_entities_count)
    print len(external_entities_count)
コード例 #2
0
ファイル: example.py プロジェクト: mindis/KBQA
def gen_description_data(fn_wq_list, fn_out):

    globals.read_configuration("config.cfg")
    entity_linker = EntityLinker.init_from_config()
    parser = CoreNLPParser.init_from_config()

    fout = open(fn_out, 'w')
    for fn_wq in fn_wq_list:
        wq = json.load(open(fn_wq), encoding='utf8')
        for data in wq:
            tokens = parser.parse(data['utterance'])
            entities = entity_linker.identify_entities_in_tokens(tokens)
            neg_entities = set()
            for e in entities:
                mid = e.get_mid()
                if mid == '':
                    continue
                if mid.startswith('m.'):
                    neg_entities.add(mid)
                else:
                    print mid, e.name, data['utterance']
            neg_entities -= set([data['mid1']])
            instance = {
                'q': data['utterance'],
                'pos': data['mid1'],
                'neg': list(neg_entities)
            }
            print >> fout, json.dumps(instance,
                                      ensure_ascii=False).encode('utf8')
    fout.close()
コード例 #3
0
ファイル: dump_qa_entity_pairs.py プロジェクト: DenXX/aqqu
def print_sparql_queries():
    import argparse

    parser = argparse.ArgumentParser(description="Dump qa entity pairs.")
    parser.add_argument("--config",
                        default="config.cfg",
                        help="The configuration file to use.")
    parser.add_argument("--output",
                        help="The file to dump results to.")
    args = parser.parse_args()
    globals.read_configuration(args.config)
    scorer_globals.init()

    parameters = translator.TranslatorParameters()
    parameters.require_relation_match = False
    parameters.restrict_answer_type = False

    dataset = "webquestions_test_filter"

    sparql_backend = globals.get_sparql_backend(globals.config)
    queries = get_evaluated_queries(dataset, True, parameters)
    for index, query in enumerate(queries):
        print "--------------------------------------------"
        print query.utterance
        print "\n".join([str((entity.__class__, entity.entity)) for entity in query.eval_candidates[0].query_candidate.query.identified_entities])
        for eval_candidate in query.eval_candidates:
            query_candidate = eval_candidate.query_candidate
            query_candidate.sparql_backend = sparql_backend
            notable_types = query_candidate.get_answers_notable_types()
            if notable_types:
                print notable_types
                print query_candidate.graph_as_simple_string().encode("utf-8")
                print query_candidate.to_sparql_query().encode("utf-8")
                print "\n\n"
コード例 #4
0
ファイル: example.py プロジェクト: mindis/KBQA
def link_entity_in_simple_question(fn_in, fn_out):
    globals.read_configuration("config.cfg")
    entity_linker = EntityLinker.init_from_config()
    parser = CoreNLPParser.init_from_config()
    with open(fn_out, 'w') as fout:
        with open(fn_in) as fin:
            for line in fin:
                ll = line.decode('utf8').strip().split('\t')
                if len(ll) != 5:
                    continue
                tokens = parser.parse(ll[4])
                entities = entity_linker.identify_entities_in_tokens(tokens)
                neg_entities = set()
                for e in entities:
                    mid = e.get_mid()
                    if mid == '':
                        continue
                    if mid.startswith('m.'):
                        neg_entities.add(mid)
                    else:
                        print mid, e.name, ll[4]
                neg_entities -= set([ll[0]])
                line = json.dumps(
                    {
                        'q': ll[4],
                        'pos': ll[0],
                        'neg': list(neg_entities)
                    },
                    ensure_ascii=False).encode('utf8')
                print >> fout, line
コード例 #5
0
ファイル: ama.py プロジェクト: xiaozhuyfk/parallel
def main():
    import argparse
    parser = argparse.ArgumentParser(description='Choose to learn or test AMA')

    parser.add_argument('--config',
                        default='config.cfg',
                        help='The configuration file to use')
    subparsers = parser.add_subparsers(help='command help')
    answer_parser = subparsers.add_parser('answer')
    answer_parser.add_argument('question')
    answer_parser.set_defaults(which='answer')

    play_parser = subparsers.add_parser('play')
    play_parser.set_defaults(which='play')

    test_parser = subparsers.add_parser('test', help='Test memory network')
    test_parser.add_argument('dataset', help='The dataset to test')
    test_parser.set_defaults(which='test')

    args = parser.parse_args()

    # Read global config
    globals.read_configuration(args.config)

    # Load modules
    modules.init_from_config(args)

    if args.which == 'test':
        test(args.dataset)
    elif args.which == 'answer':
        answer(args.question)
    elif args.which == 'play':
        play()
コード例 #6
0
ファイル: main.py プロジェクト: xiaozhuyfk/NeuralNet
def main():
    import argparse
    parser = argparse.ArgumentParser(description='Choose to learn or test AMA')

    parser.add_argument('--config',
                        default='config.cfg',
                        help='The configuration file to use')
    subparsers = parser.add_subparsers(help='command help')
    train_parser = subparsers.add_parser('train', help='Train memory network')
    train_parser.add_argument('dataset', help='The dataset to train.')
    train_parser.set_defaults(which='train')

    test_parser = subparsers.add_parser('test', help='Test memory network')
    test_parser.add_argument('dataset', help='The dataset to test')
    test_parser.set_defaults(which='test')

    args = parser.parse_args()

    # Read global config
    globals.read_configuration(args.config)

    if args.which == 'train':
        train(args.dataset)
    elif args.which == 'test':
        test(args.dataset)
コード例 #7
0
ファイル: entity_linking.py プロジェクト: DenXX/aqqu
def main_entity_link_text():
    globals.read_configuration('config.cfg')
    entity_linker = globals.get_entity_linker()
    parser = globals.get_parser()
    from text2kb.utils import get_questions_serps
    question_search_results = get_questions_serps()
    globals.logger.setLevel("DEBUG")
    import operator
    while True:
        print "Please enter some text: "
        text = sys.stdin.readline().strip().decode('utf-8')
        tokens = parser.parse(text).tokens
        print "Entities:", entity_linker.identify_entities_in_document(tokens, max_token_window=5)
        entities = {}
        tokens = {}

        if text in question_search_results:
            for doc in question_search_results[text][:10]:
                print doc
                title = doc.title
                snippet = doc.snippet
                snippet_tokens = parser.parse(title + "\n" + snippet).tokens
                for token in snippet_tokens:
                    if token.lemma not in tokens:
                        tokens[token.lemma] = 0
                    tokens[token.lemma] += 1
                for entity in entity_linker.identify_entities_in_document(snippet_tokens):
                    if entity['mid'] not in entities:
                        entities[entity['mid']] = entity
                    else:
                        entities[entity['mid']]['count'] += entity['count']
        print sorted(entities.values(), key=operator.itemgetter('count'), reverse=True)[:50]
コード例 #8
0
ファイル: entity_linking.py プロジェクト: DenXX/aqqu
def test_new_entity_linker():
    globals.read_configuration('config.cfg')
    from query_translator.translator import SparqlQueryTranslator
    query_translator = SparqlQueryTranslator.init_from_config()
    while True:
        question = sys.stdin.readline().strip()
        print "Translation: ", query_translator.translate_query(question)
コード例 #9
0
ファイル: learn_notable_types.py プロジェクト: DenXX/aqqu
def train_type_model():
    globals.read_configuration('config.cfg')
    parser = globals.get_parser()
    scorer_globals.init()

    datasets = ["webquestions_split_train", ]

    parameters = translator.TranslatorParameters()
    parameters.require_relation_match = False
    parameters.restrict_answer_type = False

    feature_extractor = FeatureExtractor(False, False, n_gram_types_features=True)
    features = []
    labels = []
    for dataset in datasets:
        queries = get_evaluated_queries(dataset, True, parameters)
        for index, query in enumerate(queries):
            tokens = [token.lemma for token in parser.parse(query.utterance).tokens]
            n_grams = get_grams_feats(tokens)

            answer_entities = [mid for answer in query.target_result
                               for mid in KBEntity.get_entityid_by_name(answer, keep_most_triples=True)]
            correct_notable_types = set(filter(lambda x: x,
                                               [KBEntity.get_notable_type(entity_mid)
                                                for entity_mid in answer_entities]))

            other_notable_types = set()
            for candidate in query.eval_candidates:
                entities = [mid for entity_name in candidate.prediction
                            for mid in KBEntity.get_entityid_by_name(entity_name, keep_most_triples=True)]
                other_notable_types.update(set([KBEntity.get_notable_type(entity_mid) for entity_mid in entities]))
            incorrect_notable_types = other_notable_types.difference(correct_notable_types)

            for type in correct_notable_types.union(incorrect_notable_types):
                if type in correct_notable_types:
                    labels.append(1)
                else:
                    labels.append(0)
                features.append(feature_extractor.extract_ngram_features(n_grams, [type, ], "type"))

    with open("type_model_data.pickle", 'wb') as out:
        pickle.dump((features, labels), out)

    label_encoder = LabelEncoder()
    labels = label_encoder.fit_transform(labels)
    vec = DictVectorizer(sparse=True)
    X = vec.fit_transform(features)
    feature_selector = SelectPercentile(chi2, percentile=5).fit(X, labels)
    vec.restrict(feature_selector.get_support())
    X = feature_selector.transform(X)
    type_scorer = SGDClassifier(loss='log', class_weight='auto',
                                n_iter=1000,
                                alpha=1.0,
                                random_state=999,
                                verbose=5)
    type_scorer.fit(X, labels)
    with open("type-model.pickle", 'wb') as out:
        pickle.dump((vec, type_scorer), out)
コード例 #10
0
def main():
    import argparse
    parser = argparse.ArgumentParser(description='Learn or test a'
                                     ' scorer model.')
    parser.add_argument('--no-cached',
                        default=False,
                        action='store_true',
                        help='Don\'t use cached data if available.')
    parser.add_argument('--config',
                        default='config.cfg',
                        help='The configuration file to use.')
    subparsers = parser.add_subparsers(help='command help')
    train_parser = subparsers.add_parser('train', help='Train a scorer.')
    train_parser.add_argument('scorer_name', help='The scorer to train.')
    train_parser.set_defaults(which='train')
    test_parser = subparsers.add_parser('test', help='Test a scorer.')
    test_parser.add_argument('scorer_name', help='The scorer to test.')
    test_parser.add_argument('test_dataset',
                             help='The dataset on which to test the scorer.')
    test_parser.add_argument('--avg_runs',
                             type=int,
                             default=1,
                             help='Over how many runs to average.')
    test_parser.set_defaults(which='test')
    cv_parser = subparsers.add_parser('cv', help='Cross-validate a scorer.')
    cv_parser.add_argument('scorer_name', help='The scorer to test.')
    cv_parser.add_argument('dataset',
                           help='The dataset on which to compute cv scores.')
    cv_parser.add_argument('--n_folds',
                           type=int,
                           default=6,
                           help='The number of folds.')
    cv_parser.add_argument('--avg_runs',
                           type=int,
                           default=1,
                           help='Over how many runs to average.')
    cv_parser.set_defaults(which='cv')

    args = parser.parse_args()
    # Read global config.
    globals.read_configuration(args.config)
    # Fix randomness.
    random.seed(999)
    use_cache = not args.no_cached
    if args.which == 'train':
        train(args.scorer_name, use_cache)
    elif args.which == 'test':
        test(args.scorer_name,
             args.test_dataset,
             use_cache,
             avg_runs=args.avg_runs)
    elif args.which == 'cv':
        cv(args.scorer_name,
           args.dataset,
           use_cache,
           n_folds=args.n_folds,
           avg_runs=args.avg_runs)
コード例 #11
0
ファイル: learner.py プロジェクト: ShuaiyiLiu/aqqu
def main():
    import argparse
    parser = argparse.ArgumentParser(description='Learn or test a'
                                                 ' scorer model.')
    parser.add_argument('--no-cached',
                        default=False,
                        action='store_true',
                        help='Don\'t use cached data if available.')
    parser.add_argument('--config',
                        default='config.cfg',
                        help='The configuration file to use.')
    subparsers = parser.add_subparsers(help='command help')
    train_parser = subparsers.add_parser('train', help='Train a scorer.')
    train_parser.add_argument('scorer_name',
                              help='The scorer to train.')
    train_parser.set_defaults(which='train')
    test_parser = subparsers.add_parser('test', help='Test a scorer.')
    test_parser.add_argument('scorer_name',
                             help='The scorer to test.')
    test_parser.add_argument('test_dataset',
                             help='The dataset on which to test the scorer.')
    test_parser.add_argument('--avg_runs',
                             type=int,
                             default=1,
                             help='Over how many runs to average.')
    test_parser.set_defaults(which='test')
    cv_parser = subparsers.add_parser('cv', help='Cross-validate a scorer.')
    cv_parser.add_argument('scorer_name',
                           help='The scorer to test.')
    cv_parser.add_argument('dataset',
                           help='The dataset on which to compute cv scores.')
    cv_parser.add_argument('--n_folds',
                           type=int,
                           default=6,
                           help='The number of folds.')
    cv_parser.add_argument('--avg_runs',
                           type=int,
                           default=1,
                           help='Over how many runs to average.')
    cv_parser.set_defaults(which='cv')

    args = parser.parse_args()
    # Read global config.
    globals.read_configuration(args.config)
    # Fix randomness.
    random.seed(999)
    use_cache = not args.no_cached
    if args.which == 'train':
        train(args.scorer_name, use_cache)
    elif args.which == 'test':
        test(args.scorer_name, args.test_dataset, use_cache,
             avg_runs=args.avg_runs)
    elif args.which == 'cv':
        cv(args.scorer_name, args.dataset, use_cache, n_folds=args.n_folds,
           avg_runs=args.avg_runs)
コード例 #12
0
def main():
    import argparse
    parser = argparse.ArgumentParser(description="Console based translation.")
    parser.add_argument("ranker_name",
                        default="WQ_Ranker",
                        help="The ranker to use.")
    parser.add_argument("--config",
                        default="config.cfg",
                        help="The configuration file to use.")
    args = parser.parse_args()
    globals.read_configuration(args.config)
    if args.ranker_name not in scorer_globals.scorers_dict:
        logger.error("%s is not a valid ranker" % args.ranker_name)
        logger.error("Valid rankers are: %s " %
                     (" ".join(scorer_globals.scorers_dict.keys())))
    logger.info("Using ranker %s" % args.ranker_name)
    ranker = scorer_globals.scorers_dict[args.ranker_name]
    translator = QueryTranslator.init_from_config()
    translator.set_scorer(ranker)
    while True:
        sys.stdout.write("enter question> ")
        sys.stdout.flush()
        query = sys.stdin.readline().strip()
        logger.info("Translating query: %s" % query)
        results = translator.translate_and_execute_query(query)
        logger.info("Done translating query: %s" % query)
        logger.info("#candidates: %s" % len(results))
        if len(results) > 0:
            best_candidate = results[0].query_candidate

            for result in results:
                candidate = result.query_candidate
                relation = candidate.relations[-1]
                #last_node = candidate.nodes[-1]
                print candidate.graph_as_simple_string()
                print candidate.get_result(include_name=True)
                print candidate.pattern
                print relation.name
                print relation.source_node.entity.entity.name
                print ""
                #print len(candidate.relations), candidate.pattern
                #print last_node.name

            sparql_query = best_candidate.to_sparql_query()
            result_rows = results[0].query_result_rows
            result = []
            # Usually we get a name + mid.
            for r in result_rows:
                if len(r) > 1:
                    result.append("%s (%s)" % (r[1], r[0]))
                else:
                    result.append("%s" % r[0])
            logger.info("SPARQL query: %s" % sparql_query)
            logger.info("Result: %s " % " ".join(result))
コード例 #13
0
ファイル: entity_linking.py プロジェクト: DenXX/aqqu
def get_question_terms():
    import scorer_globals
    globals.read_configuration('config_webentity.cfg')
    scorer_globals.init()
    datasets = ["webquestionstrain", "webquestionstest",]

    question_tokens = set()
    for dataset in datasets:
        queries = load_eval_queries(dataset)
        for index, query in enumerate(queries):
            question_tokens.update(token for token in tokenize(query.utterance))
    print question_tokens
コード例 #14
0
ファイル: console_translator.py プロジェクト: DenXX/aqqu
def main():
    import argparse
    parser = argparse.ArgumentParser(description="Console based translation.")
    parser.add_argument("ranker_name",
                        default="WQ_Ranker",
                        help="The ranker to use.")
    parser.add_argument("--config",
                        default="config.cfg",
                        help="The configuration file to use.")
    args = parser.parse_args()
    globals.read_configuration(args.config)
    scorer_globals.init()
    if args.ranker_name not in scorer_globals.scorers_dict:
        logger.error("%s is not a valid ranker" % args.ranker_name)
        logger.error("Valid rankers are: %s " % (" ".join(scorer_globals.scorers_dict.keys())))
    logger.info("Using ranker %s" % args.ranker_name)
    ranker = scorer_globals.scorers_dict[args.ranker_name]
    translator = SparqlQueryTranslator.init_from_config()
    translator.set_scorer(ranker)
    while True:
        try:
            sys.stdout.write("enter question> ")
            sys.stdout.flush()
            query = sys.stdin.readline().strip()
            logger.info("Translating query: %s" % query)
            results = translator.translate_and_execute_query(query)
            logger.info("Done translating query: %s" % query)
            logger.info("#candidates: %s" % len(results))
            logger.info("------------------- Candidate features ------------------")
            for rank, result in enumerate(results[:10]):
                logger.info("RANK " + str(rank))
                logger.info(result.query_candidate.relations)
                logger.info(result.query_candidate.get_results_text())
                if result.query_candidate.features:
                    logger.info("Features: " + str(result.query_candidate.features))
            logger.info("---------------------------------------------------------")
            if len(results) > 0:
                best_candidate = results[0].query_candidate
                sparql_query = best_candidate.to_sparql_query()
                result_rows = results[0].query_result_rows
                result = []
                # Usually we get a name + mid.
                for r in result_rows:
                    if len(r) > 1:
                        result.append("%s (%s)" % (r[1], r[0]))
                    else:
                        result.append("%s" % r[0])
                logger.info("SPARQL query: %s" % sparql_query)
                logger.info("Result: %s " % " ".join(result))
        except Exception as e:
            logger.error(e.message)
コード例 #15
0
def gen_pos_data():
    globals.read_configuration('../config.cfg')

    fn_sq_train = '../data/simple.train.dev.el.v2.new'
    fn_sq_test = '../data/simple.test.el.v2.new'
    fn_sq_train_pos = '../data/simple.train.dev.el.v2.pos'
    fn_sq_test_pos = '../data/simple.test.el.v2.pos'

    fn_wq_train = '../data/wq.train.complete.v2.new'
    fn_wq_test = '../data/wq.test.complete.v2.new'
    fn_wq_dev = '../data/wq.dev.complete.v2.new'

    fn_wq_train_pos = '../data/wq.train.complete.v2.pos'
    fn_wq_test_pos = '../data/wq.test.complete.v2.pos'
    fn_wq_dev_pos = '../data/wq.dev.complete.v2.pos'

    fn_wq_train_pos_iob = '../data/wq.train.complete.v2.pos.iob'
    fn_wq_test_pos_iob = '../data/wq.test.complete.v2.pos.iob'

    fn_sq_train_pos_iob = '../data/simple.train.dev.el.v2.pos.iob'
    fn_sq_test_pos_iob = '../data/simple.test.el.v2.pos.iob'

    fn_train_pos_iob = '../data/tag.train.pos.iob'
    fn_word = "../data/tag.word.list"
    fn_char = '../data/tag.char.list'
    fn_pos = '../data/pos.list'
    # parser = CoreNLPParser.init_from_config()

    add_pos_feature(fn_sq_train, fn_sq_train_pos + '.tmp')
    # add_pos_feature(fn_sq_test, fn_sq_test_pos+'.tmp')

    # add_pos_feature(fn_wq_train, fn_wq_train_pos +'.tmp')
    # add_pos_feature(fn_wq_dev, fn_wq_dev_pos + '.tmp')
    # add_pos_feature(fn_wq_test, fn_wq_test_pos + '.tmp')

    gen_tagged_sentence_plus_pos(
        [fn_wq_train_pos + '.tmp', fn_wq_dev_pos + '.tmp'],
        fn_wq_train_pos_iob, 'iob')
    gen_tagged_sentence_plus_pos([fn_wq_test_pos + '.tmp'], fn_wq_test_pos_iob,
                                 'iob')

    gen_tagged_sentence_plus_pos([fn_sq_train_pos + '.tmp'],
                                 fn_sq_train_pos_iob, 'iob')
    gen_tagged_sentence_plus_pos([fn_sq_test_pos + '.tmp'], fn_sq_test_pos_iob,
                                 'iob')

    merge_file([fn_wq_train_pos_iob, fn_sq_train_pos_iob], fn_train_pos_iob)
    gen_word_list_for_pos([fn_sq_train_pos + '.tmp', fn_wq_train_pos + '.tmp'],
                          fn_word)
コード例 #16
0
ファイル: example.py プロジェクト: mindis/KBQA
def link_entity_in_simple_question_mt(fn_in, fn_out):
    from multiprocessing import Pool
    MAX_POOL_NUM = 8

    num_line = 0
    with open(fn_in) as fin:
        for _ in fin:
            num_line += 1
    print "There are %d lines to process." % num_line
    chunk_size = 50
    parameters = []
    i = 0
    while i * chunk_size < num_line:
        parameters.append(
            (fn_in, i * chunk_size, min(num_line, (i + 1) * chunk_size)))
        i += 1

    pool = Pool(MAX_POOL_NUM)
    ret_list = pool.imap_unordered(link_entity_one, parameters)
    pool.close()
    globals.read_configuration("config.cfg")
    entity_linker = EntityLinker.init_from_config()
    with open(fn_out, 'w') as fout:
        for l in ret_list:
            for sentence, entity, tokens in l:
                entities = entity_linker.identify_entities_in_tokens(tokens)
                neg_entities = set()
                for e in entities:
                    mid = e.get_mid()
                    if mid == '':
                        continue
                    if mid.startswith('m.'):
                        neg_entities.add(mid)
                    else:
                        print mid, e.name, sentence
                neg_entities -= set([entity])
                print >> fout, json.dumps(
                    {
                        'q': sentence,
                        'pos': entity,
                        'neg': list(neg_entities)
                    },
                    ensure_ascii=False).encode('utf8')
    pool.join()
コード例 #17
0
 def __init__(self):
     self.esfreebase = EsFreebase()
     self.mediate_relations = set()
     conf = globals.read_configuration('../config.cfg')
     mediator_filename = conf.get('FREEBASE', 'mediator-relations')
     with open(mediator_filename) as fin:
         for line in fin:
             rel = line.decode('utf8').strip()
             if rel.startswith('m.'):
                 continue
             self.mediate_relations.add(rel)
コード例 #18
0
ファイル: example.py プロジェクト: mindis/KBQA
def link_entity_one(params):
    fn, start, end = params
    lno = 0
    fin = open(fn)
    while lno < start:
        fin.readline()
        lno += 1
    globals.read_configuration("config.cfg")
    parser = CoreNLPParser.init_from_config()
    ret = []

    for i in xrange(start, end):
        line = fin.readline()
        ll = line.decode('utf8').strip().split('\t')
        if len(ll) != 5:
            continue
        tokens = parser.parse(ll[4])
        ret.append((ll[4], ll[0], tokens))

    fin.close()
    return ret
コード例 #19
0
ファイル: entity_linking.py プロジェクト: DenXX/aqqu
def main_entities():
    globals.read_configuration('config.cfg')
    from text2kb.utils import get_questions_serps
    from text2kb.utils import get_documents_entities
    serps = get_questions_serps()
    doc_entities = get_documents_entities()
    import operator
    while True:
        print "Please enter a question:"
        question = sys.stdin.readline().strip()
        if question in serps:
            docs = serps[question][:10]
            entities = {}
            for doc in docs:
                for entity in doc_entities[doc.url].itervalues():
                    e = (entity['mid'], entity['name'])
                    if e not in entities:
                        entities[e] = 0
                    entities[e] += entity['count']
            top_entities = entities.items()
            top_entities.sort(key=operator.itemgetter(1), reverse=True)
            print top_entities[:50]
コード例 #20
0
ファイル: console_translator.py プロジェクト: xiaozhuyfk/aqqu
def main():
    import argparse
    parser = argparse.ArgumentParser(description = "Console based translation.")
    parser.add_argument("ranker_name",
                        default = "WQ_Ranker",
                        help = "The ranker to use.")
    parser.add_argument("--config",
                        default = "config.cfg",
                        help = "The configuration file to use.")
    args = parser.parse_args()
    globals.read_configuration(args.config)
    if args.ranker_name not in scorer_globals.scorers_dict:
        logger.error("%s is not a valid ranker" % args.ranker_name)
        logger.error("Valid rankers are: %s " % (" ".join(scorer_globals.scorers_dict.keys())))
    logger.info("Using ranker %s" % args.ranker_name)
    ranker = scorer_globals.scorers_dict[args.ranker_name]
    translator = QueryTranslator.init_from_config()
    translator.set_scorer(ranker)

    writeFile(test_file, "", "w")

    linker = translator.entity_linker
    entities = linker.surface_index.get_entities_for_surface("spanish")
    for (e, score) in entities:
        print e.name, score

    """
    for i in xrange(len(rank_error)):
        query = rank_error[i]
        results = translator.translate_and_execute_query(query)
        if (len(results) > 0):
            correct = results[rank_pos[i]].query_candidate

            candidate = results[0].query_candidate
            sparql_query = candidate.to_sparql_query()
            correct_query = correct.to_sparql_query()

            result_rows = results[0].query_result_rows
            result = []
            for r in result_rows:
                if len(r) > 1:
                    result.append("%s (%s)" % (r[1], r[0]))
                else:
                    result.append("%s" % r[0])
            correct_result_rows = results[rank_pos[i]].query_result_rows
            correct_result = []
            for r in correct_result_rows:
                if len(r) > 1:
                    correct_result.append("%s (%s)" % (r[1], r[0]))
                else:
                    correct_result.append("%s" % r[0])

            extractor = FeatureExtractor(True, False, None)
            features = extractor.extract_features(candidate)
            y_features = extractor.extract_features(correct)
            diff = feature_diff(features, y_features)

            X = ranker.dict_vec.transform(diff)
            if ranker.scaler:
                X = ranker.scaler.transform(X)
            ranker.model.n_jobs = 1
            p = ranker.model.predict(X)
            c = ranker.label_encoder.inverse_transform(p)
            res = c[0]

            root_name = "Root Node: %s\n" % (candidate.root_node.entity.name.encode('utf-8'))
            query_str = "SPARQL query: %s\n" % (sparql_query.encode('utf-8'))
            graph_str = "Candidate Graph: %s\n" % (candidate.graph_as_string().encode('utf-8'))
            graph_str_simple = "Simple Candidate Graph: %s" % (candidate.graph_as_simple_string().encode('utf-8'))
            y_graph_str_simple = "Answer Candidate Graph: %s" % (correct.graph_as_simple_string().encode('utf-8'))
            result_str = "Result: %s\n" % ((" ".join(result)).encode('utf-8'))
            correct_result_str = "Correct Result: %s\n" % ((" ".join(correct_result)).encode('utf-8'))

            feature_str = "Result Features: %s\n" % (str(features).encode('utf-8'))
            y_feature_str = "Answer Features: %s\n" %(str(y_features).encode('utf-8'))
            diff_str = "Feature Diff: %s\n" %(str(diff).encode('utf-8'))

            x_str = "X vector: %s\n" % (str(X).encode('utf-8'))
            p_str = "Predict vector: %s\n" % (str(p).encode('utf-8'))
            c_str = "C vector: %s\n" % (str(c).encode('utf-8'))
            cmp_res = "Compare result: %d\n" % (res)

            writeFile(test_file, root_name, "a")
            writeFile(test_file, result_str, "a")
            writeFile(test_file, correct_result_str, "a")

            writeFile(test_file, graph_str_simple, "a")
            writeFile(test_file, y_graph_str_simple, "a")

            writeFile(test_file, feature_str, "a")
            writeFile(test_file, y_feature_str, "a")
            writeFile(test_file, diff_str, "a")

            writeFile(test_file, x_str, "a")
            writeFile(test_file, p_str, "a")
            writeFile(test_file, c_str, "a")
            writeFile(test_file, cmp_res, "a")
        writeFile(test_file, "\n", "a")
    """

    """
    for query in test_set + unidentified:
        results = translator.translate_and_execute_query(query)
        if (len(results) > 0):
            for i in xrange(len(results)):
                if (i > 10):
                    break
                candidate = results[i].query_candidate
                sparql_query = candidate.to_sparql_query()
                result_rows = results[i].query_result_rows
                result = []
                for r in result_rows:
                    if len(r) > 1:
                        result.append("%s (%s)" % (r[1], r[0]))
                    else:
                        result.append("%s" % r[0])

                extractor = FeatureExtractor(True, False, None)
                features = extractor.extract_features(candidate)

                root_name = "%d Root Node: %s\n" % (i+1, candidate.root_node.entity.name.encode('utf-8'))
                query_str = "%d SPARQL query: %s\n" % (i+1, sparql_query.encode('utf-8'))
                graph_str = "%d Candidate Graph: %s\n" % (i+1, candidate.graph_as_string().encode('utf-8'))
                graph_str_simple = "%d Simple Candidate Graph: %s" % (i+1, candidate.graph_as_simple_string().encode('utf-8'))
                result_str = "%d Result: %s\n" % (i+1, (" ".join(result)).encode('utf-8'))
                feature_str = "%d Features: %s\n" % (i+1, str(features).encode('utf-8'))
                writeFile(test_file, root_name, "a")
                #writeFile(test_file, graph_str, "a")
                writeFile(test_file, graph_str_simple, "a")
                writeFile(test_file, feature_str, "a")
                #writeFile(test_file, query_str, "a")
                writeFile(test_file, result_str, "a")
        writeFile(test_file, "\n", "a")
    """

    while True:
        sys.stdout.write("enter question> ")
        sys.stdout.flush()
        query = sys.stdin.readline().strip()
        logger.info("Translating query: %s" % query)
        results = translator.translate_and_execute_query(query)
        logger.info("Done translating query: %s" % query)
        logger.info("#candidates: %s" % len(results))
        if len(results) > 0:
            best_candidate = results[0].query_candidate
            sparql_query = best_candidate.to_sparql_query()
            result_rows = results[0].query_result_rows
            result = []
            # Usually we get a name + mid.
            for r in result_rows:
                if len(r) > 1:
                    result.append("%s (%s)" % (r[1], r[0]))
                else:
                    result.append("%s" % r[0])
            logger.info("SPARQL query: %s" % sparql_query)
            logger.info("Result: %s " % " ".join(result))
コード例 #21
0
ファイル: dump_qa_entity_pairs.py プロジェクト: DenXX/aqqu

if __name__ == "__main__":
    # print_sparql_queries()
    # exit()

    import argparse

    parser = argparse.ArgumentParser(description="Dump qa entity pairs.")
    parser.add_argument("--config",
                        default="config.cfg",
                        help="The configuration file to use.")
    parser.add_argument("--output",
                        help="The file to dump results to.")
    args = parser.parse_args()
    globals.read_configuration(args.config)
    scorer_globals.init()

    parameters = translator.TranslatorParameters()
    parameters.require_relation_match = False
    parameters.restrict_answer_type = False

    # datasets = ["webquestions_split_train", "webquestions_split_dev",]
    # datasets = ["webquestions_split_train_externalentities", "webquestions_split_dev_externalentities",]
    # datasets = ["webquestions_split_train_externalentities3", "webquestions_split_dev_externalentities3",]
    datasets = ["webquestions_train_externalentities_all", "webquestions_test_externalentities_all", ]

    count = 0
    correct_relations = set()
    positions = []
    for dataset in datasets:
コード例 #22
0
ファイル: web_features.py プロジェクト: DenXX/aqqu
        "token_tfidf": SparseVector.from_2pos(combined_doc_snippet_token2pos,
                                              element_calc_func=SparseVector.compute_tfidf_token_elements),
    }

    # Cache the computed vectors.
    _documents_vectors_cache[question_text] = (documents_vectors, snippets_vectors, fragment_vectors,
                                               combined_documents_vector, combined_document_snippets_vector)

    return documents_vectors, snippets_vectors, fragment_vectors, combined_documents_vector, combined_document_snippets_vector


def create_document_vectors_cache(questions):
    cache_file = globals.config.get('WebSearchFeatures', 'document-vectors')
    logger.info("Caching document vectors...")
    with open(cache_file, 'wx') as out:
        for index, question in enumerate(questions):
            question_token2pos = dict((token, [1, ]) for token in tokenize(question))
            generate_document_vectors(question, question_token2pos, get_questions_serps())
            pickle.dump((question, _documents_vectors_cache[question]), out)
            if index % 100 == 0:
                logger.info("Cached document vectors for %d questions" % index)


if __name__ == "__main__":
    logging.basicConfig(format='%(asctime)s : %(levelname)s '
                               ': %(module)s : %(message)s',
                        level=logging.INFO)
    globals.read_configuration('config_wikipedia.cfg')
    serps = get_questions_serps()
    create_document_vectors_cache(serps.keys())
コード例 #23
0
ファイル: learn_notable_types.py プロジェクト: DenXX/aqqu
    X = feature_selector.transform(X)
    type_scorer = SGDClassifier(loss='log', class_weight='auto',
                                n_iter=1000,
                                alpha=1.0,
                                random_state=999,
                                verbose=5)
    type_scorer.fit(X, labels)
    with open("type-model.pickle", 'wb') as out:
        pickle.dump((vec, type_scorer), out)


if __name__ == "__main__":
    extract_npmi_ngram_type_pairs()
    exit()

    globals.read_configuration('config.cfg')
    parser = globals.get_parser()
    scorer_globals.init()

    datasets = ["webquestions_split_train", ]
    # datasets = ["webquestions_split_train_externalentities", "webquestions_split_dev_externalentities",]
    # datasets = ["webquestions_split_train_externalentities3", "webquestions_split_dev_externalentities3",]

    data = []
    for dataset in datasets:
        queries = load_eval_queries(dataset)
        for index, query in enumerate(queries):
            tokens = [token.token for token in parser.parse(query.utterance).tokens]
            answer_entities = [mid for answer in query.target_result
                               for mid in KBEntity.get_entityid_by_name(answer, keep_most_triples=True)]
            notable_types = [KBEntity.get_notable_type(entity_mid) for entity_mid in answer_entities]
コード例 #24
0
ファイル: application.py プロジェクト: xiaozhuyfk/parallel
from OpenSSL import SSL
import os

from relation_matching import modules
import globals

import argparse
parser = argparse.ArgumentParser(description='Choose to learn or test AMA')

parser.add_argument('--config',
                    default='config.cfg',
                    help='The configuration file to use')
args = parser.parse_args()

# Read global config
globals.read_configuration(args.config)

# Load modules
modules.init_from_config(args)

#context = SSL.Context(SSL.SSLv23_METHOD)
cer = os.path.join(os.path.dirname(__file__), 'certs/development.crt')
key = os.path.join(os.path.dirname(__file__), 'certs/development.key')
context = (cer, key)

application = Flask(__name__)
app = application
CORS(app)


@app.route('/')
コード例 #25
0
        #         if t[0] not in nodes_:
        #             nodes_[t[0]] = {'category': 1, 'name': t[0], 'value': 1}
        #         if t[0] == mid:
        #             nodes_[t[0]]['category'] = 0
        #             nodes_[t[0]]['value'] = 10
        #         if t[2] not in nodes_:
        #             nodes_[t[2]] = {'category': 1, 'name': t[2], 'value': 4}
        #             if t[3] == 0 or t[3] == 2:
        #                 nodes_[t[2]]['category'] = 2
        #
        #     for m in nodes_.keys():
        #         name, name_info = DBManager.get_name(m)
        #         nodes_[m]['label'] = name
        #
        #     nodes = nodes_.values()
        #     for t in subgraph:
        #         links.append({'source': nodes_[t[0]]['name'], 'target': nodes_[t[2]]['name'], 'weight': 2, 'name': t[1]})
        #     print 'node', nodes
        #     print 'links', links
        #
        #     return json.dumps({'nodes':nodes, 'links':links})

    def run(self, port, debug=False):
        self.app.run(host='0.0.0.0', port=port, debug=debug)


if __name__ == '__main__':
    globals.read_configuration('../config.cfg')
    obj = KBQADemo(sys.argv[1])
    obj.run(7778, True)
コード例 #26
0
ファイル: learn_notable_types.py プロジェクト: DenXX/aqqu
def extract_npmi_ngram_type_pairs():
    globals.read_configuration('config.cfg')
    scorer_globals.init()

    datasets = ["webquestions_split_train", ]

    parameters = translator.TranslatorParameters()
    parameters.require_relation_match = False
    parameters.restrict_answer_type = False

    n_gram_type_counts = dict()
    type_counts = dict()
    n_gram_counts = dict()
    total = 0
    year_pattern = re.compile("[0-9]+")
    for dataset in datasets:
        queries = get_evaluated_queries(dataset, True, parameters)
        for index, query in enumerate(queries):
            if query.oracle_position != -1 and query.oracle_position <= len(query.eval_candidates):
                correct_candidate = query.eval_candidates[query.oracle_position - 1]
                logger.info(query.utterance)
                logger.info(correct_candidate.query_candidate)

                n_grams = set(get_n_grams_features(correct_candidate.query_candidate))

                answer_entities = [mid for answer in query.target_result
                                   if year_pattern.match(answer) is None
                                   for mid in KBEntity.get_entityid_by_name(answer, keep_most_triples=True)]
                correct_notable_types = set(filter(lambda x: x,
                                                   [KBEntity.get_notable_type(entity_mid)
                                                    for entity_mid in answer_entities]))

                for notable_type in correct_notable_types:
                    if notable_type not in type_counts:
                        type_counts[notable_type] = 0
                    type_counts[notable_type] += 1

                for n_gram in n_grams:
                    if n_gram not in n_gram_counts:
                        n_gram_counts[n_gram] = 0
                    n_gram_counts[n_gram] += 1

                    for notable_type in correct_notable_types:
                        pair = (n_gram, notable_type)
                        if pair not in n_gram_type_counts:
                            n_gram_type_counts[pair] = 0
                        n_gram_type_counts[pair] += 1

                total += 1

    npmi = dict()
    from math import log
    for n_gram_type_pair, n_gram_type_count in n_gram_type_counts.iteritems():
        if n_gram_type_count > 4:
            n_gram, type = n_gram_type_pair
            npmi[n_gram_type_pair] = (log(n_gram_type_count) - log(n_gram_counts[n_gram]) - log(type_counts[type]) +
                                        log(total)) / (-log(n_gram_type_count) + log(total))

    with open("type_model_npmi.pickle", 'wb') as out:
        pickle.dump(npmi, out)

    import operator
    npmi = sorted(npmi.items(), key=operator.itemgetter(1), reverse=True)
    print "\n".join(map(str, npmi[:50]))