Пример #1
0
def run_experiment(args):
    if args.process_data:
        process_data()
    elif args.ensemble_inference and not args.demo:
        get_model_dir(args)
        assert(args.model in ['bridge',
                              'seq2seq',
                              'seq2seq.pg'])
        ensemble()
    else:
        with torch.set_grad_enabled(args.train or args.search_random_seed or args.grid_search or args.fine_tune):
            get_model_dir(args)
            if args.model in ['bridge',
                              'seq2seq',
                              'seq2seq.pg']:
                sp = EncoderDecoderLFramework(args)
            else:
                raise NotImplementedError

            sp.cuda()
            if args.train:
                train(sp)
            elif args.inference:
                inference(sp)
            elif args.error_analysis:
                error_analysis(sp)
            elif args.demo:
                demo(args)
            elif args.fine_tune:
                fine_tune(sp)
            else:
                print('No experiment specified. Exit now.')
                sys.exit(1)
Пример #2
0
def run_inference(args):
    if args.model in ['bridge', 'seq2seq', 'seq2seq.pg']:
        sp = EncoderDecoderLFramework(args)
    else:
        raise NotImplementedError

    sp.cuda()

    with torch.set_grad_enabled(False):
        inference(sp)
Пример #3
0
def load_semantic_parser(args):
    if args.model in model_index:
        sp = EncoderDecoderLFramework(args)
    else:
        raise NotImplementedError
    sp.load_checkpoint(get_checkpoint_path(args))
    sp.cuda()
    sp.eval()
    return sp
Пример #4
0
def ensemble():
    dataset = data_loader.load_processed_data(args)
    split = 'test' if args.test else 'dev'
    dev_examples = dataset[split]
    print('{} dev examples loaded'.format(len(dev_examples)))
    if args.dataset_name == 'wikisql':
        engine_path = os.path.join(args.data_dir, '{}.db'.format(split))
        engine = DBEngine(engine_path)
    else:
        engine = None

    sps = [EncoderDecoderLFramework(args) for _ in ensemble_model_dirs]
    for i, model_dir in enumerate(ensemble_model_dirs):
        checkpoint_path = os.path.join(model_dir, 'model-best.16.tar')
        sps[i].schema_graphs = dataset['schema']
        sps[i].load_checkpoint(checkpoint_path)
        sps[i].cuda()
        sps[i].eval()

    pred_restored_cache = sps[0].load_pred_restored_cache()
    pred_restored_cache_size = sum(len(v)
                                   for v in pred_restored_cache.values())

    out_dict = sps[0].inference(dev_examples, restore_clause_order=args.process_sql_in_execution_order,
                                pred_restored_cache=pred_restored_cache,
                                check_schema_consistency_=args.sql_consistency_check, engine=engine,
                                inline_eval=True, model_ensemble=[sp.mdl for sp in sps], verbose=True)

    if args.process_sql_in_execution_order:
        new_pred_restored_cache_size = sum(
            len(v) for v in out_dict['pred_restored_cache'].values())
        newly_cached_size = new_pred_restored_cache_size - pred_restored_cache_size
        if newly_cached_size > 0:
            sps[0].save_pred_restored_cache(
                out_dict['pred_restored_cache'], newly_cached_size)

    out_txt = os.path.join(sps[0].model_dir, 'predictions.ens.{}.{}.{}.{}.txt'.format(
        args.beam_size, args.bs_alpha, split, len(ensemble_model_dirs)))
    with open(out_txt, 'w') as o_f:
        assert(len(dev_examples) == len(out_dict['pred_decoded']))
        for i, pred_sql in enumerate(out_dict['pred_decoded']):
            if args.dataset_name == 'wikisql':
                example = dev_examples[i]
                o_f.write('{}\n'.format(json.dumps(
                    {'sql': pred_sql[0], 'table_id': example.db_name})))
            else:
                o_f.write('{}\n'.format(pred_sql[0]))
        print('Model predictions saved to {}'.format(out_txt))

    print('{} set performance'.format(split.upper()))
    metrics = eval_tools.get_exact_match_metrics(
        dev_examples, out_dict['pred_decoded'], engine=engine)
    print('Top-1 exact match: {:.3f}'.format(metrics['top_1_em']))
    print('Top-2 exact match: {:.3f}'.format(metrics['top_2_em']))
    print('Top-3 exact match: {:.3f}'.format(metrics['top_3_em']))
    print('Top-5 exact match: {:.3f}'.format(metrics['top_5_em']))
    print('Top-10 exact match: {:.3f}'.format(metrics['top_10_em']))
def ensemble():
    text_tokenize, program_tokenize, post_process, table_utils = tok.get_tokenizers(
        args)
    schema_graphs = schema_loader.load_schema_graphs_spider(
        args.codalab_data_dir, 'spider', db_dir=args.codalab_db_dir)
    schema_graphs.lexicalize_graphs(tokenize=text_tokenize,
                                    normalized=(args.model_id
                                                in [utils.BRIDGE]))
    text_vocab = Vocabulary('text',
                            func_token_index=functional_token_index,
                            tu=table_utils)
    for v in table_utils.tokenizer.vocab:
        text_vocab.index_token(v, True,
                               table_utils.tokenizer.convert_tokens_to_ids(v))
    program_vocab = sql_reserved_tokens if args.pretrained_transformer else sql_reserved_tokens_revtok
    vocabs = {'text': text_vocab, 'program': program_vocab}
    examples = data_loader.load_data_split_spider(args.codalab_data_dir, 'dev',
                                                  schema_graphs)
    print('{} {} examples loaded'.format(len(examples), 'dev'))
    for i, example in enumerate(examples):
        schema_graph = schema_graphs.get_schema(example.db_id)
        preprocess_example('dev', example, args, None, text_tokenize,
                           program_tokenize, post_process, table_utils,
                           schema_graph, vocabs)
    print('{} {} examples processed'.format(len(examples), 'dev'))

    checkpoint_paths = [
        'ensemble_models/model1.tar', 'ensemble_models/model2.tar',
        'ensemble_models/model3.tar'
    ]

    sps = [EncoderDecoderLFramework(args) for _ in checkpoint_paths]
    for i, checkpoint_path in enumerate(checkpoint_paths):
        sps[i].schema_graphs = schema_graphs
        sps[i].load_checkpoint(checkpoint_path)
        sps[i].cuda()
        sps[i].eval()

    out_dict = sps[0].inference(
        examples,
        restore_clause_order=args.process_sql_in_execution_order,
        check_schema_consistency_=args.sql_consistency_check,
        inline_eval=False,
        model_ensemble=[sp.mdl for sp in sps],
        verbose=False)

    assert (sps[0].args.prediction_path is not None)
    out_txt = sps[0].args.prediction_path
    with open(out_txt, 'w') as o_f:
        for pred_sql in out_dict['pred_decoded']:
            o_f.write('{}\n'.format(pred_sql[0]))
        print('Model predictions saved to {}'.format(out_txt))
Пример #6
0
    def __init__(self, args, cs_args, schema, ensemble_model_dirs=None):
        self.args = args
        self.text_tokenize, _, _, self.tu = tok.get_tokenizers(args)

        # Vocabulary
        self.vocabs = data_loader.load_vocabs(args)

        # Confusion span detector
        self.confusion_span_detector = load_confusion_span_detector(cs_args)

        # Text-to-SQL model
        self.semantic_parsers = []
        self.model_ensemble = None
        if ensemble_model_dirs is None:
            sp = load_semantic_parser(args)
            sp.schema_graphs = SchemaGraphs()
            self.semantic_parsers.append(sp)
        else:
            sps = [EncoderDecoderLFramework(args) for _ in ensemble_model_dirs]
            for i, model_dir in enumerate(ensemble_model_dirs):
                checkpoint_path = os.path.join(model_dir, 'model-best.16.tar')
                sps[i].schema_graphs = SchemaGraphs()
                sps[i].load_checkpoint(checkpoint_path)
                sps[i].cuda()
                sps[i].eval()
            self.semantic_parsers = sps
            self.model_ensemble = [sp.mdl for sp in sps]

        if schema is not None:
            self.add_schema(schema)

        self.model_ensemble = None

        # When generating SQL in execution order, cache reordered SQLs to save time
        if args.process_sql_in_execution_order:
            self.pred_restored_cache = self.semantic_parsers[0].load_pred_restored_cache()
        else:
            self.pred_restored_cache = None