Beispiel #1
0
	def __init__(self, raw_question, raw_answerset, raw_query, raw_row, id, parser):
		self.raw_row = raw_row
		self.question = []
		self.sparql = []
		self.id = id
		
		self.question = Question(raw_question, parser.parse_question)
		self.answerset = AnswerSet(raw_answerset, parser.parse_answerset)
		self.sparql = SPARQL(raw_query, parser.parse_sparql)
Beispiel #2
0
    def sort_query(self, linker, kb, parser, qapair, question_type_classifier, force_gold=True):
        logger.info(qapair.sparql)
        logger.info(qapair.question.text)

        # Get Answer from KB online
        status, raw_answer_true = kb.query(qapair.sparql.query.replace("https", "http"))
        answerset_true = AnswerSet(raw_answer_true, parser.parse_queryresult)
        qapair.answerset = answerset_true

        ask_query = False
        count_query = False

        question = qapair.question.text

        question_type = question_type_classifier.predict([question])

        if question_type == 2:
            count_query = True
        elif question_type == 1:
            ask_query = True

        type_confidence = question_type_classifier.predict_proba([question])[0][question_type]
        if isinstance(question_type_classifier.predict_proba([question])[0][question_type], (np.ndarray, list)):
            type_confidence = type_confidence[0]
            type_confidence = float(type_confidence)

        question_type = int(question_type)

        entities, ontologies = linker.do(qapair, force_gold=force_gold)
        precision = None
        recall = None

        if qapair.answerset is None or len(qapair.answerset) == 0:
            return "-Not_Applicable", [], question_type, type_confidence, precision, recall
        else:
            if entities is None or ontologies is None:
                recall = 0.0
                return "-Linker_failed", [], question_type, type_confidence, precision, recall

            logger.info("start finding the minimal subgraph")

            entity_list = []
            for L in range(1, len(entities) + 1):
                for subset in itertools.combinations(entities, L):
                    entity_list.append(subset)
            entity_list = entity_list[::-1]

            relation_list = []
            for L in range(1, len(ontologies) + 1):
                for subset in itertools.combinations(ontologies, L):
                    relation_list.append(subset)
            relation_list = relation_list[::-1]

            combination_list = [(x, y) for x in entity_list for y in relation_list]

            args = Struct()
            base_path = "./learning/treelstm/"
            args.save = os.path.join(base_path, "checkpoints/")
            args.expname = "lc_quad,epoch=5,train_loss=0.08340245485305786"
            args.mem_dim = 150
            args.hidden_dim = 50
            args.num_classes = 2
            args.input_dim = 300
            args.sparse = False
            args.lr = 0.01
            args.wd = 1e-4
            args.data = os.path.join(base_path, "data/lc_quad/")
            args.cuda = False

            generated_queries = []

            for comb in combination_list:
                if len(generated_queries) == 0:
                    generated_queries, question_type, type_confidence = self.generate_query(question, comb[0], comb[1])
                    if len(generated_queries) > 0:
                        ask_query = False
                        count_query = False

                        if int(question_type) == 2:
                            count_query = True
                        elif int(question_type) == 1:
                            ask_query = True
                else:
                    break

            generated_queries.extend(generated_queries)
            if len(generated_queries) == 0:
                recall = 0.0
                return "-without_path", [], question_type, type_confidence, precision, recall

            scores = []
            for s in generated_queries:
                scores.append(s['confidence'])

            scores = np.array(scores)
            inds = scores.argsort()[::-1]
            sorted_queries = [generated_queries[s] for s in inds]
            scores = [scores[s] for s in inds]

            used_answer = []
            uniqueid = []
            for i in range(len(sorted_queries)):
                if sorted_queries[i]['where'] not in used_answer:
                    used_answer.append(sorted_queries[i]['where'])
                    uniqueid.append(i)

            sorted_queries = [sorted_queries[i] for i in uniqueid ]
            scores = [scores[i] for i in uniqueid]

            s_counter = Counter(sorted(scores, reverse=True))
            s_ind = []
            s_i = 0
            for k, v in s_counter.items():
                s_ind.append(range(s_i, s_i + v))
                s_i += v

            output_where = [{"query": " .".join(item["where"]), "correct": False, "target_var": "?u_0"} for item in sorted_queries]
            for item in list(output_where):
                logger.info(item["query"])
            correct = False

            wrongd = {}

            for idx in range(len(sorted_queries)):
                where = sorted_queries[idx]

                if "answer" in where:
                    answerset = where["answer"]
                    target_var = where["target_var"]
                else:
                    target_var = "?u_" + str(where["suggested_id"])
                    raw_answer = kb.query_where(where["where"], target_var, count_query, ask_query)
                    answerset = AnswerSet(raw_answer, parser.parse_queryresult)

                output_where[idx]["target_var"] = target_var
                sparql = SPARQL(kb.sparql_query(where["where"], target_var, count_query, ask_query), ds.parser.parse_sparql)

                answereq = (answerset == qapair.answerset)
                try:
                    sparqleq = (sparql == qapair.sparql)
                except:
                    sparqleq = False

                if answereq != sparqleq:
                    print("error")

                if answerset == qapair.answerset:
                    correct = True
                    output_where[idx]["correct"] = True
                    output_where[idx]["target_var"] = target_var
                    recall = 1.0
                    precision = 1.0
                    correct_index = idx
                    break
                else:
                    if target_var == "?u_0":
                        target_var = "?u_1"
                    else:
                        target_var = "?u_0"
                    raw_answer = kb.query_where(where["where"], target_var, count_query, ask_query)
                    answerset = AnswerSet(raw_answer, parser.parse_queryresult)

                    sparql = SPARQL(kb.sparql_query(where["where"], target_var, count_query, ask_query), ds.parser.parse_sparql)

                    answereq = (answerset == qapair.answerset)
                    try:
                        sparqleq = (sparql == qapair.sparql)
                    except:
                        sparqleq = False

                    if answereq != sparqleq:
                        print("error")

                    if answerset == qapair.answerset:
                        correct = True
                        output_where[idx]["correct"] = True
                        output_where[idx]["target_var"] = target_var
                        recall=1.0
                        precision=1.0
                        correct_index = idx
                        break
                    else:
                        correct = False
                        output_where[idx]["correct"] = False
                        output_where[idx]["target_var"] = target_var
                        intersect = answerset.intersect(qapair.answerset)
                        recall= intersect/len(qapair.answerset)
                        precision= intersect/len(answerset)
                        wrongd[idx] = intersect

            if correct:
                # here the precision and recall is calculated based on the number of correct generated queries
                for si in s_ind:
                    if correct_index in si:
                        if len(si)>1:
                            c_answer = []
                            t_answer = []
                            for j in si:
                                where = sorted_queries[j]

                                if "answer" in where:
                                    answerset = where["answer"]
                                    target_var = where["target_var"]
                                else:
                                    target_var = "?u_" + str(where["suggested_id"])
                                    raw_answer = kb.query_where(where["where"], target_var, count_query, ask_query)
                                    answerset = AnswerSet(raw_answer, parser.parse_queryresult)

                                output_where[j]["target_var"] = target_var
                                sparql = SPARQL(kb.sparql_query(where["where"], target_var, count_query, ask_query),
                                                ds.parser.parse_sparql)

                                answereq = (answerset == qapair.answerset)
                                try:
                                    sparqleq = (sparql == qapair.sparql)
                                except:
                                    sparqleq = False

                                if answereq != sparqleq:
                                    print("error")

                                if len(answerset)>0:
                                    if answerset == qapair.answerset:
                                        c_answer.append(len(answerset))
                                        t_answer.append(len(answerset))
                                    else:
                                        if target_var == "?u_0":
                                            target_var = "?u_1"
                                        else:
                                            target_var = "?u_0"
                                        raw_answer = kb.query_where(where["where"], target_var, count_query, ask_query)
                                        answerset = AnswerSet(raw_answer, parser.parse_queryresult)

                                        sparql = SPARQL(kb.sparql_query(where["where"], target_var, count_query, ask_query),
                                                        ds.parser.parse_sparql)

                                        answereq = (answerset == qapair.answerset)
                                        try:
                                            sparqleq = (sparql == qapair.sparql)
                                        except:
                                            sparqleq = False

                                        if answereq != sparqleq:
                                            print("error")

                                        if answerset == qapair.answerset:
                                            c_answer.append(len(answerset))
                                            t_answer.append(len(answerset))
                                        else:
                                            intersect = answerset.intersect(qapair.answerset)
                                            c_answer.append(intersect)
                                            t_answer.append(len(answerset))
                            precision = sum(c_answer)/sum(t_answer)
                            recall = min(sum(c_answer)/len(qapair.answerset),1.0)
                            break
            else:
                mkey, mvalue = max(wrongd.items(), key=lambda x: x[1])
                for si in s_ind:
                    if mkey in si:
                        if len(si)>1:
                            c_answer = []
                            t_answer = []
                            for j in si:
                                where = sorted_queries[j]

                                if "answer" in where:
                                    answerset = where["answer"]
                                    target_var = where["target_var"]
                                else:
                                    target_var = "?u_" + str(where["suggested_id"])
                                    raw_answer = kb.query_where(where["where"], target_var, count_query, ask_query)
                                    answerset = AnswerSet(raw_answer, parser.parse_queryresult)

                                output_where[j]["target_var"] = target_var
                                sparql = SPARQL(kb.sparql_query(where["where"], target_var, count_query, ask_query),
                                                ds.parser.parse_sparql)

                                answereq = (answerset == qapair.answerset)
                                try:
                                    sparqleq = (sparql == qapair.sparql)
                                except:
                                    sparqleq = False

                                if answereq != sparqleq:
                                    print("error")

                                if len(answerset)>0:
                                    if answerset == qapair.answerset:
                                        c_answer.append(len(answerset))
                                        t_answer.append(len(answerset))
                                    else:
                                        if target_var == "?u_0":
                                            target_var = "?u_1"
                                        else:
                                            target_var = "?u_0"
                                        raw_answer = kb.query_where(where["where"], target_var, count_query, ask_query)
                                        answerset = AnswerSet(raw_answer, parser.parse_queryresult)

                                        sparql = SPARQL(kb.sparql_query(where["where"], target_var, count_query, ask_query),
                                                        ds.parser.parse_sparql)

                                        answereq = (answerset == qapair.answerset)
                                        try:
                                            sparqleq = (sparql == qapair.sparql)
                                        except:
                                            sparqleq = False

                                        if answereq != sparqleq:
                                            print("error")

                                        if answerset == qapair.answerset:
                                            c_answer.append(len(answerset))
                                            t_answer.append(len(answerset))
                                        else:
                                            intersect = answerset.intersect(qapair.answerset)
                                            c_answer.append(intersect)
                                            t_answer.append(len(answerset))
                            precision = sum(c_answer)/sum(t_answer)
                            recall = min(sum(c_answer)/len(qapair.answerset),1.0)
                            break

            return "correct" if correct else "-incorrect", output_where, question_type, type_confidence, precision, recall
Beispiel #3
0
def qg(linker, kb, parser, qapair, question_type_classifier, double_relation_classifier, force_gold=True):
    logger.info(qapair.sparql)
    logger.info(qapair.question.text)

    # Get Answer from KB online
    status, raw_answer_true = kb.query(qapair.sparql.query.replace("https", "http"))
    answerset_true = AnswerSet(raw_answer_true, parser.parse_queryresult)
    qapair.answerset = answerset_true

    ask_query = "ASK " in qapair.sparql.query
    count_query = "COUNT(" in qapair.sparql.query
    sort_query = "order by" in qapair.sparql.raw_query.lower()
    entities, ontologies = linker.do(qapair, force_gold=force_gold)

    if entities is None or ontologies is None:
        return "-Linker_failed", []

    graph = Graph(kb)
    queryBuilder = QueryBuilder()

    logger.info("start finding the minimal subgraph")

    graph.find_minimal_subgraph(entities, ontologies, ask_query=ask_query, sort_query=sort_query)
    logger.info(graph)
    wheres = queryBuilder.to_where_statement(graph, parser.parse_queryresult, ask_query=ask_query,
                                             count_query=count_query, sort_query=sort_query)

    output_where = [{"query": " .".join(item["where"]), "correct": False, "target_var": "?u_0"} for item in wheres]
    for item in list(output_where):
        logger.info(item["query"])
    if len(wheres) == 0:
        return "-without_path", output_where
    correct = False

    for idx in range(len(wheres)):
        where = wheres[idx]

        if "answer" in where:
            answerset = where["answer"]
            target_var = where["target_var"]
        else:
            target_var = "?u_" + str(where["suggested_id"])
            raw_answer = kb.query_where(where["where"], target_var, count_query, ask_query)
            answerset = AnswerSet(raw_answer, parser.parse_queryresult)

        output_where[idx]["target_var"] = target_var
        sparql = SPARQL(kb.sparql_query(where["where"], target_var, count_query, ask_query), ds.parser.parse_sparql)
        if (answerset == qapair.answerset) != (sparql == qapair.sparql):
            print("error")

        if answerset == qapair.answerset:
            correct = True
            output_where[idx]["correct"] = True
            output_where[idx]["target_var"] = target_var
        else:
            if target_var == "?u_0":
                target_var = "?u_1"
            else:
                target_var = "?u_0"
            raw_answer = kb.query_where(where["where"], target_var, count_query, ask_query)
            print("Q_H ",)
            print(raw_answer)
            print("Q_")
            answerset = AnswerSet(raw_answer, parser.parse_queryresult)

            sparql = SPARQL(kb.sparql_query(where["where"], target_var, count_query, ask_query), ds.parser.parse_sparql)
            if (answerset == qapair.answerset) != (sparql == qapair.sparql):
                print("error")

            if answerset == qapair.answerset:
                correct = True
                output_where[idx]["correct"] = True
                output_where[idx]["target_var"] = target_var

    return "correct" if correct else "-incorrect", output_where