Пример #1
0
    def infer_query(self, question, db_id):
        ###step-1 schema-linking
        lemma_utterance_stanza = self.stanza_model(question)
        lemma_utterance = [
            word.lemma for sent in lemma_utterance_stanza.sentences
            for word in sent.words
        ]
        db_context = SpiderDBContext(db_id,
                                     lemma_utterance,
                                     tables_file=self.schema_path,
                                     dataset_path=self.db_path,
                                     stanza_model=self.stanza_model,
                                     schemas=self.schema,
                                     original_utterance=question)
        value_match, value_alignment, exact_match, partial_match = db_context.get_db_knowledge_graph(
            db_id)

        item = {}
        item['interaction'] = [{
            'db_id': db_id,
            'question': question,
            'sql': '',
            'value_match': value_match,
            'value_alignment': value_alignment,
            'exact_match': exact_match,
            'partial_match': partial_match,
        }]

        ###step-2 serialization
        source_sequence, _ = build_schema_linking_data(
            schema=self.database_schemas[db_id],
            question=question,
            item=item,
            turn_id=0,
            linking_type='default')
        slml_question = source_sequence[0]

        ###step-3 prediction
        schemas, eval_foreign_key_maps = load_tables(self.schema_path)
        original_schemas = load_original_schemas(self.schema_path)
        alias_schema = get_alias_schema(schemas)
        rnt = decode_with_constrain(slml_question, alias_schema[db_id],
                                    self.model)
        predict_sql = rnt[0]['text'] if isinstance(rnt[0]['text'],
                                                   str) else rnt[0]['text'][0]
        score = rnt[0]['score'].tolist()

        predict_sql = post_processing_sql(predict_sql,
                                          eval_foreign_key_maps[db_id],
                                          original_schemas[db_id],
                                          schemas[db_id])

        return {
            "slml_question": slml_question,
            "predict_sql": predict_sql,
            "score": score
        }
def read_spider_split(dataset_path, table_path, database_path):
    with open(dataset_path) as f:
        split_data = json.load(f)
    print('read_spider_split', dataset_path, len(split_data))

    schemas = read_dataset_schema(table_path, stanza_model)

    interaction_list = {}
    for i, ex in enumerate(tqdm(split_data)):
        db_id = ex['db_id']

        ex['query_toks_no_value'] = normalize_original_sql(ex['query_toks_no_value'])
        turn_sql = ' '.join(ex['query_toks_no_value'])
        turn_sql = turn_sql.replace('select count ( * ) from follows group by value',
                                    'select count ( * ) from follows group by f1')
        ex['query_toks_no_value'] = turn_sql.split(' ')

        ex = fix_number_value(ex)
        try:
            ex['query_toks_no_value'] = disambiguate_items(db_id, ex['query_toks_no_value'],
                                                           tables_file=table_path, allow_aliases=False)
        except:
            print(ex['query_toks'])
            continue

        final_sql_parse = ' '.join(ex['query_toks_no_value'])
        final_utterance = ' '.join(ex['question_toks']).lower()

        if stanza_model is not None:
            lemma_utterance_stanza = stanza_model(final_utterance)
            lemma_utterance = [word.lemma for sent in lemma_utterance_stanza.sentences for word in sent.words]
            original_utterance = final_utterance
        else:
            original_utterance = lemma_utterance = final_utterance.split(' ')

        # using db content
        db_context = SpiderDBContext(db_id,
                                     lemma_utterance,
                                     tables_file=table_path,
                                     dataset_path=database_path,
                                     stanza_model=stanza_model,
                                     schemas=schemas,
                                     original_utterance=original_utterance)

        value_match, value_alignment, exact_match, partial_match = db_context.get_db_knowledge_graph(db_id)

        if value_match != []:
            print(value_match, value_alignment)

        if db_id not in interaction_list:
            interaction_list[db_id] = []

        interaction = {}
        interaction['id'] = i
        interaction['database_id'] = db_id
        interaction['interaction'] = [{'utterance': final_utterance,
                                       'db_id': db_id,
                                       'query': ex['query'],
                                       'question': ex['question'],
                                       'sql': final_sql_parse,
                                       'value_match': value_match,
                                       'value_alignment': value_alignment,
                                       'exact_match': exact_match,
                                       'partial_match': partial_match,
                                       }]
        interaction_list[db_id].append(interaction)

    return interaction_list