Пример #1
0
def demo_preprocess(args, example, vocabs=None, schema_graph=None):
    text_tokenize, program_tokenize, post_process, tu = tok.get_tokenizers(args)
    if not schema_graph:
        schema_graphs = load_schema_graphs(args)
        schema_graph = schema_graphs.get_schema(example.db_id)
    schema_graph.lexicalize_graph(tokenize=text_tokenize, normalized=(args.model_id in [BRIDGE]))
    preprocess_example('test', example, args, {}, text_tokenize, program_tokenize, post_process, tu, schema_graph, vocabs)
Пример #2
0
    def __init__(self, args):
        super().__init__(args)
        vocabs = data_loader.load_vocabs(args)
        self.in_vocab = vocabs['text']
        self.out_vocab = vocabs['program']

        # Construct NN model
        if self.model_id == BRIDGE:
            self.mdl = Bridge(args, self.in_vocab, self.out_vocab)
        elif self.model_id == SEQ2SEQ_PG:
            self.mdl = PointerGenerator(args, self.in_vocab, self.out_vocab)
        elif self.model_id == SEQ2SEQ:
            self.mdl = Seq2Seq(args, self.in_vocab, self.out_vocab)
        else:
            raise NotImplementedError

        # Specify loss function
        if self.args.loss == 'cross_entropy':
            self.loss_fun = MaskedCrossEntropyLoss(self.mdl.out_vocab.pad_id)
        else:
            raise NotImplementedError

        # Optimizer
        self.define_optimizer()
        self.define_lr_scheduler()

        # Post-process
        _, _, self.output_post_process, _ = tok.get_tokenizers(args)

        print('{} module created'.format(self.model))
Пример #3
0
def load_data(args):
    def load_split(in_json):
        examples = []
        with open(in_json) as f:
            content = json.load(f)
        for exp in tqdm(content):
            question = exp['question']
            question_tokens = exp['question_toks']
            db_name = exp['db_id']
            schema = schema_graphs[db_name]
            example = Example(question, schema)
            text_tokens = bu.tokenizer.tokenize(question)
            example.text_tokens = text_tokens
            example.text_ids = bu.tokenizer.convert_tokens_to_ids(example.text_tokens)
            schema_features, _ = schema.get_serialization(bu, flatten_features=True,
                                                          question_encoding=question,
                                                          top_k_matches=args.top_k_picklist_matches)
            example.input_tokens, _, _, _ = get_table_aware_transformer_encoder_inputs(
                text_tokens, text_tokens, schema_features, bu)
            example.ptr_input_ids = bu.tokenizer.convert_tokens_to_ids(example.input_tokens)
            if exp['untranslatable']:
                modify_span = exp['modify_span']
                if modify_span[0] == -1:
                    example.span_ids = [1, len(text_tokens)]
                else:
                    assert (modify_span[0] >= 0 and modify_span[1] >= 0)
                    span_ids = utils.get_sub_token_ids(question_tokens, modify_span, bu)
                    if span_ids[0] >= len(text_tokens) or span_ids[1] > len(text_tokens):
                        a, b = span_ids
                        while (a >= len(text_tokens)):
                            a -= 1
                        while (b > len(text_tokens)):
                            b -= 1
                        span_ids = (a, b)
                    example.span_ids = [span_ids[0] + 1, span_ids[1]]
            else:
                example.span_ids = [0, 0]
            examples.append(example)
        print('{} examples loaded from {}'.format(len(examples), in_json))
        return examples

    data_dir = args.data_dir
    train_json = os.path.join(data_dir, 'train_ut.json')
    dev_json = os.path.join(data_dir, 'dev_ut.json')
    text_tokenize, _, _, _ = tok.get_tokenizers(args)

    schema_graphs = load_schema_graphs(args)
    schema_graphs.lexicalize_graphs(tokenize=text_tokenize, normalized=True)
    if args.train:
        train_data = load_split(train_json)
    else:
        train_data = None
    dev_data = load_split(dev_json)
    dataset = dict()
    dataset['train'] = train_data
    dataset['dev'] = dev_data
    dataset['schema'] = schema_graphs
    return dataset
def ensemble():
    text_tokenize, program_tokenize, post_process, table_utils = tok.get_tokenizers(
        args)
    schema_graphs = schema_loader.load_schema_graphs_spider(
        args.codalab_data_dir, 'spider', db_dir=args.codalab_db_dir)
    schema_graphs.lexicalize_graphs(tokenize=text_tokenize,
                                    normalized=(args.model_id
                                                in [utils.BRIDGE]))
    text_vocab = Vocabulary('text',
                            func_token_index=functional_token_index,
                            tu=table_utils)
    for v in table_utils.tokenizer.vocab:
        text_vocab.index_token(v, True,
                               table_utils.tokenizer.convert_tokens_to_ids(v))
    program_vocab = sql_reserved_tokens if args.pretrained_transformer else sql_reserved_tokens_revtok
    vocabs = {'text': text_vocab, 'program': program_vocab}
    examples = data_loader.load_data_split_spider(args.codalab_data_dir, 'dev',
                                                  schema_graphs)
    print('{} {} examples loaded'.format(len(examples), 'dev'))
    for i, example in enumerate(examples):
        schema_graph = schema_graphs.get_schema(example.db_id)
        preprocess_example('dev', example, args, None, text_tokenize,
                           program_tokenize, post_process, table_utils,
                           schema_graph, vocabs)
    print('{} {} examples processed'.format(len(examples), 'dev'))

    checkpoint_paths = [
        'ensemble_models/model1.tar', 'ensemble_models/model2.tar',
        'ensemble_models/model3.tar'
    ]

    sps = [EncoderDecoderLFramework(args) for _ in checkpoint_paths]
    for i, checkpoint_path in enumerate(checkpoint_paths):
        sps[i].schema_graphs = schema_graphs
        sps[i].load_checkpoint(checkpoint_path)
        sps[i].cuda()
        sps[i].eval()

    out_dict = sps[0].inference(
        examples,
        restore_clause_order=args.process_sql_in_execution_order,
        check_schema_consistency_=args.sql_consistency_check,
        inline_eval=False,
        model_ensemble=[sp.mdl for sp in sps],
        verbose=False)

    assert (sps[0].args.prediction_path is not None)
    out_txt = sps[0].args.prediction_path
    with open(out_txt, 'w') as o_f:
        for pred_sql in out_dict['pred_decoded']:
            o_f.write('{}\n'.format(pred_sql[0]))
        print('Model predictions saved to {}'.format(out_txt))
Пример #5
0
    def __init__(self, args):
        super(LFramework, self).__init__()
        self.model = args.model
        self.model_id = args.model_id

        self.tu = utils.get_trans_utils(args)
        self.schema_graphs = None

        # Training hyperparameters
        self.args = args
        _, _, _, self.tu = tok.get_tokenizers(args)
        self.dataset = args.dataset_name
        self.model_dir = args.model_dir
        self.train_batch_size = args.train_batch_size
        self.dev_batch_size = args.dev_batch_size

        self.start_step = args.start_step
        self.num_steps = args.num_steps
        self.num_peek_steps = args.num_peek_steps
        self.num_log_steps = args.num_log_steps
        self.num_accumulation_steps = args.num_accumulation_steps
        self.save_best_model_only = args.save_best_model_only

        self.optimizer = args.optimizer
        self.bert_finetune_rate = args.bert_finetune_rate
        self.learning_rate = args.learning_rate
        self.learning_rate_scheduler = learning_rate_scheduler_sigs[
            args.learning_rate_scheduler]
        self.ft_learning_rate_scheduler = learning_rate_scheduler_sigs[
            args.trans_learning_rate_scheduler]
        self.warmup_init_lr = args.warmup_init_lr
        self.warmup_init_ft_lr = args.warmup_init_ft_lr
        self.num_warmup_steps = args.num_warmup_steps
        self.grad_norm = args.grad_norm
        self.adam_beta1 = args.adam_beta1
        self.adam_beta2 = args.adam_beta2
        self.optim = None
        self.lr_scheduler = None

        self.decoding_algorithm = args.decoding_algorithm
        self.beam_size = args.beam_size

        self.save_all_checkpoints = args.save_all_checkpoints

        # Visualization saver
        self.vis_writer = LayerVisualizationDataWriter(log_dir=args.viz_dir)
Пример #6
0
def preprocess(args, dataset, process_splits=('train', 'dev', 'test'), print_aggregated_stats=False, verbose=False,
               save_processed_data=True):
    """
    Data pre-processing for baselines that does only shallow processing on the schema.
    """
    text_tokenize, program_tokenize, post_process, trans_utils = tok.get_tokenizers(args)
    parsed_programs = load_parsed_sqls(args, augment_with_wikisql=args.augment_with_wikisql)
    num_parsed_programs = len(parsed_programs)

    vocabs = load_vocabs(args)

    schema_graphs = dataset['schema']
    schema_graphs.lexicalize_graphs(tokenize=text_tokenize, normalized=(args.model_id in [BRIDGE]))

    ############################
    # data statistics
    ds = DatasetStatistics()
    ############################

    # parallel data
    for split in process_splits:
        if split not in dataset:
            continue
        ds_split, sl_split = preprocess_split(dataset, split, args, parsed_programs,
                                              text_tokenize, program_tokenize, post_process, trans_utils,
                                              schema_graphs, vocabs, verbose=verbose)
        ds_split.print(split)
        sl_split.print()
        ############################
        # update data statistics
        ds.accumulate(ds_split)
        ############################

    if len(parsed_programs) > num_parsed_programs:
        save_parsed_sqls(args, parsed_programs)

    if print_aggregated_stats:
        ds.print()

    if save_processed_data:
        out_pkl = get_processed_data_path(args)
        with open(out_pkl, 'wb') as o_f:
            pickle.dump(dataset, o_f)
            print('Processed data dumped to {}'.format(out_pkl))
Пример #7
0
    def __init__(self, args, cs_args, schema, ensemble_model_dirs=None):
        self.args = args
        self.text_tokenize, _, _, self.tu = tok.get_tokenizers(args)

        # Vocabulary
        self.vocabs = data_loader.load_vocabs(args)

        # Confusion span detector
        self.confusion_span_detector = load_confusion_span_detector(cs_args)

        # Text-to-SQL model
        self.semantic_parsers = []
        self.model_ensemble = None
        if ensemble_model_dirs is None:
            sp = load_semantic_parser(args)
            sp.schema_graphs = SchemaGraphs()
            self.semantic_parsers.append(sp)
        else:
            sps = [EncoderDecoderLFramework(args) for _ in ensemble_model_dirs]
            for i, model_dir in enumerate(ensemble_model_dirs):
                checkpoint_path = os.path.join(model_dir, 'model-best.16.tar')
                sps[i].schema_graphs = SchemaGraphs()
                sps[i].load_checkpoint(checkpoint_path)
                sps[i].cuda()
                sps[i].eval()
            self.semantic_parsers = sps
            self.model_ensemble = [sp.mdl for sp in sps]

        if schema is not None:
            self.add_schema(schema)

        self.model_ensemble = None

        # When generating SQL in execution order, cache reordered SQLs to save time
        if args.process_sql_in_execution_order:
            self.pred_restored_cache = self.semantic_parsers[0].load_pred_restored_cache()
        else:
            self.pred_restored_cache = None
def build_vocab(args, dataset, schema_graphs):
    """
    Construct vocabularies.

    This function saves to disk:
    - text vocab: consists of tokens appeared in the natural language query and schema
    - program vocab: consists of tokens appeared in the program
    - schema vocab: consists of table and field names from the schema
    - world vocab: consists of tokens in the program that does not come from any of the above category
      (which likely needed to be inferred from world knowledge)
    """
    print('Constructing vocabulary...')

    text_tokenize, program_tokenize, _, tu = tok.get_tokenizers(args)
    if args.pretrained_transformer:
        sql_reserved_vocab = sql_reserved_tokens
    else:
        sql_reserved_vocab = sql_reserved_tokens_revtok
    parsed_programs = load_parsed_sqls(
        args, augment_with_wikisql=args.augment_with_wikisql)

    schema_graphs.lexicalize_graphs(tokenize=text_tokenize,
                                    normalized=(args.model_id in [BRIDGE]))

    # compute text and program vocab
    text_hist, program_hist = collections.defaultdict(
        int), collections.defaultdict(int)
    world_vocab = Vocabulary('world')

    for split in ['train', 'dev', 'test']:
        if not split in dataset:
            continue
        data_split = dataset[split]
        for i, example in enumerate(data_split):
            if isinstance(example, AugmentedText2SQLExample):
                continue
            schema_graph = schema_graphs.get_schema(example.db_id)
            text = example.text
            if args.pretrained_transformer:
                text_tokens = text_tokenize(text)
            else:
                text_tokens = text_tokenize(text.lower(), functional_tokens)
            for word in text_tokens:
                text_hist[word] += 1
            for program in example.program_list:
                ast, _ = get_ast(program, parsed_programs,
                                 args.denormalize_sql, schema_graph)
                if ast:
                    program = ast
                program_tokens = program_tokenize(
                    program,
                    omit_from_clause=args.omit_from_clause,
                    no_join_condition=args.no_join_condition)
                for token in program_tokens:
                    program_hist[token] += 1
                    if split == 'train':
                        if not token in text_tokens and not sql_reserved_vocab.contains(
                                token):
                            world_vocab.index_token(token, in_vocab=True)
            if i > 0 and i % 5000 == 0:
                print('{} examples processed'.format(i))

    if args.pretrained_transformer.startswith(
            'bert') or args.pretrained_transformer == 'table-bert':
        text_hist = dict()
        for v in tu.tokenizer.vocab:
            text_hist[v] = tu.tokenizer.vocab[v]
        for v in tu.tokenizer.added_tokens_encoder:
            text_hist[v] = tu.tokenizer.convert_tokens_to_ids(v)
        schema_lexical_vocab = None
    elif args.pretrained_transformer.startswith('roberta'):
        text_hist = tu.tokenizer.encoder
        schema_lexical_vocab = None
    else:
        schema_lexical_vocab = schema_graphs.get_lexical_vocab()

    export_vocab(text_hist, program_hist, schema_lexical_vocab, world_vocab,
                 args)