def demo_table(args, sp):
    """
    Run the semantic parser from the standard input.
    """
    sp.load_checkpoint(get_checkpoint_path(args))
    sp.eval()

    vocabs = data_loader.load_vocabs(args)

    table_array = [
        ['name', 'age', 'gender'],
        ['John', 18, 'male'],
        ['Kate', 19, 'female']
    ]
    table_name = 'employees'
    schema_graph = schema_loader.SchemaGraph(table_name)
    schema_graph.load_data_from_2d_array(table_array)

    sys.stdout.write('Enter a natural language question: ')
    sys.stdout.write('> ')
    sys.stdout.flush()
    text = sys.stdin.readline()

    while text:
        example = data_utils.Text2SQLExample(0, table_name, 0)
        example.text = text
        demo_preprocess(args, example, vocabs, schema_graph)
        output = sp.forward([example])
        for i, sql in enumerate(output['pred_decoded'][0]):
            print('Top {}: {}'.format(i, sql))
        sys.stdout.flush()
        sys.stdout.write('\nEnter a natural language question: ')
        sys.stdout.write('> ')
        text = sys.stdin.readline()
def load_semantic_parser(args):
    if args.model in model_index:
        sp = EncoderDecoderLFramework(args)
    else:
        raise NotImplementedError
    sp.load_checkpoint(get_checkpoint_path(args))
    sp.cuda()
    sp.eval()
    return sp
Beispiel #3
0
def fine_tune(sp):
    dataset = data_loader.load_processed_data(args)
    fine_tune_data = dataset['fine-tune']

    print('{} fine-tuning examples loaded'.format(len(fine_tune_data)))
    dev_data = fine_tune_data

    sp.schema_graphs = dataset['schema']
    sp.load_checkpoint(get_checkpoint_path(args))

    sp.run_train(fine_tune_data, dev_data)
def inference(sp):
    text_tokenize, program_tokenize, post_process, table_utils = tok.get_tokenizers(
        args)
    schema_graphs = schema_loader.load_schema_graphs_spider(
        args.codalab_data_dir, 'spider', db_dir=args.codalab_db_dir)
    schema_graphs.lexicalize_graphs(tokenize=text_tokenize,
                                    normalized=(args.model_id
                                                in [utils.BRIDGE]))
    sp.schema_graphs = schema_graphs
    text_vocab = Vocabulary('text',
                            func_token_index=functional_token_index,
                            tu=table_utils)
    for v in table_utils.tokenizer.vocab:
        text_vocab.index_token(v, True,
                               table_utils.tokenizer.convert_tokens_to_ids(v))
    program_vocab = sql_reserved_tokens if args.pretrained_transformer else sql_reserved_tokens_revtok
    vocabs = {'text': text_vocab, 'program': program_vocab}
    examples = data_loader.load_data_split_spider(args.codalab_data_dir, 'dev',
                                                  schema_graphs)
    print('{} {} examples loaded'.format(len(examples), 'dev'))

    for i, example in enumerate(examples):
        schema_graph = schema_graphs.get_schema(example.db_id)
        preprocess_example('dev', example, args, None, text_tokenize,
                           program_tokenize, post_process, table_utils,
                           schema_graph, vocabs)
    print('{} {} examples processed'.format(len(examples), 'dev'))

    sp.load_checkpoint(get_checkpoint_path(args))
    sp.eval()

    out_dict = sp.inference(
        examples,
        restore_clause_order=args.process_sql_in_execution_order,
        check_schema_consistency_=True,
        inline_eval=False,
        verbose=False)

    assert (sp.args.prediction_path is not None)
    out_txt = sp.args.prediction_path
    with open(out_txt, 'w') as o_f:
        for pred_sql in out_dict['pred_decoded']:
            o_f.write('{}\n'.format(pred_sql[0]))
        print('Model predictions saved to {}'.format(out_txt))
Beispiel #5
0
def inference(sp):
    dataset = data_loader.load_processed_data(args)
    split = 'test' if args.test else 'dev'
    if args.dataset_name == 'wikisql':
        engine_path = os.path.join(args.data_dir, '{}.db'.format(split))
        engine = DBEngine(engine_path)
    else:
        engine = None

    def evaluate(examples, out_dict):
        metrics = eval_tools.get_exact_match_metrics(
            examples, out_dict['pred_decoded'], engine=engine)
        print('Top-1 exact match: {:.3f}'.format(metrics['top_1_em']))
        print('Top-2 exact match: {:.3f}'.format(metrics['top_2_em']))
        print('Top-3 exact match: {:.3f}'.format(metrics['top_3_em']))
        print('Top-5 exact match: {:.3f}'.format(metrics['top_5_em']))
        print('Top-10 exact match: {:.3f}'.format(metrics['top_10_em']))
        if args.dataset_name == 'wikisql':
            print('Top-1 exe match: {:.3f}'.format(metrics['top_1_ex']))
            print('Top-2 exe match: {:.3f}'.format(metrics['top_2_ex']))
            print('Top-3 exe match: {:.3f}'.format(metrics['top_3_ex']))
            print('Top-5 exe match: {:.3f}'.format(metrics['top_5_ex']))
            print('Top-10 exet match: {:.3f}'.format(metrics['top_10_ex']))
        print('Table error: {:.3f}'.format(metrics['table_err']))
        performance = os.path.join(sp.model_dir, f"test_performance_{args.data_dir.split('/')[1]}_{args.beam_size}.txt")
        metric_keys = ['top_1_em', 'top_2_em', 'top_3_em', 'top_5_em', 'top_10_em', 'top_1_ex', 'top_2_ex', 
        'top_3_ex', 'top_5_ex', 'top_10_ex', 'table_err']
        with open(performance, 'w') as pf:
            for key in metric_keys:
                pf.write(f'{key}: {metrics[key]:.3f}\n')

    examples = dataset[split]
    # random.shuffle(examples)
    sp.schema_graphs = dataset['schema']
    print('{} {} examples loaded'.format(len(examples), split))

    if sp.args.use_pred_tables:
        in_table = os.path.join(sp.args.model_dir, 'predicted_tables.txt')
        with open(in_table) as f:
            content = f.readlines()
        assert(len(content) == len(examples))
        for example, line in zip(examples, content):
            pred_tables = set([x.strip()[1:-1]
                              for x in line.strip()[1:-1].split(',')])
            example.leaf_condition_vals_list = pred_tables

    sp.load_checkpoint(get_checkpoint_path(args))
    sp.eval()

    if sp.args.augment_with_wikisql:
        examples_, examples_wikisql = [], []
        for example in examples:
            if example.dataset_id == data_utils.WIKISQL:
                examples_wikisql.append(example)
            else:
                examples_.append(example)
        examples = examples_

    pred_restored_cache = sp.load_pred_restored_cache()
    pred_restored_cache_size = sum(len(v)
                                   for v in pred_restored_cache.values())
    # pred_restored_cache = None
    out_dict = sp.inference(examples, restore_clause_order=args.process_sql_in_execution_order,
                            pred_restored_cache=pred_restored_cache,
                            check_schema_consistency_=args.sql_consistency_check,
                            engine=engine, inline_eval=True, verbose=True)
    if args.process_sql_in_execution_order:
        new_pred_restored_cache_size = sum(
            len(v) for v in out_dict['pred_restored_cache'].values())
        newly_cached_size = new_pred_restored_cache_size - pred_restored_cache_size
        if newly_cached_size > 0:
            sp.save_pred_restored_cache(
                out_dict['pred_restored_cache'], newly_cached_size)

    out_txt = os.path.join(sp.model_dir, 'predictions.{}.{}.{}.txt'.format(
        args.beam_size, args.bs_alpha, split))
    with open(out_txt, 'w') as o_f:
        assert(len(examples) == len(out_dict['pred_decoded']))
        for i, pred_sql in enumerate(out_dict['pred_decoded']):
            if args.dataset_name == 'wikisql':
                example = examples[i]
                o_f.write('{}\n'.format(json.dumps(
                    {'sql': pred_sql[0], 'table_id': example.db_name})))
            else:
                o_f.write('{}\n'.format(pred_sql[0]))
        print('Model predictions saved to {}'.format(out_txt))

    print('{} set performance'.format(split.upper()))
    evaluate(examples, out_dict)
    
    if args.augment_with_wikisql:
        wikisql_out_dict = sp.forward(examples_wikisql, verbose=False)
        print('*** WikiSQL ***')
        evaluate(examples_wikisql, wikisql_out_dict)