예제 #1
0
    def query_parser(self, question: str, query_info: Dict[str, str],
                     entities_and_types_select: List[str],
                     entity_ids: List[List[str]], type_ids: List[List[str]],
                     rels_from_template: Optional[List[Tuple[str]]] = None) -> List[Tuple[str]]:
        candidate_outputs = []
        question_tokens = nltk.word_tokenize(question)
        query = query_info["query_template"].lower().replace("wdt:p31", "wdt:P31")
        rels_for_search = query_info["rank_rels"]
        rel_types = query_info["rel_types"]
        rels_for_filter = query_info["filter_rels"]
        property_types = query_info["property_types"]
        query_seq_num = query_info["query_sequence"]
        return_if_found = query_info["return_if_found"]
        log.debug(f"(query_parser)query: {query}, {rels_for_search}, {query_seq_num}, {return_if_found}")
        query_triplets = re.findall("{[ ]?(.*?)[ ]?}", query)[0].split(' . ')
        log.debug(f"(query_parser)query_triplets: {query_triplets}")
        query_triplets = [triplet.split(' ')[:3] for triplet in query_triplets]
        triplet_info_list = [("forw" if triplet[2].startswith('?') else "backw", search_source, rel_type)
                             for search_source, triplet, rel_type in zip(rels_for_search, query_triplets, rel_types) if
                             search_source != "do_not_rank"]
        log.debug(f"(query_parser)rel_directions: {triplet_info_list}")
        rel_variables = re.findall(":(r[\d]{1,2})", query)
        entity_ids = [entity[:self.entities_to_leave] for entity in entity_ids]
        if rels_from_template is not None:
            rels = rels_from_template
        else:
            rels = [self.find_top_rels(question, entity_ids, triplet_info)
                    for triplet_info in triplet_info_list]
        rels_list_for_filter = []
        rels_list_for_fill = []
        filter_rel_variables = []
        fill_rel_variables = []
        for rel_variable, rel_list, is_filter in zip(rel_variables, rels, rels_for_filter):
            if is_filter:
                rels_list_for_filter.append(rel_list)
                filter_rel_variables.append(rel_variable)
            else:
                rels_list_for_fill.append(rel_list)
                fill_rel_variables.append(rel_variable)
        log.debug(f"(query_parser)rels: {rels}")
        log.debug(f"rel_variables {rel_variables}, filter_rel_variables: {filter_rel_variables}")
        log.debug(f"rels_list_for_filter: {rels_list_for_filter}")
        log.debug(f"rels_list_for_fill: {rels_list_for_fill}")
        rels_from_query = list(set([triplet[1] for triplet in query_triplets if triplet[1].startswith('?')]))
        if "count" in query:
            answer_ent = re.findall("as (\?[\S]+)", query)
        else:
            answer_ent = re.findall("select [\(]?([\S]+) ", query)
        
        filter_from_query = re.findall("contains\((\?\w), (.+?)\)", query)
        log.debug(f"(query_parser)filter_from_query: {filter_from_query}")

        year = extract_year(question_tokens, question)
        number = extract_number(question_tokens, question)
        log.debug(f"year {year}, number {number}")
        if year:
            for elem in filter_from_query:
                query = query.replace(f"{elem[0]}, n", f"YEAR({elem[0]}), {year}")
        elif number:
            for elem in filter_from_query:
                query = query.replace(f"{elem[0]}, n", f"{elem[0]}, {number}")
        query = query.replace(" where", f" {' '.join(rels_from_query)} where")
        
        log.debug(f"(query_parser)query_with_filtering: {query}")
        rel_combs = make_combs(rels_list_for_fill, permut=False)
        log.debug(f"(query_parser)rel_combs: {rel_combs[:3]}")
        import datetime
        start_time = datetime.datetime.now()
        entity_positions, type_positions = [elem.split('_') for elem in entities_and_types_select.split(' ')]
        log.debug(f"entity_positions {entity_positions}, type_positions {type_positions}")
        selected_entity_ids = [entity_ids[int(pos)-1] for pos in entity_positions if int(pos)>0]
        selected_type_ids = [type_ids[int(pos)-1] for pos in type_positions if int(pos)>0]
        entity_combs = make_combs(selected_entity_ids, permut=True)
        log.debug(f"(query_parser)entity_combs: {entity_combs[:3]}")
        type_combs = make_combs(selected_type_ids, permut=False)
        log.debug(f"(query_parser)type_combs: {type_combs[:3]}")
        for combs in itertools.product(entity_combs, type_combs, rel_combs):
            filled_query, filter_rels = fill_online_query(query, combs[0], combs[1], combs[2], fill_rel_variables,
                                                          filter_rel_variables, rels_list_for_filter)
            candidate_output = self.wiki_parser.get_answer(filled_query)
            
            out_vars = filter_rels + rels_from_query + answer_ent
            candidate_output = [output for output in candidate_output
                if (all([filter_value in output[filter_var[1:]]["value"] for filter_var, filter_value in property_types.items()])
                and all([not output[ent[1:]]["value"].startswith("http://www.wikidata.org/value") for ent in answer_ent]))]
            candidate_outputs += [combs[2][:-1] + [output[var[1:]]["value"] for var in out_vars] for output in candidate_output]
            if return_if_found and candidate_output:
                return candidate_outputs
        log.debug(f"(query_parser)loop time: {datetime.datetime.now() - start_time}")
        log.debug(f"(query_parser)final outputs: {candidate_outputs[:3]}")

        return candidate_outputs
예제 #2
0
    def query_parser(
        self,
        question: str,
        query_info: Dict[str, str],
        entities_and_types_select: List[str],
        entity_ids: List[List[str]],
        type_ids: List[List[str]],
        rels_from_template: Optional[List[Tuple[str]]] = None
    ) -> List[List[Union[Tuple[Any, ...], Any]]]:
        question_tokens = nltk.word_tokenize(question)
        query = query_info["query_template"].lower()
        for old_tok, new_tok in self.replace_tokens:
            query = query.replace(old_tok, new_tok)
        log.debug(
            f"\n_______________________________\nquery: {query}\n_______________________________\n"
        )
        rels_for_search = query_info["rank_rels"]
        rel_types = query_info["rel_types"]
        query_seq_num = query_info["query_sequence"]
        return_if_found = query_info["return_if_found"]
        define_sorting_order = query_info["define_sorting_order"]
        property_types = query_info["property_types"]
        log.debug(
            f"(query_parser)query: {query}, {rels_for_search}, {query_seq_num}, {return_if_found}"
        )
        query_triplets = re.findall("{[ ]?(.*?)[ ]?}", query)[0].split(' . ')
        log.debug(f"(query_parser)query_triplets: {query_triplets}")
        query_triplets = [triplet.split(' ')[:3] for triplet in query_triplets]
        query_sequence_dict = {
            num: triplet
            for num, triplet in zip(query_seq_num, query_triplets)
        }
        query_sequence = []
        for i in range(1, max(query_seq_num) + 1):
            query_sequence.append(query_sequence_dict[i])
        triplet_info_list = [
            ("forw" if triplet[2].startswith('?') else "backw", search_source,
             rel_type) for search_source, triplet, rel_type in zip(
                 rels_for_search, query_triplets, rel_types)
            if search_source != "do_not_rank"
        ]
        log.debug(f"(query_parser)rel_directions: {triplet_info_list}")
        entity_ids = [entity[:self.entities_to_leave] for entity in entity_ids]
        if rels_from_template is not None:
            rels = [[(rel, 1.0) for rel in rel_list]
                    for rel_list in rels_from_template]
        else:
            rels = [
                self.find_top_rels(question, entity_ids, triplet_info)
                for triplet_info in triplet_info_list
            ]
        log.debug(f"(query_parser)rels: {rels}")
        rels_from_query = [
            triplet[1] for triplet in query_triplets
            if triplet[1].startswith('?')
        ]
        answer_ent = re.findall("select [\(]?([\S]+) ", query)
        order_info_nt = namedtuple("order_info", ["variable", "sorting_order"])
        order_variable = re.findall("order by (asc|desc)\((.*)\)", query)
        if order_variable:
            if define_sorting_order:
                answers_sorting_order = order_of_answers_sorting(question)
            else:
                answers_sorting_order = order_variable[0][0]
            order_info = order_info_nt(order_variable[0][1],
                                       answers_sorting_order)
        else:
            order_info = order_info_nt(None, None)
        log.debug(f"question, order_info: {question}, {order_info}")
        filter_from_query = re.findall("contains\((\?\w), (.+?)\)", query)
        log.debug(f"(query_parser)filter_from_query: {filter_from_query}")

        year = extract_year(question_tokens, question)
        number = extract_number(question_tokens, question)
        log.debug(f"year {year}, number {number}")
        if year:
            filter_info = [(elem[0], elem[1].replace("n", year))
                           for elem in filter_from_query]
        elif number:
            filter_info = [(elem[0], elem[1].replace("n", number))
                           for elem in filter_from_query]
        else:
            filter_info = [
                elem for elem in filter_from_query if elem[1] != "n"
            ]
        for unk_prop, prop_type in property_types.items():
            filter_info.append((unk_prop, prop_type))
        log.debug(f"(query_parser)filter_from_query: {filter_from_query}")
        rel_combs = make_combs(rels, permut=False)
        import datetime
        start_time = datetime.datetime.now()
        entity_positions, type_positions = [
            elem.split('_') for elem in entities_and_types_select.split(' ')
        ]
        log.debug(
            f"entity_positions {entity_positions}, type_positions {type_positions}"
        )
        selected_entity_ids = [
            entity_ids[int(pos) - 1] for pos in entity_positions
            if int(pos) > 0
        ]
        selected_type_ids = [
            type_ids[int(pos) - 1] for pos in type_positions if int(pos) > 0
        ]
        entity_combs = make_combs(selected_entity_ids, permut=True)
        type_combs = make_combs(selected_type_ids, permut=False)
        log.debug(
            f"(query_parser)entity_combs: {entity_combs[:3]}, type_combs: {type_combs[:3]},"
            f" rel_combs: {rel_combs[:3]}")
        queries_list = []
        parser_info_list = []
        confidences_list = []
        all_combs_list = list(
            itertools.product(entity_combs, type_combs, rel_combs))
        if self.wiki_file_format == "pickle":
            total_entities_list = list(itertools.chain.from_iterable(selected_entity_ids)) + \
                                  list(itertools.chain.from_iterable(selected_type_ids))
            parse_res = self.wiki_parser(["parse_triplets"],
                                         [total_entities_list])
        for comb_num, combs in enumerate(all_combs_list):
            confidence = np.prod([score for rel, score in combs[2][:-1]])
            confidences_list.append(confidence)
            query_hdt_seq = [
                fill_query(query_hdt_elem, combs[0], combs[1], combs[2])
                for query_hdt_elem in query_sequence
            ]
            if comb_num == 0:
                log.debug(
                    f"\n__________________________\nfilled query: {query_hdt_seq}\n__________________________\n"
                )
            queries_list.append((rels_from_query + answer_ent, query_hdt_seq,
                                 filter_info, order_info, return_if_found))
            parser_info_list.append("query_execute")
            if comb_num == self.max_comb_num:
                break

        candidate_outputs = []
        candidate_outputs_list = self.wiki_parser(parser_info_list,
                                                  queries_list)
        if self.use_api_requester and isinstance(
                candidate_outputs_list, list) and candidate_outputs_list:
            candidate_outputs_list = candidate_outputs_list[0]

        if isinstance(candidate_outputs_list, list) and candidate_outputs_list:
            outputs_len = len(candidate_outputs_list)
            all_combs_list = all_combs_list[:outputs_len]
            confidences_list = confidences_list[:outputs_len]
            for combs, confidence, candidate_output in zip(
                    all_combs_list, confidences_list, candidate_outputs_list):
                candidate_outputs += [[combs[0]] +
                                      [rel for rel, score in combs[2][:-1]] +
                                      output + [confidence]
                                      for output in candidate_output]
            if self.return_all_possible_answers:
                candidate_outputs_dict = defaultdict(list)
                for candidate_output in candidate_outputs:
                    candidate_outputs_dict[(tuple(
                        candidate_output[0]), tuple(
                            candidate_output[1:-2]))].append(
                                candidate_output[-2:])
                candidate_outputs = []
                for (candidate_entity_comb, candidate_rel_comb
                     ), candidate_output in candidate_outputs_dict.items():
                    candidate_outputs.append(
                        list(candidate_rel_comb) + [
                            tuple([ans for ans, conf in candidate_output]),
                            candidate_output[0][1]
                        ])
            else:
                candidate_outputs = [
                    output[1:] for output in candidate_outputs
                ]
        log.debug(
            f"(query_parser)loop time: {datetime.datetime.now() - start_time}")
        log.debug(f"(query_parser)final outputs: {candidate_outputs[:3]}")

        return candidate_outputs
예제 #3
0
    def query_parser(self, question: str, query_info: Dict[str, str],
                     entities_and_types_select: List[str],
                     entity_ids: List[List[str]],
                     type_ids: List[List[str]],
                     rels_from_template: Optional[List[Tuple[str]]] = None) -> List[Tuple[str]]:
        candidate_outputs = []
        question_tokens = nltk.word_tokenize(question)
        query = query_info["query_template"].lower().replace("wdt:p31", "wdt:P31")
        log.debug(f"\n_______________________________\nquery: {query}\n_______________________________\n")
        rels_for_search = query_info["rank_rels"]
        rel_types = query_info["rel_types"]
        query_seq_num = query_info["query_sequence"]
        return_if_found = query_info["return_if_found"]
        log.debug(f"(query_parser)query: {query}, {rels_for_search}, {query_seq_num}, {return_if_found}")
        query_triplets = re.findall("{[ ]?(.*?)[ ]?}", query)[0].split(' . ')
        log.debug(f"(query_parser)query_triplets: {query_triplets}")
        query_triplets = [triplet.split(' ')[:3] for triplet in query_triplets]
        query_sequence_dict = {num: triplet for num, triplet in zip(query_seq_num, query_triplets)}
        query_sequence = []
        for i in range(1, max(query_seq_num) + 1):
            query_sequence.append(query_sequence_dict[i])
        triplet_info_list = [("forw" if triplet[2].startswith('?') else "backw", search_source, rel_type)
                             for search_source, triplet, rel_type in zip(rels_for_search, query_triplets, rel_types) if
                             search_source != "do_not_rank"]
        log.debug(f"(query_parser)rel_directions: {triplet_info_list}")
        entity_ids = [entity[:self.entities_to_leave] for entity in entity_ids]
        if rels_from_template is not None:
            rels = rels_from_template
        else:
            rels = [self.find_top_rels(question, entity_ids, triplet_info)
                    for triplet_info in triplet_info_list]
        log.debug(f"(query_parser)rels: {rels}")
        rels_from_query = [triplet[1] for triplet in query_triplets if triplet[1].startswith('?')]
        answer_ent = re.findall("select [\(]?([\S]+) ", query)
        order_info_nt = namedtuple("order_info", ["variable", "sorting_order"])
        order_variable = re.findall("order by (asc|desc)\((.*)\)", query)
        answers_sorting_order = order_of_answers_sorting(question)
        if order_variable:
            order_info = order_info_nt(order_variable[0][1], answers_sorting_order)
        else:
            order_info = order_info_nt(None, None)
        log.debug(f"question, order_info: {question}, {order_info}")
        filter_from_query = re.findall("contains\((\?\w), (.+?)\)", query)
        log.debug(f"(query_parser)filter_from_query: {filter_from_query}")

        year = extract_year(question_tokens, question)
        number = extract_number(question_tokens, question)
        log.debug(f"year {year}, number {number}")
        if year:
            filter_info = [(elem[0], elem[1].replace("n", year)) for elem in filter_from_query]
        elif number:
            filter_info = [(elem[0], elem[1].replace("n", number)) for elem in filter_from_query]
        else:
            filter_info = [elem for elem in filter_from_query if elem[1] != "n"]
        log.debug(f"(query_parser)filter_from_query: {filter_from_query}")
        rel_combs = make_combs(rels, permut=False)
        import datetime
        start_time = datetime.datetime.now()
        entity_positions, type_positions = [elem.split('_') for elem in entities_and_types_select.split(' ')]
        log.debug(f"entity_positions {entity_positions}, type_positions {type_positions}")
        selected_entity_ids = [entity_ids[int(pos)-1] for pos in entity_positions if int(pos)>0]
        selected_type_ids = [type_ids[int(pos)-1] for pos in type_positions if int(pos)>0]
        entity_combs = make_combs(selected_entity_ids, permut=True)
        type_combs = make_combs(selected_type_ids, permut=False)
        log.debug(f"(query_parser)entity_combs: {entity_combs[:3]}, type_combs: {type_combs[:3]}, rel_combs: {rel_combs[:3]}")
        for comb_num, combs in enumerate(itertools.product(entity_combs, type_combs, rel_combs)):
            query_hdt_seq = [
                fill_query(query_hdt_elem, combs[0], combs[1], combs[2]) for query_hdt_elem in query_sequence]
            if comb_num == 0:
                log.debug(f"\n_______________________________\nfilled query: {query_hdt_seq}\n_______________________________\n")
            candidate_output = self.wiki_parser(
                rels_from_query + answer_ent, query_hdt_seq, filter_info, order_info)
            candidate_outputs += [combs[2][:-1] + output for output in candidate_output]
            if return_if_found and candidate_output:
                return candidate_outputs
        log.debug(f"(query_parser)loop time: {datetime.datetime.now() - start_time}")
        log.debug(f"(query_parser)final outputs: {candidate_outputs[:3]}")

        return candidate_outputs