def test(self, data_with_gold, verbose=False):
     graphs, gold_answers = data_with_gold
     predicted_indices = self.apply_on_batch(graphs, verbose=verbose)
     successes = deque()
     avg_metrics = np.zeros(3)
     for i, sorted_indices in enumerate(
             tqdm.tqdm(predicted_indices, ascii=True,
                       disable=(not verbose))):
         sorted_indices = deque(sorted_indices)
         if sorted_indices:
             retrieved_answers = []
             while not retrieved_answers and sorted_indices:
                 index = sorted_indices.popleft()
                 g = graphs[i][index]
                 retrieved_answers = wdaccess.query_graph_denotations(
                     g).get('e1', [])
             retrieved_answers = wdaccess.label_query_results(
                 retrieved_answers)
             metrics = evaluation.retrieval_prec_rec_f1_with_altlabels(
                 gold_answers[i], retrieved_answers)
             if metrics[-1]:
                 successes.append((i, metrics[-1], g))
             avg_metrics += metrics
     avg_metrics /= len(gold_answers)
     return successes, avg_metrics
예제 #2
0
def test_ground_with_model_hops():
    stages.HOP_TYPES = {'hopUp', 'hopDown'}
    ungrounded_graph = {
        'edgeSet': [],
        'fragments': [(['thailand'], 'LOCATION'), (['religion'], 'NN')],
        'tokens': ['what', 'religion', 'in', 'thailand', '?']
    }
    print(ungrounded_graph)
    chosen_graphs = staged_generation.generate_with_model(ungrounded_graph,
                                                          trainablemodel,
                                                          beam_size=10)
    g = chosen_graphs[0]
    model_answers = wdaccess.query_graph_denotations(g[0]).get("e1")
    print(model_answers)
    print(chosen_graphs)
예제 #3
0
def ground_one_with_gold(s_g, gold_answers, min_fscore):
    grounded_graphs = [apply_grounding(s_g, p) for p in find_groundings(s_g)]
    logger.debug("Number of possible groundings: {}".format(
        len(grounded_graphs)))
    logger.debug("First one: {}".format(grounded_graphs[:1]))
    query_results_per_graph = [
        wdaccess.query_graph_denotations(s_g) for s_g in grounded_graphs
    ]
    topics_per_graph = [
        wdaccess.query_graph_topics(s_g) for s_g in grounded_graphs
    ]
    retrieved_answers = [r.get('e1', []) for r in query_results_per_graph]
    post_process_results = wdaccess.label_query_results if generation_p[
        'label.query.results'] else wdaccess.answers_from_results
    retrieved_answers = [
        post_process_results(answer_set) for answer_set in retrieved_answers
    ]
    retrieved_answers = [
        post_process_answers_given_graph(answer_set, grounded_graphs[i])
        for i, answer_set in enumerate(retrieved_answers)
    ]
    logger.debug("Number of retrieved answer sets: {}. Example: {}".format(
        len(retrieved_answers),
        retrieved_answers[0][:10] if len(retrieved_answers) > 0 else []))
    evaluation_results = [
        evaluation.retrieval_prec_rec_f1_with_altlabels(
            gold_answers, retrieved_answers[i])
        for i in range(len(grounded_graphs))
    ]
    assert len(retrieved_answers) == len(query_results_per_graph) == len(
        topics_per_graph)
    for i, r in enumerate(topics_per_graph):
        grounded_graphs[i] = apply_topics(grounded_graphs[i],
                                          r.get("topic", {}))
        r['e1'] = retrieved_answers[i]
    chosen_graphs = [(grounded_graphs[i], evaluation_results[i],
                      query_results_per_graph[i]['e1'])
                     for i in range(len(grounded_graphs))
                     if evaluation_results[i][2] > min_fscore]
    not_chosen_graphs = [(grounded_graphs[i], (0.0, 0.0, 0.0),
                          query_results_per_graph[i].get('num_answers', 0))
                         for i in range(len(grounded_graphs))
                         if evaluation_results[i][2] < 0.01]
    return chosen_graphs, not_chosen_graphs
예제 #4
0
def test_ground_with_model2():
    stages.HOP_TYPES = set()
    entity_linking.entity_linking_p["max.entity.options"] = 1
    entity_linking.entity_linking_p["np.parser"] = "aida"
    ungrounded_graph = {
        'edgeSet': [],
        'fragments': [(['Garry', 'Marshall'], 'NNP'),
                      (['Julia', 'Roberts'], 'NNP'), (['films'], 'NN')],
        'tokens':
        "In which films directed by Garry Marshall was Julia Roberts starring ?"
        .split()
    }
    staged_generation.generation_p["strict.classes"] = False
    chosen_graphs = staged_generation.generate_with_model(ungrounded_graph,
                                                          trainablemodel,
                                                          beam_size=10)
    g = chosen_graphs[0]
    model_answers = wdaccess.query_graph_denotations(g[0]).get("e1")
    print(model_answers)
    print(chosen_graphs)
예제 #5
0
def test_ground_with_model():
    stages.HOP_TYPES = set()
    webquestions_entities = webquestions.extract_question_entities()
    i = 3
    print(webquestions.get_question_tokens(i))
    entity_linking.entity_linking_p["max.entity.options"] = 1
    question_entities = webquestions_entities[i]
    ungrounded_graph = {
        'tokens': webquestions.get_question_tokens(i),
        'edgeSet': [],
        'fragments': question_entities[:2]
    }
    staged_generation.generation_p["strict.classes"] = False
    chosen_graphs = staged_generation.generate_with_model(ungrounded_graph,
                                                          trainablemodel,
                                                          beam_size=10)
    g = chosen_graphs[0]
    model_answers = wdaccess.query_graph_denotations(g[0]).get("e1")
    print(model_answers)
    print(chosen_graphs)
def generate(path_to_model, config_file_path):

    config = utils.load_config(config_file_path)
    if "evaluation" not in config:
        print("Evaluation parameters not in the config file!")
        sys.exit()
    config_global = config.get('global', {})
    np.random.seed(config_global.get('random.seed', 1))

    logger = logging.getLogger(__name__)
    logger.setLevel(config['logger']['level'])
    ch = logging.StreamHandler()
    ch.setLevel(config['logger']['level'])
    logger.addHandler(ch)
    # logging.basicConfig(level=config['logger']['level'])

    wdaccess.wdaccess_p["timeout"] = config['wikidata'].get("timeout", 20)
    wdaccess.wdaccess_p['wikidata_url'] = config['wikidata'].get("backend", "http://knowledgebase:8890/sparql")
    wdaccess.sparql_init()

    assert not(config['webquestions']["no.ne.tags"] and config['evaluation']["only.named.entities"])
    entity_linking.entity_linking_p["no.ne.tags"] = config['webquestions']["no.ne.tags"]
    assert not(entity_linking.entity_linking_p["no.ne.tags"] and config['evaluation']["only.named.entities"])
    entity_linking.entity_linking_p["max.entity.options"] = config['evaluation']["max.entity.options"]
    entity_linking.entity_linking_p["global.entity.grouping"] = not config['webquestions']["no.ne.tags"] and entity_linking.entity_linking_p["overlaping.nn.ne"]
    entity_linking.entity_linking_p["np.parser"] = config['evaluation']["np.parser"]

    wdaccess.wdaccess_p["restrict.hop"] = config['wikidata'].get("restrict.hop", False)
    wdaccess.update_sparql_clauses()
    staged_generation.generation_p["use.whitelist"] = config['evaluation'].get("use.whitelist", False)
    staged_generation.generation_p["strict.structure"] = config['evaluation'].get("strict.structure", False)
    staged_generation.generation_p["v.structure"] = config['evaluation'].get("v.structure", True)
    staged_generation.generation_p["max.num.entities"] = config['evaluation'].get("max.num.entities", 3)
    staged_generation.generation_p["topics.at.test"] = config['evaluation'].get("topics.at.test", False)
    if not config['wikidata'].get("addclass.action", False):
        stages.RESTRICT_ACTIONS = [stages.add_entity_and_relation, stages.last_relation_temporal,
                                   stages.add_temporal_relation, stages.last_relation_numeric]

    logger.debug("max.entity.options: {}".format(entity_linking.entity_linking_p["max.entity.options"]))
    if 'hop.types' in config['wikidata']:
        stages.HOP_TYPES = set(config['wikidata']['hop.types'])
    if 'arg.types' in config['wikidata']:
        stages.ARG_TYPES = set(config['wikidata']['arg.types'])
    if 'filter.out.relation.classes' in config['wikidata']:
        wdaccess.FILTER_RELATION_CLASSES = set(config['wikidata']['filter.out.relation.classes'])
    logger.debug("entity_linking: {}".format(entity_linking.entity_linking_p))
    with open(config['evaluation']['questions']) as f:
        webquestions_questions = json.load(f)
    webquestions = webquestions_io.WebQuestions(config['webquestions'], logger=logger)

    logger.debug('Extracting entities.')
    webquestions_entities = webquestions.extract_question_entities()
    if 'class' not in config['model']:
        config['model']['class'] = path_to_model.split('/')[-1].split('_')[0]
    logger.debug('Loading the model from: {}'.format(path_to_model))
    qa_model = getattr(models, config['model']['class'])(parameters=config['model'], logger=logger)
    qa_model.load_from_file(path_to_model)

    logger.debug('Testing')
    global_answers = []
    avg_metrics = np.zeros(3)
    len_webquestion = webquestions.get_dataset_size()
    for i in tqdm.trange(len_webquestion, ncols=100, ascii=True):
        question_entities = webquestions_entities[i]
        nes = [e for e in question_entities if e[1] != "NN"]
        if config['evaluation'].get('only.named.entities', False) and len(nes) > 0:
            question_entities = nes
        ungrounded_graph = {'tokens': webquestions.get_original_question_tokens(i),
                            'edgeSet': [],
                            'fragments': question_entities}
        chosen_graphs = staged_generation.generate_with_model(ungrounded_graph, qa_model, beam_size=config['evaluation'].get("beam.size", 10))
        model_answers = []
        g = ({},)
        if chosen_graphs:
            j = 0
            while not model_answers and j < len(chosen_graphs):
                g = chosen_graphs[j]
                model_answers = wdaccess.query_graph_denotations(g[0]).get("e1")
                j += 1
        if config['evaluation'].get('label.answers', False):
            gold_answers = [e.lower() for e in webquestions_io.get_answers_from_question(webquestions_questions[i])]
            model_answers_labels = wdaccess.label_query_results(model_answers)
            model_answers_labels = staged_generation.post_process_answers_given_graph(model_answers_labels, g[0])
            metrics = evaluation.retrieval_prec_rec_f1_with_altlabels(gold_answers, model_answers_labels)
            global_answers.append((i, list(metrics), model_answers, model_answers_labels,
                                   [(c_g[0], float(c_g[1])) for c_g in chosen_graphs[:10]]))
        else:
            gold_answers = webquestions_io.get_answers_from_question(webquestions_questions[i])
            metrics = evaluation.retrieval_prec_rec_f1(gold_answers, model_answers)
            global_answers.append((i, list(metrics), model_answers,
                                   [(c_g[0], float(c_g[1])) for c_g in chosen_graphs[:10]]))
        avg_metrics += metrics
        if i % 100 == 0:
            logger.debug("Average f1 so far: {}".format((avg_metrics/(i+1))))
            with open(config['evaluation']["save.answers.to"], 'w') as answers_out:
                json.dump(global_answers, answers_out, sort_keys=True, indent=4)

    print("Average metrics: {}".format((avg_metrics/(len_webquestion))))

    logger.debug('Testing is finished')
    with open(config['evaluation']["save.answers.to"], 'w') as answers_out:
        json.dump(global_answers, answers_out, sort_keys=True, indent=4)