Пример #1
0
    def evaluate_dataset(self, examples, decode_results, fast_mode=False):
        #        for example, hyp_list in tqdm.tqdm(zip(examples, decode_results)):
        #            if(hyp_list):
        #                print(hyp_list[0].actions)
        #                print([a.action for a in example.tgt_actions])
        #        print('jkhff')
        self.answer_prune = True
        if self.answer_prune:
            filtered_decode_results = []
            for example, hyp_list in tqdm.tqdm(zip(examples,
                                                   decode_results[0])):
                pruned_hyps = []
                if hyp_list:
                    for hyp_id, hyp in enumerate(hyp_list):
                        try:
                            # check if it is executable
                            detokenized_hyp_query = detokenize_query(
                                hyp.code, example.meta, example.table)
                            hyp_answer = self.execution_engine.execute_query(
                                example.meta['table_id'],
                                detokenized_hyp_query,
                                lower=True)
                            if len(hyp_answer) == 0:
                                continue

                            pruned_hyps.append(hyp)
                            if fast_mode: break
                        except:
                            print("Exception in converting tree to code:",
                                  file=sys.stdout)
                            print('-' * 60, file=sys.stdout)
                            print(
                                'Example: %s\nIntent: %s\nTarget Code:\n%s\nHypothesis[%d]:\n%s'
                                % (example.idx, ' '.join(
                                    example.src_sent), example.tgt_code,
                                   hyp_id, hyp.tree.to_string()),
                                file=sys.stdout)
                            print()
                            print(hyp.code)
                            traceback.print_exc(file=sys.stdout)
                            print('-' * 60, file=sys.stdout)

                filtered_decode_results.append(pruned_hyps)

            decode_results = [filtered_decode_results, decode_results[1]]

        eval_results = Evaluator.evaluate_dataset(self, examples,
                                                  decode_results, fast_mode)

        return eval_results
def dump_wiki_sql_eval_file(dataset, decode_results, output_file):
    f = open(output_file, 'w')
    for example, hyps in zip(dataset, decode_results):
        result_dict = dict()
        if hyps:
            hyp = hyps[0]
            result_dict['error'] = False
            result_dict['query'] = detokenize_query(hyp.code, example.meta, example.table).to_dict()
        else:
            result_dict['error'] = True

        json_line = json.dumps(result_dict)
        f.write(json_line + '\n')
    f.close()
Пример #3
0
    def is_hyp_correct(self, example, hyp):
        hyp_query = asdl_ast_to_sql_query(hyp.tree)
        detokenized_hyp_query = detokenize_query(hyp_query, example.meta,
                                                 example.table)

        hyp_answer = self.execution_engine.execute_query(
            example.meta['table_id'], detokenized_hyp_query, lower=True)

        ref_query = Query.from_tokenized_dict(example.meta['query'])
        ref_answer = self.execution_engine.execute_query(
            example.meta['table_id'], ref_query, lower=True)

        result = ref_answer == hyp_answer

        return result