def check():
    data_file = '/Users/yinpengcheng/Research/SemanticParsing/WikiSQL/data/train.jsonl'
    engine = DBEngine(
        '/Users/yinpengcheng/Research/SemanticParsing/WikiSQL/data/train.db')
    grammar = ASDLGrammar.from_text(open('sql_asdl.txt').read())
    transition_system = SqlTransitionSystem(grammar)
    from asdl.hypothesis import Hypothesis
    for line in open(data_file):
        example = json.loads(line)
        query = Query.from_dict(example['sql'])
        asdl_ast = sql_query_to_asdl_ast(query, grammar)
        asdl_ast.sanity_check()
        actions = transition_system.get_actions(asdl_ast)
        hyp = Hypothesis()

        for action in actions:
            hyp.apply_action(action)

        # if asdl_ast_to_sql_query(hyp.tree) != asdl_ast_to_sql_query(asdl_ast):
        #     hyp_query = asdl_ast_to_sql_query(hyp.tree)
        # make sure the execution result is the same
        # hyp_query_result = engine.execute_query(example['table_id'], hyp_query)
        # ref_result = engine.execute_query(example['table_id'], query)

        # assert hyp_query_result == ref_result
        query_reconstr = asdl_ast_to_sql_query(asdl_ast)
        assert query == query_reconstr
        print(query)
Example #2
0
    def is_hyp_correct(self, example, hyp):
        hyp_query = asdl_ast_to_sql_query(hyp.tree)
        detokenized_hyp_query = detokenize_query(hyp_query, example.meta,
                                                 example.table)

        hyp_answer = self.execution_engine.execute_query(
            example.meta['table_id'], detokenized_hyp_query, lower=True)

        ref_query = Query.from_tokenized_dict(example.meta['query'])
        ref_answer = self.execution_engine.execute_query(
            example.meta['table_id'], ref_query, lower=True)

        result = ref_answer == hyp_answer

        return result
def asdl_ast_to_sql_query(asdl_ast):
    # stmt = Select(agg_op? agg, column_name col_name, cond_expr* condition)
    sel_idx = asdl_ast['col_idx'].value
    agg_op_idx = 0 if asdl_ast['agg'].value is None else ctr_name2agg_idx[
        asdl_ast['agg'].value.production.constructor.name]
    conditions = []
    for condition_node in asdl_ast['conditions'].value:
        col_idx = condition_node['col_idx'].value
        cmp_op_idx = ctr_name2cmp_op_idx[
            condition_node['op'].value.production.constructor.name]
        value = condition_node['value'].value
        conditions.append((col_idx, cmp_op_idx, value))

    query = Query(sel_idx, agg_op_idx, conditions)

    return query
Example #4
0
#arg_parser.add_argument('-no_parent_production_embe',default=False, action='store_true',
#                            help='Do not use embedding of parent ASDL production to update decoder LSTM state')
#args = arg_parser.parse_args()
##args = init_config()
##args=init_config()
#print(args.no_parent_production_embe)
tmp=[]
engine = DBEngine('./data_model/wikisql/data/train.db')
grammar = ASDLGrammar.from_text(open('./asdl/lang/sql/sql_asdl.txt').read())
transition_system = SqlTransitionSystem(grammar)
if(True):
    from asdl.hypothesis import Hypothesis
    for ids,line in enumerate(open(data_file,encoding='utf-8')):
        example = json.loads(line)
        print(example['sql'])
        query = Query.from_dict(example['sql']).lower()
        print(query)
        asdl_ast = sql_query_to_asdl_ast(query, grammar)
        asdl_ast.sanity_check()
        print(asdl_ast.to_string())
#        asdl_ast.sort_removedup_self()
#        print(asdl_ast.to_string())
#        a=input('fff')
        actions = transition_system.get_actions(asdl_ast)
        tmp.append(actions)
        hyp = Hypothesis()
        print(actions)
        for action in actions:
            hyp.apply_action(action)
        print(hyp.tree)
#        a=input('fff')
Example #5
0
def load_dataset(transition_system, dataset_file, table_file):
    examples = []
    engine = DBEngine(dataset_file[:-len('jsonl')] + 'db')

    # load table
    tables = dict()
    for line in open(table_file):
        table_entry = json.loads(line)
        tables[table_entry['id']] = table_entry

    for idx, line in tqdm.tqdm(enumerate(open(dataset_file))):
        # if idx > 100: break
        entry = json.loads(line)
        del entry['seq_input']
        del entry['seq_output']
        del entry['where_output']

        query = Query.from_tokenized_dict(entry['query'])
        query = query.lower()

        tokenized_conditions = []
        for col, op, val_entry in entry['query']['conds']:
            val = []
            for word, after in zip(val_entry['words'], val_entry['after']):
                val.append(word)

            tokenized_conditions.append([col, op, ' '.join(val)])
        tokenized_query = Query(sel_index=entry['query']['sel'],
                                agg_index=entry['query']['agg'],
                                conditions=tokenized_conditions)

        asdl_ast = sql_query_to_asdl_ast(tokenized_query,
                                         transition_system.grammar)
        asdl_ast.sanity_check()
        actions = transition_system.get_actions(asdl_ast)
        hyp = Hypothesis()

        question_tokens = entry['question']['words']
        tgt_action_infos = get_action_infos(question_tokens,
                                            actions,
                                            force_copy=True)

        for action, action_info in zip(actions, tgt_action_infos):
            assert action == action_info.action
            hyp.apply_action(action)

        reconstructed_query_from_hyp = asdl_ast_to_sql_query(hyp.tree)
        reconstructed_query = asdl_ast_to_sql_query(asdl_ast)

        assert tokenized_query == reconstructed_query
        #
        #        # now we make sure the tokenized query executes to the same results as the original one!
        #
        #        detokenized_conds_from_reconstr_query = []
        #        error = False
        #        for i, (col, op, val) in enumerate(reconstructed_query_from_hyp.conditions):
        #            val_tokens = val.split(' ')
        #            cond_entry = entry['query']['conds'][i]
        #
        #            assert col == cond_entry[0]
        #            assert op == cond_entry[1]
        #
        #            detokenized_cond_val = my_detokenize(val_tokens, entry['question'])
        #            raw_cond_val = detokenize(cond_entry[2])
        #            if detokenized_cond_val.lower() != raw_cond_val.lower():
        #                # print(idx + 1, detokenized_cond_val, raw_cond_val, file=sys.stderr)
        #                error = True
        #
        #            detokenized_conds_from_reconstr_query.append((col, op, detokenized_cond_val))
        #
        #        detokenized_reconstr_query_from_hyp = Query(sel_index=reconstructed_query_from_hyp.sel_index,
        #                                                    agg_index=reconstructed_query_from_hyp.agg_index,
        #                                                    conditions=detokenized_conds_from_reconstr_query)
        #
        #        # make sure the execution result is the same
        #        hyp_query_result = engine.execute_query(entry['table_id'], detokenized_reconstr_query_from_hyp)
        #        ref_result = engine.execute_query(entry['table_id'], query)
        #
        #        if hyp_query_result != ref_result:
        #            print('[%d]: %s, %s' % (idx, query, detokenized_reconstr_query_from_hyp), file=sys.stderr)

        header = [
            TableColumn(name=detokenize(col_name),
                        tokens=col_name['words'],
                        type=col_type) for (col_name, col_type) in
            zip(entry['table']['header'], tables[entry['table_id']]['types'])
        ]
        table = WikiSqlTable(header=header)

        example = WikiSqlExample(idx=idx,
                                 question=question_tokens,
                                 table=table,
                                 tgt_actions=tgt_action_infos,
                                 tgt_code=query,
                                 tgt_ast=asdl_ast,
                                 meta=entry)

        examples.append(example)

        # print(query)

    return examples