def demo_preprocess(args, example, vocabs=None, schema_graph=None): text_tokenize, program_tokenize, post_process, tu = tok.get_tokenizers(args) if not schema_graph: schema_graphs = load_schema_graphs(args) schema_graph = schema_graphs.get_schema(example.db_id) schema_graph.lexicalize_graph(tokenize=text_tokenize, normalized=(args.model_id in [BRIDGE])) preprocess_example('test', example, args, {}, text_tokenize, program_tokenize, post_process, tu, schema_graph, vocabs)
def load_data(args): def load_split(in_json): examples = [] with open(in_json) as f: content = json.load(f) for exp in tqdm(content): question = exp['question'] question_tokens = exp['question_toks'] db_name = exp['db_id'] schema = schema_graphs[db_name] example = Example(question, schema) text_tokens = bu.tokenizer.tokenize(question) example.text_tokens = text_tokens example.text_ids = bu.tokenizer.convert_tokens_to_ids(example.text_tokens) schema_features, _ = schema.get_serialization(bu, flatten_features=True, question_encoding=question, top_k_matches=args.top_k_picklist_matches) example.input_tokens, _, _, _ = get_table_aware_transformer_encoder_inputs( text_tokens, text_tokens, schema_features, bu) example.ptr_input_ids = bu.tokenizer.convert_tokens_to_ids(example.input_tokens) if exp['untranslatable']: modify_span = exp['modify_span'] if modify_span[0] == -1: example.span_ids = [1, len(text_tokens)] else: assert (modify_span[0] >= 0 and modify_span[1] >= 0) span_ids = utils.get_sub_token_ids(question_tokens, modify_span, bu) if span_ids[0] >= len(text_tokens) or span_ids[1] > len(text_tokens): a, b = span_ids while (a >= len(text_tokens)): a -= 1 while (b > len(text_tokens)): b -= 1 span_ids = (a, b) example.span_ids = [span_ids[0] + 1, span_ids[1]] else: example.span_ids = [0, 0] examples.append(example) print('{} examples loaded from {}'.format(len(examples), in_json)) return examples data_dir = args.data_dir train_json = os.path.join(data_dir, 'train_ut.json') dev_json = os.path.join(data_dir, 'dev_ut.json') text_tokenize, _, _, _ = tok.get_tokenizers(args) schema_graphs = load_schema_graphs(args) schema_graphs.lexicalize_graphs(tokenize=text_tokenize, normalized=True) if args.train: train_data = load_split(train_json) else: train_data = None dev_data = load_split(dev_json) dataset = dict() dataset['train'] = train_data dataset['dev'] = dev_data dataset['schema'] = schema_graphs return dataset