def demo_table(args, sp): """ Run the semantic parser from the standard input. """ sp.load_checkpoint(get_checkpoint_path(args)) sp.eval() vocabs = data_loader.load_vocabs(args) table_array = [ ['name', 'age', 'gender'], ['John', 18, 'male'], ['Kate', 19, 'female'] ] table_name = 'employees' schema_graph = schema_loader.SchemaGraph(table_name) schema_graph.load_data_from_2d_array(table_array) sys.stdout.write('Enter a natural language question: ') sys.stdout.write('> ') sys.stdout.flush() text = sys.stdin.readline() while text: example = data_utils.Text2SQLExample(0, table_name, 0) example.text = text demo_preprocess(args, example, vocabs, schema_graph) output = sp.forward([example]) for i, sql in enumerate(output['pred_decoded'][0]): print('Top {}: {}'.format(i, sql)) sys.stdout.flush() sys.stdout.write('\nEnter a natural language question: ') sys.stdout.write('> ') text = sys.stdin.readline()
def load_semantic_parser(args): if args.model in model_index: sp = EncoderDecoderLFramework(args) else: raise NotImplementedError sp.load_checkpoint(get_checkpoint_path(args)) sp.cuda() sp.eval() return sp
def fine_tune(sp): dataset = data_loader.load_processed_data(args) fine_tune_data = dataset['fine-tune'] print('{} fine-tuning examples loaded'.format(len(fine_tune_data))) dev_data = fine_tune_data sp.schema_graphs = dataset['schema'] sp.load_checkpoint(get_checkpoint_path(args)) sp.run_train(fine_tune_data, dev_data)
def inference(sp): text_tokenize, program_tokenize, post_process, table_utils = tok.get_tokenizers( args) schema_graphs = schema_loader.load_schema_graphs_spider( args.codalab_data_dir, 'spider', db_dir=args.codalab_db_dir) schema_graphs.lexicalize_graphs(tokenize=text_tokenize, normalized=(args.model_id in [utils.BRIDGE])) sp.schema_graphs = schema_graphs text_vocab = Vocabulary('text', func_token_index=functional_token_index, tu=table_utils) for v in table_utils.tokenizer.vocab: text_vocab.index_token(v, True, table_utils.tokenizer.convert_tokens_to_ids(v)) program_vocab = sql_reserved_tokens if args.pretrained_transformer else sql_reserved_tokens_revtok vocabs = {'text': text_vocab, 'program': program_vocab} examples = data_loader.load_data_split_spider(args.codalab_data_dir, 'dev', schema_graphs) print('{} {} examples loaded'.format(len(examples), 'dev')) for i, example in enumerate(examples): schema_graph = schema_graphs.get_schema(example.db_id) preprocess_example('dev', example, args, None, text_tokenize, program_tokenize, post_process, table_utils, schema_graph, vocabs) print('{} {} examples processed'.format(len(examples), 'dev')) sp.load_checkpoint(get_checkpoint_path(args)) sp.eval() out_dict = sp.inference( examples, restore_clause_order=args.process_sql_in_execution_order, check_schema_consistency_=True, inline_eval=False, verbose=False) assert (sp.args.prediction_path is not None) out_txt = sp.args.prediction_path with open(out_txt, 'w') as o_f: for pred_sql in out_dict['pred_decoded']: o_f.write('{}\n'.format(pred_sql[0])) print('Model predictions saved to {}'.format(out_txt))
def inference(sp): dataset = data_loader.load_processed_data(args) split = 'test' if args.test else 'dev' if args.dataset_name == 'wikisql': engine_path = os.path.join(args.data_dir, '{}.db'.format(split)) engine = DBEngine(engine_path) else: engine = None def evaluate(examples, out_dict): metrics = eval_tools.get_exact_match_metrics( examples, out_dict['pred_decoded'], engine=engine) print('Top-1 exact match: {:.3f}'.format(metrics['top_1_em'])) print('Top-2 exact match: {:.3f}'.format(metrics['top_2_em'])) print('Top-3 exact match: {:.3f}'.format(metrics['top_3_em'])) print('Top-5 exact match: {:.3f}'.format(metrics['top_5_em'])) print('Top-10 exact match: {:.3f}'.format(metrics['top_10_em'])) if args.dataset_name == 'wikisql': print('Top-1 exe match: {:.3f}'.format(metrics['top_1_ex'])) print('Top-2 exe match: {:.3f}'.format(metrics['top_2_ex'])) print('Top-3 exe match: {:.3f}'.format(metrics['top_3_ex'])) print('Top-5 exe match: {:.3f}'.format(metrics['top_5_ex'])) print('Top-10 exet match: {:.3f}'.format(metrics['top_10_ex'])) print('Table error: {:.3f}'.format(metrics['table_err'])) performance = os.path.join(sp.model_dir, f"test_performance_{args.data_dir.split('/')[1]}_{args.beam_size}.txt") metric_keys = ['top_1_em', 'top_2_em', 'top_3_em', 'top_5_em', 'top_10_em', 'top_1_ex', 'top_2_ex', 'top_3_ex', 'top_5_ex', 'top_10_ex', 'table_err'] with open(performance, 'w') as pf: for key in metric_keys: pf.write(f'{key}: {metrics[key]:.3f}\n') examples = dataset[split] # random.shuffle(examples) sp.schema_graphs = dataset['schema'] print('{} {} examples loaded'.format(len(examples), split)) if sp.args.use_pred_tables: in_table = os.path.join(sp.args.model_dir, 'predicted_tables.txt') with open(in_table) as f: content = f.readlines() assert(len(content) == len(examples)) for example, line in zip(examples, content): pred_tables = set([x.strip()[1:-1] for x in line.strip()[1:-1].split(',')]) example.leaf_condition_vals_list = pred_tables sp.load_checkpoint(get_checkpoint_path(args)) sp.eval() if sp.args.augment_with_wikisql: examples_, examples_wikisql = [], [] for example in examples: if example.dataset_id == data_utils.WIKISQL: examples_wikisql.append(example) else: examples_.append(example) examples = examples_ pred_restored_cache = sp.load_pred_restored_cache() pred_restored_cache_size = sum(len(v) for v in pred_restored_cache.values()) # pred_restored_cache = None out_dict = sp.inference(examples, restore_clause_order=args.process_sql_in_execution_order, pred_restored_cache=pred_restored_cache, check_schema_consistency_=args.sql_consistency_check, engine=engine, inline_eval=True, verbose=True) if args.process_sql_in_execution_order: new_pred_restored_cache_size = sum( len(v) for v in out_dict['pred_restored_cache'].values()) newly_cached_size = new_pred_restored_cache_size - pred_restored_cache_size if newly_cached_size > 0: sp.save_pred_restored_cache( out_dict['pred_restored_cache'], newly_cached_size) out_txt = os.path.join(sp.model_dir, 'predictions.{}.{}.{}.txt'.format( args.beam_size, args.bs_alpha, split)) with open(out_txt, 'w') as o_f: assert(len(examples) == len(out_dict['pred_decoded'])) for i, pred_sql in enumerate(out_dict['pred_decoded']): if args.dataset_name == 'wikisql': example = examples[i] o_f.write('{}\n'.format(json.dumps( {'sql': pred_sql[0], 'table_id': example.db_name}))) else: o_f.write('{}\n'.format(pred_sql[0])) print('Model predictions saved to {}'.format(out_txt)) print('{} set performance'.format(split.upper())) evaluate(examples, out_dict) if args.augment_with_wikisql: wikisql_out_dict = sp.forward(examples_wikisql, verbose=False) print('*** WikiSQL ***') evaluate(examples_wikisql, wikisql_out_dict)