def ground_one_with_gold(s_g, gold_answers, min_fscore):
    grounded_graphs = [
        apply_grounding(s_g, p)
        for p in graph_queries.get_graph_groundings(s_g)
    ]
    logger.debug("Number of possible groundings: {}".format(
        len(grounded_graphs)))
    logger.debug("First one: {}".format(grounded_graphs[:1]))
    i = 0
    chosen_graphs, not_chosen_graphs = [], []
    last_f1 = 0.0
    while i < len(grounded_graphs) and last_f1 < MIN_F_SCORE_TO_STOP:
        s_g = grounded_graphs[i]
        s_g.denotations = graph_queries.get_graph_denotations(s_g)
        i += 1
        retrieved_answers = s_g.denotations

        evaluation_results = evaluation.retrieval_prec_rec_f1(
            gold_answers, retrieved_answers)
        last_f1 = evaluation_results[2]
        if last_f1 > min_fscore:
            chosen_graphs.append(WithScore(s_g, evaluation_results))
        elif last_f1 < 0.05:
            not_chosen_graphs.append(WithScore(s_g, evaluation_results))
    return chosen_graphs, not_chosen_graphs
def generate_with_model(s, qa_model, beam_size=10):
    pool = [WithScore(s.graphs[0].graph,
                      (0.0, 0.0, 0.0))]  # pool of possible parses
    generated_graphs = []
    iterations = 0

    actions = [
        lambda x: stages.add_entity_and_relation(
            x, leg_length=1) + stages.add_entity_and_relation(
                x, leg_length=2, fixed_relations=stages.LONG_LEG_RELATIONS),
        stages.last_edge_numeric_constraint, stages.add_relation
    ]

    while pool and iterations < 100:
        iterations += 1
        g = pool.pop(0)
        logger.debug("Pool length: {}, Graph: {}".format(len(pool), g))
        master_score = g.scores[2]
        a_i = 0
        chosen_graphs = []
        while a_i < len(actions) and not chosen_graphs:
            suggested_graphs = actions[a_i](g[0])
            suggested_graphs = [
                s_g for s_g in suggested_graphs
                if sum(1 for e in s_g.edges if any(
                    n.startswith("Q") for n in e.nodes()
                    if n) and graph_queries.QUESTION_VAR not in e.nodes()) < 2
            ]
            suggested_graphs = [
                s_g for s_g in suggested_graphs
                if graph_queries.verify_grounding(s_g)
            ]
            logger.debug("Suggested graphs:{}, {}".format(
                len(suggested_graphs), suggested_graphs))
            chosen_graphs += ground_with_model(suggested_graphs,
                                               s,
                                               qa_model,
                                               min_score=master_score,
                                               beam_size=beam_size,
                                               verify_with_wikidata=True)
            a_i += 1

        logger.debug("Chosen graphs length: {}".format(len(chosen_graphs)))
        if len(chosen_graphs) > 0:
            logger.debug("Extending the pool.")
            pool.extend(chosen_graphs)
            logger.debug("Extending the generated graph set: {}".format(
                len(chosen_graphs)))
            generated_graphs.extend(chosen_graphs)
    logger.debug("Iterations {}".format(iterations))
    logger.debug("Generated graphs {}".format(len(generated_graphs)))
    generated_graphs = sorted(generated_graphs,
                              key=lambda x: x[1],
                              reverse=True)
    return generated_graphs
Exemple #3
0
def ground_with_model(input_graphs, s, qa_model, min_score, beam_size=10, verify_with_wikidata=True):
    """

    :param input_graphs: a list of equivalent graph extensions to choose from.
    :param s: sentence
    :param qa_model: a model to evaluate graphs
    :param min_score: filter out graphs that receive a score lower than that from the model.
    :param beam_size: size of the beam
    :return: a list of selected graphs with size = beam_size
    """

    logger.debug("Input graphs: {}".format(len(input_graphs)))
    logger.debug("First input one: {}".format(input_graphs[:1]))

    grounded_graphs = [apply_grounding(s_g, p) for s_g in input_graphs for p in graph_queries.get_graph_groundings(s_g, use_wikidata=verify_with_wikidata)]
    grounded_graphs = filter_second_hops(grounded_graphs)
    logger.debug("Number of possible groundings: {}".format(len(grounded_graphs)))
    if len(grounded_graphs) == 0:
        return []

    sentences = []
    for i in range(0, len(grounded_graphs), V.MAX_NEGATIVE_GRAPHS):
        dummy_sentence = sentence.Sentence()
        dummy_sentence.__dict__.update(s.__dict__)
        dummy_sentence.graphs = [WithScore(s_g, (0.0, 0.0, min_score)) for s_g in grounded_graphs[i:i+V.MAX_NEGATIVE_GRAPHS]]
        sentences.append(dummy_sentence)
    if len(sentences) == 0:
        return []
    samples = V.encode_for_model(sentences, qa_model._model.__class__.__name__)
    model_scores = qa_model.predict_batchwise(*samples).view(-1).data

    logger.debug("model_scores: {}".format(model_scores))
    all_chosen_graphs = [WithScore(grounded_graphs[i], (0.0, 0.0, model_scores[i]))
                         for i in range(len(grounded_graphs)) if model_scores[i] > min_score]

    all_chosen_graphs = sorted(all_chosen_graphs, key=lambda x: x[1], reverse=True)
    if len(all_chosen_graphs) > beam_size:
        all_chosen_graphs = all_chosen_graphs[:beam_size]
    logger.debug("Number of chosen groundings: {}".format(len(all_chosen_graphs)))
    return all_chosen_graphs
def sentence_object_hook(obj):
    if all(k in obj for k in Sentence().__dict__):
        s = Sentence()
        s.__dict__.update(obj)
        s.graphs = [WithScore(*l) for l in s.graphs]
        return s
    if all(k in obj for k in SemanticGraph().__dict__):
        g = SemanticGraph()
        g.__dict__.update(obj)
        g.edges = EdgeList()
        g.edges._list = obj['edges']
        return g
    if all(k in obj for k in DUMMY_EDGE.__dict__):
        e = copy(DUMMY_EDGE)
        e.__dict__.update(obj)
        return e
    return obj
Exemple #5
0
 def __init__(self, input_text=None, tagged=None, entities=None):
     """
     A sentence object.
     #参数包括:input_text、 tagged、entities
     :param input_text: raw input text as a string
     :param tagged: a list of dict objects, one per token, with the output of the POS and NER taggers, see utils
                   for more info
     :param entities: a list of tuples, where each tuple is an entity link (first position is the KB id and
                      the second position is the label)
     """
     self.input_text = input_text if input_text else ""
     self.tagged = tagged if tagged else []
     self.tokens = [t['originalText'] for t in self.tagged]
     self.entities = [{k: e[k]
                       for k in {'type', 'linkings', 'token_ids'}}
                      for e in entities] if entities else []
     self.entities += [{
         'type': 'YEAR',
         'linkings': [(t['originalText'], t['originalText'])],
         'token_ids': [t['index'] - 1]
     } for t in self.tagged if t['pos'] == 'CD' and t['ner'] == 'DATE']
     if get_question_type(self.input_text) == "person":
         self.entities.append({
             'type': 'NN',
             'linkings': [("Q5", 'human')],
             'token_ids': [0]
         })
     if get_question_type(self.input_text) == "location":
         self.entities.append({
             'type':
             'NN',
             'linkings': [("Q618123", 'geographical object')],
             'token_ids': [0]
         })
     self.graphs = [
         WithScore(
             SemanticGraph(free_entities=self.entities, tokens=self.tokens),
             (0.0, 0.0, 0.0))
     ]