예제 #1
0
    def __init__(self, transition_system, args):
        super(WikiSQLEvaluator,
              self).__init__(transition_system=transition_system)

        print(f'load evaluation database {args.sql_db_file}', file=sys.stderr)
        self.execution_engine = DBEngine(args.sql_db_file)
        self.answer_prune = args.answer_prune
def check():
    data_file = '/Users/yinpengcheng/Research/SemanticParsing/WikiSQL/data/train.jsonl'
    engine = DBEngine(
        '/Users/yinpengcheng/Research/SemanticParsing/WikiSQL/data/train.db')
    grammar = ASDLGrammar.from_text(open('sql_asdl.txt').read())
    transition_system = SqlTransitionSystem(grammar)
    from asdl.hypothesis import Hypothesis
    for line in open(data_file):
        example = json.loads(line)
        query = Query.from_dict(example['sql'])
        asdl_ast = sql_query_to_asdl_ast(query, grammar)
        asdl_ast.sanity_check()
        actions = transition_system.get_actions(asdl_ast)
        hyp = Hypothesis()

        for action in actions:
            hyp.apply_action(action)

        # if asdl_ast_to_sql_query(hyp.tree) != asdl_ast_to_sql_query(asdl_ast):
        #     hyp_query = asdl_ast_to_sql_query(hyp.tree)
        # make sure the execution result is the same
        # hyp_query_result = engine.execute_query(example['table_id'], hyp_query)
        # ref_result = engine.execute_query(example['table_id'], query)

        # assert hyp_query_result == ref_result
        query_reconstr = asdl_ast_to_sql_query(asdl_ast)
        assert query == query_reconstr
        print(query)
예제 #3
0
#    # seed the RNG
#    torch.manual_seed(args.seed)
#    if args.cuda:
#        torch.cuda.manual_seed(args.seed)
#    np.random.seed(int(args.seed * 13 / 7))
#
#    return args
#arg_parser = argparse.ArgumentParser()
#arg_parser.add_argument('-no_parent_production_embe',default=False, action='store_true',
#                            help='Do not use embedding of parent ASDL production to update decoder LSTM state')
#args = arg_parser.parse_args()
##args = init_config()
##args=init_config()
#print(args.no_parent_production_embe)
tmp=[]
engine = DBEngine('./data_model/wikisql/data/train.db')
grammar = ASDLGrammar.from_text(open('./asdl/lang/sql/sql_asdl.txt').read())
transition_system = SqlTransitionSystem(grammar)
if(True):
    from asdl.hypothesis import Hypothesis
    for ids,line in enumerate(open(data_file,encoding='utf-8')):
        example = json.loads(line)
        print(example['sql'])
        query = Query.from_dict(example['sql']).lower()
        print(query)
        asdl_ast = sql_query_to_asdl_ast(query, grammar)
        asdl_ast.sanity_check()
        print(asdl_ast.to_string())
#        asdl_ast.sort_removedup_self()
#        print(asdl_ast.to_string())
#        a=input('fff')
예제 #4
0
class WikiSQLEvaluator(Evaluator):
    def __init__(self, transition_system, args):
        super(WikiSQLEvaluator,
              self).__init__(transition_system=transition_system)

        print(f'load evaluation database {args.sql_db_file}', file=sys.stderr)
        self.execution_engine = DBEngine(args.sql_db_file)
        self.answer_prune = args.answer_prune

    def is_hyp_correct(self, example, hyp):
        hyp_query = asdl_ast_to_sql_query(hyp.tree)
        detokenized_hyp_query = detokenize_query(hyp_query, example.meta,
                                                 example.table)

        hyp_answer = self.execution_engine.execute_query(
            example.meta['table_id'], detokenized_hyp_query, lower=True)

        ref_query = Query.from_tokenized_dict(example.meta['query'])
        ref_answer = self.execution_engine.execute_query(
            example.meta['table_id'], ref_query, lower=True)

        result = ref_answer == hyp_answer

        return result

    def finemet(self, tree1, tree2, example, passed):
        hyp_query = asdl_ast_to_sql_query(tree1)
        detokenized_hyp_query = hyp_query
        #        detokenized_hyp_query = detokenize_query(hyp_query, example.meta, example.table).lower()
        c1 = detokenized_hyp_query.conditions
        c1 = sorted(c1, key=lambda x: str(x))

        ref_query = asdl_ast_to_sql_query(tree2)
        c2 = ref_query.conditions
        c2 = sorted(c2, key=lambda x: str(x))
        issim = True
        if (len(c1) == len(c2)):
            for i in range(len(c1)):
                if str(c1[i]) == str(c2[i]):
                    issim = True
                else:
                    issim = False
                    break
        else:
            issim = False
        if passed and not issim:
            print(detokenized_hyp_query)
            print(ref_query)
            print(c1)
            print(c2)

            #         print(result)
            a = input('haha')
        result = [
            detokenized_hyp_query.sel_index == ref_query.sel_index,
            detokenized_hyp_query.agg_index == ref_query.agg_index, issim
        ]
        return result

    def evaluate_dataset(self, examples, decode_results, fast_mode=False):
        #        for example, hyp_list in tqdm.tqdm(zip(examples, decode_results)):
        #            if(hyp_list):
        #                print(hyp_list[0].actions)
        #                print([a.action for a in example.tgt_actions])
        #        print('jkhff')
        self.answer_prune = True
        if self.answer_prune:
            filtered_decode_results = []
            for example, hyp_list in tqdm.tqdm(zip(examples,
                                                   decode_results[0])):
                pruned_hyps = []
                if hyp_list:
                    for hyp_id, hyp in enumerate(hyp_list):
                        try:
                            # check if it is executable
                            detokenized_hyp_query = detokenize_query(
                                hyp.code, example.meta, example.table)
                            hyp_answer = self.execution_engine.execute_query(
                                example.meta['table_id'],
                                detokenized_hyp_query,
                                lower=True)
                            if len(hyp_answer) == 0:
                                continue

                            pruned_hyps.append(hyp)
                            if fast_mode: break
                        except:
                            print("Exception in converting tree to code:",
                                  file=sys.stdout)
                            print('-' * 60, file=sys.stdout)
                            print(
                                'Example: %s\nIntent: %s\nTarget Code:\n%s\nHypothesis[%d]:\n%s'
                                % (example.idx, ' '.join(
                                    example.src_sent), example.tgt_code,
                                   hyp_id, hyp.tree.to_string()),
                                file=sys.stdout)
                            print()
                            print(hyp.code)
                            traceback.print_exc(file=sys.stdout)
                            print('-' * 60, file=sys.stdout)

                filtered_decode_results.append(pruned_hyps)

            decode_results = [filtered_decode_results, decode_results[1]]

        eval_results = Evaluator.evaluate_dataset(self, examples,
                                                  decode_results, fast_mode)

        return eval_results
예제 #5
0
def load_dataset(transition_system, dataset_file, table_file):
    examples = []
    engine = DBEngine(dataset_file[:-len('jsonl')] + 'db')

    # load table
    tables = dict()
    for line in open(table_file):
        table_entry = json.loads(line)
        tables[table_entry['id']] = table_entry

    for idx, line in tqdm.tqdm(enumerate(open(dataset_file))):
        # if idx > 100: break
        entry = json.loads(line)
        del entry['seq_input']
        del entry['seq_output']
        del entry['where_output']

        query = Query.from_tokenized_dict(entry['query'])
        query = query.lower()

        tokenized_conditions = []
        for col, op, val_entry in entry['query']['conds']:
            val = []
            for word, after in zip(val_entry['words'], val_entry['after']):
                val.append(word)

            tokenized_conditions.append([col, op, ' '.join(val)])
        tokenized_query = Query(sel_index=entry['query']['sel'],
                                agg_index=entry['query']['agg'],
                                conditions=tokenized_conditions)

        asdl_ast = sql_query_to_asdl_ast(tokenized_query,
                                         transition_system.grammar)
        asdl_ast.sanity_check()
        actions = transition_system.get_actions(asdl_ast)
        hyp = Hypothesis()

        question_tokens = entry['question']['words']
        tgt_action_infos = get_action_infos(question_tokens,
                                            actions,
                                            force_copy=True)

        for action, action_info in zip(actions, tgt_action_infos):
            assert action == action_info.action
            hyp.apply_action(action)

        reconstructed_query_from_hyp = asdl_ast_to_sql_query(hyp.tree)
        reconstructed_query = asdl_ast_to_sql_query(asdl_ast)

        assert tokenized_query == reconstructed_query
        #
        #        # now we make sure the tokenized query executes to the same results as the original one!
        #
        #        detokenized_conds_from_reconstr_query = []
        #        error = False
        #        for i, (col, op, val) in enumerate(reconstructed_query_from_hyp.conditions):
        #            val_tokens = val.split(' ')
        #            cond_entry = entry['query']['conds'][i]
        #
        #            assert col == cond_entry[0]
        #            assert op == cond_entry[1]
        #
        #            detokenized_cond_val = my_detokenize(val_tokens, entry['question'])
        #            raw_cond_val = detokenize(cond_entry[2])
        #            if detokenized_cond_val.lower() != raw_cond_val.lower():
        #                # print(idx + 1, detokenized_cond_val, raw_cond_val, file=sys.stderr)
        #                error = True
        #
        #            detokenized_conds_from_reconstr_query.append((col, op, detokenized_cond_val))
        #
        #        detokenized_reconstr_query_from_hyp = Query(sel_index=reconstructed_query_from_hyp.sel_index,
        #                                                    agg_index=reconstructed_query_from_hyp.agg_index,
        #                                                    conditions=detokenized_conds_from_reconstr_query)
        #
        #        # make sure the execution result is the same
        #        hyp_query_result = engine.execute_query(entry['table_id'], detokenized_reconstr_query_from_hyp)
        #        ref_result = engine.execute_query(entry['table_id'], query)
        #
        #        if hyp_query_result != ref_result:
        #            print('[%d]: %s, %s' % (idx, query, detokenized_reconstr_query_from_hyp), file=sys.stderr)

        header = [
            TableColumn(name=detokenize(col_name),
                        tokens=col_name['words'],
                        type=col_type) for (col_name, col_type) in
            zip(entry['table']['header'], tables[entry['table_id']]['types'])
        ]
        table = WikiSqlTable(header=header)

        example = WikiSqlExample(idx=idx,
                                 question=question_tokens,
                                 table=table,
                                 tgt_actions=tgt_action_infos,
                                 tgt_code=query,
                                 tgt_ast=asdl_ast,
                                 meta=entry)

        examples.append(example)

        # print(query)

    return examples
예제 #6
0
class WikiSQLEvaluator(Evaluator):
    def __init__(self, transition_system, args):
        super(WikiSQLEvaluator,
              self).__init__(transition_system=transition_system)

        print(f'load evaluation database {args.sql_db_file}', file=sys.stderr)
        self.execution_engine = DBEngine(args.sql_db_file)
        self.answer_prune = args.answer_prune

    def is_hyp_correct(self, example, hyp):
        hyp_query = asdl_ast_to_sql_query(hyp.tree)
        detokenized_hyp_query = detokenize_query(hyp_query, example.meta,
                                                 example.table)

        hyp_answer = self.execution_engine.execute_query(
            example.meta['table_id'], detokenized_hyp_query, lower=True)

        ref_query = Query.from_tokenized_dict(example.meta['query'])
        ref_answer = self.execution_engine.execute_query(
            example.meta['table_id'], ref_query, lower=True)

        result = ref_answer == hyp_answer

        return result

    def evaluate_dataset(self, examples, decode_results, fast_mode=False):
        if self.answer_prune:
            filtered_decode_results = []
            for example, hyp_list in zip(examples, decode_results):
                pruned_hyps = []
                if hyp_list:
                    for hyp_id, hyp in enumerate(hyp_list):
                        try:
                            # check if it is executable
                            detokenized_hyp_query = detokenize_query(
                                hyp.code, example.meta, example.table)
                            hyp_answer = self.execution_engine.execute_query(
                                example.meta['table_id'],
                                detokenized_hyp_query,
                                lower=True)
                            if len(hyp_answer) == 0:
                                continue

                            pruned_hyps.append(hyp)
                            if fast_mode: break
                        except:
                            print("Exception in converting tree to code:",
                                  file=sys.stdout)
                            print('-' * 60, file=sys.stdout)
                            print(
                                'Example: %s\nIntent: %s\nTarget Code:\n%s\nHypothesis[%d]:\n%s'
                                % (example.idx, ' '.join(
                                    example.src_sent), example.tgt_code,
                                   hyp_id, hyp.tree.to_string()),
                                file=sys.stdout)
                            print()
                            print(hyp.code)
                            traceback.print_exc(file=sys.stdout)
                            print('-' * 60, file=sys.stdout)

                filtered_decode_results.append(pruned_hyps)

            decode_results = filtered_decode_results

        eval_results = Evaluator.evaluate_dataset(self, examples,
                                                  decode_results, fast_mode)

        return eval_results
예제 #7
0
 def __init__(self, args):
     super(WikiSqlEvaluator, self).__init__()
     self.db_engine = DBEngine(args.wikisql_db_file)
     self.example_dicts = []
     for json_line in open(args.wikisql_table_file):
         self.example_dicts.append(json.loads(json_line))