def demo_preprocess(args, example, vocabs=None, schema_graph=None):
    text_tokenize, program_tokenize, post_process, tu = tok.get_tokenizers(args)
    if not schema_graph:
        schema_graphs = load_schema_graphs(args)
        schema_graph = schema_graphs.get_schema(example.db_id)
    schema_graph.lexicalize_graph(tokenize=text_tokenize, normalized=(args.model_id in [BRIDGE]))
    preprocess_example('test', example, args, {}, text_tokenize, program_tokenize, post_process, tu, schema_graph, vocabs)
Exemple #2
0
def load_data(args):
    def load_split(in_json):
        examples = []
        with open(in_json) as f:
            content = json.load(f)
        for exp in tqdm(content):
            question = exp['question']
            question_tokens = exp['question_toks']
            db_name = exp['db_id']
            schema = schema_graphs[db_name]
            example = Example(question, schema)
            text_tokens = bu.tokenizer.tokenize(question)
            example.text_tokens = text_tokens
            example.text_ids = bu.tokenizer.convert_tokens_to_ids(example.text_tokens)
            schema_features, _ = schema.get_serialization(bu, flatten_features=True,
                                                          question_encoding=question,
                                                          top_k_matches=args.top_k_picklist_matches)
            example.input_tokens, _, _, _ = get_table_aware_transformer_encoder_inputs(
                text_tokens, text_tokens, schema_features, bu)
            example.ptr_input_ids = bu.tokenizer.convert_tokens_to_ids(example.input_tokens)
            if exp['untranslatable']:
                modify_span = exp['modify_span']
                if modify_span[0] == -1:
                    example.span_ids = [1, len(text_tokens)]
                else:
                    assert (modify_span[0] >= 0 and modify_span[1] >= 0)
                    span_ids = utils.get_sub_token_ids(question_tokens, modify_span, bu)
                    if span_ids[0] >= len(text_tokens) or span_ids[1] > len(text_tokens):
                        a, b = span_ids
                        while (a >= len(text_tokens)):
                            a -= 1
                        while (b > len(text_tokens)):
                            b -= 1
                        span_ids = (a, b)
                    example.span_ids = [span_ids[0] + 1, span_ids[1]]
            else:
                example.span_ids = [0, 0]
            examples.append(example)
        print('{} examples loaded from {}'.format(len(examples), in_json))
        return examples

    data_dir = args.data_dir
    train_json = os.path.join(data_dir, 'train_ut.json')
    dev_json = os.path.join(data_dir, 'dev_ut.json')
    text_tokenize, _, _, _ = tok.get_tokenizers(args)

    schema_graphs = load_schema_graphs(args)
    schema_graphs.lexicalize_graphs(tokenize=text_tokenize, normalized=True)
    if args.train:
        train_data = load_split(train_json)
    else:
        train_data = None
    dev_data = load_split(dev_json)
    dataset = dict()
    dataset['train'] = train_data
    dataset['dev'] = dev_data
    dataset['schema'] = schema_graphs
    return dataset