def format_batch(self, mini_batch): def get_decoder_input_ids(): if self.training: if self.model_id in [BRIDGE]: X = [ exp.program_singleton_field_input_ids for exp in mini_batch ] else: X = [exp.program_input_ids for exp in mini_batch] return ops.pad_batch(X, self.mdl.out_vocab.pad_id) else: return None def get_encoder_attn_mask(table_names, table_masks): schema_pos = [ schema_graph.get_schema_pos(table_name) for table_name in table_names ] encoder_attn_mask = [1 for _ in range(exp.num_text_tokens)] # asterisk marker encoder_attn_mask.append(1) is_selected_table = False for j in range(1, len(table_masks)): if j in schema_pos: encoder_attn_mask.append(1) is_selected_table = True elif table_masks[j] == 1: # mask current table encoder_attn_mask.append(0) is_selected_table = False else: if is_selected_table: encoder_attn_mask.append(1) else: encoder_attn_mask.append(0) return encoder_attn_mask super().format_batch(mini_batch) encoder_input_ids = ops.pad_batch([exp.text_ids for exp in mini_batch], self.mdl.in_vocab.pad_id) decoder_input_ids = get_decoder_input_ids() table_samples = [] if self.model_id == SEQ2SEQ: return encoder_input_ids, decoder_input_ids elif self.model_id in [BRIDGE]: encoder_ptr_input_ids, encoder_ptr_value_ids, decoder_ptr_value_ids = [], [], [] primary_key_ids, foreign_key_ids, field_type_ids, table_masks, table_positions, table_field_scopes, \ field_table_pos, transformer_output_value_masks, schema_memory_masks = [], [], [], [], [], [], [], [], [] for exp in mini_batch: schema_graph = self.schema_graphs.get_schema(exp.db_id) # exp.pretty_print(example_id=0, # schema=schema_graph, # de_vectorize_ptr=vec.de_vectorize_ptr, # de_vectorize_field_ptr=vec.de_vectorize_field_ptr, # rev_vocab=self.out_vocab, # post_process=self.output_post_process, # use_table_aware_te=(self.model_id in [BRIDGE])) # import pdb # pdb.set_trace() if self.training: # Compute schema layout if exp.gt_table_names_list: gt_tables = set([ schema_graph.get_table_id(t_name) for t_name in exp.gt_table_names ]) else: gt_table_names = [ token for token, t in zip( exp.program_singleton_field_tokens, exp.program_singleton_field_token_types) if t == 0 ] gt_tables = set([ schema_graph.get_table_id(t_name) for t_name in gt_table_names ]) # [Hack] Baseball database has a complex schema which does not fit the input size of BERT. We select # the ground truth tables and randomly add a few other tables for training. if schema_graph.name.startswith('baseball'): tables = list(gt_tables) tables += random.sample( [ i for i in range(schema_graph.num_tables) if i not in gt_tables ], k=min(random.randint(1, 7), schema_graph.num_tables - len(gt_tables))) else: tables = list(range(schema_graph.num_tables)) if self.args.table_shuffling: table_to_drop = random.choice(tables) if table_to_drop not in gt_tables: if random.uniform(0, 1) < 0.3: tables = [ x for x in tables if x != table_to_drop ] table_po, field_po = schema_graph.get_schema_perceived_order( tables, random_table_order=True, random_field_order=self.args.random_field_order) else: table_po, field_po = schema_graph.get_schema_perceived_order( tables, random_table_order=False, random_field_order=self.args.random_field_order) # Schema feature extraction question_encoding = exp.text if self.args.use_picklist else None schema_features, matched_values = schema_graph.get_serialization( self.tu, flatten_features=True, table_po=table_po, field_po=field_po, use_typed_field_markers=self.args. use_typed_field_markers, use_graph_encoding=self.args.use_graph_encoding, question_encoding=question_encoding, top_k_matches=self.args.top_k_picklist_matches, num_values_per_field=self.args.num_values_per_field, no_anchor_text=self.args.no_anchor_text, verbose=False) ptr_input_tokens, ptr_input_values, num_excluded_tables, num_excluded_fields = \ get_table_aware_transformer_encoder_inputs( exp.text_ptr_values, exp.text_tokens, schema_features, self.tu) assert (len(ptr_input_tokens) <= self.tu.tokenizer.max_len) if num_excluded_fields > 0: print('Warning: training input truncated') num_included_nodes = schema_graph.get_num_perceived_nodes(tables) + 1 \ - num_excluded_tables - num_excluded_fields encoder_ptr_input_ids.append( self.tu.tokenizer.convert_tokens_to_ids( ptr_input_tokens)) if self.args.read_picklist: exp.transformer_output_value_mask, value_features, value_tokens = \ get_transformer_output_value_mask(ptr_input_tokens, matched_values, self.tu) transformer_output_value_masks.append( exp.transformer_output_value_mask) primary_key_ids.append( schema_graph.get_primary_key_ids( num_included_nodes, table_po, field_po)) foreign_key_ids.append( schema_graph.get_foreign_key_ids( num_included_nodes, table_po, field_po)) field_type_ids.append( schema_graph.get_field_type_ids( num_included_nodes, table_po, field_po)) table_masks.append( schema_graph.get_table_masks(num_included_nodes, table_po, field_po)) # Value copy feature extraction if self.args.read_picklist: constant_memory_features = exp.text_tokens + value_features constant_memory = exp.text_ptr_values + value_tokens exp.text_ptr_values = constant_memory else: constant_memory_features = exp.text_tokens constant_ptr_value_ids, constant_unique_input_ids = vec.vectorize_ptr_in( constant_memory_features, self.out_vocab) encoder_ptr_value_ids.append(constant_ptr_value_ids + [ self.out_vocab.size + len(constant_memory_features) + x for x in range(num_included_nodes) ]) program_field_ptr_value_ids = \ vec.vectorize_field_ptr_out(exp.program_singleton_field_tokens, exp.program_singleton_field_token_types, self.out_vocab, constant_unique_input_ids, max_memory_size=len(constant_memory_features), schema=schema_graph, num_included_nodes=num_included_nodes) decoder_ptr_value_ids.append(program_field_ptr_value_ids) else: encoder_ptr_input_ids = [ exp.ptr_input_ids for exp in mini_batch ] encoder_ptr_value_ids = [ exp.ptr_value_ids for exp in mini_batch ] decoder_ptr_value_ids = [exp.program_text_and_field_ptr_value_ids for exp in mini_batch] \ if self.training else None primary_key_ids = [ exp.primary_key_ids for exp in mini_batch ] foreign_key_ids = [ exp.foreign_key_ids for exp in mini_batch ] field_type_ids = [exp.field_type_ids for exp in mini_batch] table_masks = [exp.table_masks for exp in mini_batch] # TODO: here we assume that all nodes in the schema graph are included table_pos, table_field_scope = schema_graph.get_table_scopes( schema_graph.num_nodes) table_positions.append(table_pos) table_field_scopes.append(table_field_scope) if self.args.read_picklist: transformer_output_value_masks.append( exp.transformer_output_value_mask) encoder_ptr_input_ids = ops.pad_batch(encoder_ptr_input_ids, self.mdl.in_vocab.pad_id) encoder_ptr_value_ids = ops.pad_batch(encoder_ptr_value_ids, self.mdl.in_vocab.pad_id) schema_memory_masks = ops.pad_batch(schema_memory_masks, pad_id=0) \ if (self.args.use_pred_tables and not self.training) else (None, None) decoder_ptr_value_ids = ops.pad_batch(decoder_ptr_value_ids, self.mdl.out_vocab.pad_id) \ if self.training else None primary_key_ids = ops.pad_batch(primary_key_ids, self.mdl.in_vocab.pad_id) foreign_key_ids = ops.pad_batch(foreign_key_ids, self.mdl.in_vocab.pad_id) field_type_ids = ops.pad_batch(field_type_ids, self.mdl.in_vocab.pad_id) table_masks = ops.pad_batch(table_masks, pad_id=0) transformer_output_value_masks = ops.pad_batch(transformer_output_value_masks, pad_id=0, dtype=torch.uint8) \ if self.args.read_picklist else (None, None) if not self.training: table_positions = ops.pad_batch(table_positions, pad_id=-1) \ if self.args.process_sql_in_execution_order else (None, None) table_field_scopes = ops.pad_batch_2D(table_field_scopes, pad_id=0) \ if self.args.process_sql_in_execution_order else (None, None) graphs = None return encoder_input_ids, decoder_input_ids, encoder_ptr_input_ids, encoder_ptr_value_ids, \ decoder_ptr_value_ids, transformer_output_value_masks, schema_memory_masks, graphs, \ (primary_key_ids, foreign_key_ids, field_type_ids, table_masks, table_positions, table_field_scopes, field_table_pos), table_samples elif self.model_id in [SEQ2SEQ_PG]: encoder_ptr_input_ids = [exp.ptr_input_ids for exp in mini_batch] encoder_ptr_value_ids = [exp.ptr_value_ids for exp in mini_batch] decoder_ptr_value_ids = [ exp.program_text_ptr_value_ids for exp in mini_batch ] encoder_ptr_input_ids = ops.pad_batch(encoder_ptr_input_ids, self.mdl.in_vocab.pad_id) encoder_ptr_value_ids = ops.pad_batch(encoder_ptr_value_ids, self.mdl.in_vocab.pad_id) decoder_ptr_value_ids = ops.pad_batch(decoder_ptr_value_ids, self.mdl.out_vocab.pad_id) return encoder_input_ids, decoder_input_ids, encoder_ptr_input_ids, encoder_ptr_value_ids, \ decoder_ptr_value_ids else: raise NotImplementedError
def preprocess_example(split, example, args, parsed_programs, text_tokenize, program_tokenize, post_process, table_utils, schema_graph, vocabs, verbose=False): tu = table_utils text_vocab = vocabs['text'] program_vocab = vocabs['program'] def get_memory_values(features, raw_text, args): if args.pretrained_transformer.startswith( 'bert-') and args.pretrained_transformer.endswith('-uncased'): return utils.restore_feature_case(features, raw_text, tu) else: return features def get_text_schema_adjacency_matrix(text_features, s_M): schema_size = s_M.shape[0] text_size = len(text_features) full_size = schema_size + text_size M = ssp.lil_matrix((full_size, full_size), dtype=np.int) M[-schema_size:, -schema_size:] = s_M return M # sanity check ############################ query_oov = False denormalized = False schema_truncated = False token_restored = True ############################ # Text feature extraction and set program ground truth list if isinstance(example, Text2SQLExample): if args.pretrained_transformer: text_features = text_tokenize(example.text) text_tokens, token_starts, token_ends = get_memory_values( text_features, example.text, args) if not token_starts: token_restored = False else: text_tokens = text_tokenize(example.text, functional_tokens) text_features = [t.lower() for t in text_tokens] example.text_tokens = text_features example.text_ptr_values = text_tokens example.text_token_starts = token_starts example.text_token_ends = token_ends example.text_ids = vec.vectorize(text_features, text_vocab) example.text_ptr_input_ids = vec.vectorize(text_features, text_vocab) program_list = example.program_list example.values = [ (schema_graph.get_field(cond[0]).signature, cond[2]) for cond in example.program_ast_list_[0]['conds'] if (isinstance(cond[2], str) and not is_number(cond[2])) ] else: text_tokens = example.example.text_ptr_values text_features = example.example.text_tokens program_list = example.example.program_list # Schema feature extraction if args.model_id in [BRIDGE]: question_encoding = example.text if args.use_picklist else None tables = sorted([schema_graph.get_table_id(t_name) for t_name in example.gt_table_names]) \ if args.use_oracle_tables else None table_po, field_po = schema_graph.get_schema_perceived_order(tables) schema_features, matched_values = schema_graph.get_serialization( tu, flatten_features=True, table_po=table_po, field_po=field_po, use_typed_field_markers=args.use_typed_field_markers, use_graph_encoding=args.use_graph_encoding, question_encoding=question_encoding, top_k_matches=args.top_k_picklist_matches, num_values_per_field=args.num_values_per_field, no_anchor_text=args.no_anchor_text) example.matched_values = matched_values example.input_tokens, example.input_ptr_values, num_excluded_tables, num_excluded_fields = \ get_table_aware_transformer_encoder_inputs(text_tokens, text_features, schema_features, table_utils) schema_truncated = (num_excluded_fields > 0) num_included_nodes = schema_graph.get_num_perceived_nodes( table_po) + 1 - num_excluded_tables - num_excluded_fields example.ptr_input_ids = vec.vectorize(example.input_tokens, text_vocab) if args.read_picklist: example.transformer_output_value_mask, value_features, value_tokens = \ get_transformer_output_value_mask(example.input_tokens, matched_values, tu) example.primary_key_ids = schema_graph.get_primary_key_ids( num_included_nodes, table_po=table_po, field_po=field_po) example.foreign_key_ids = schema_graph.get_foreign_key_ids( num_included_nodes, table_po=table_po, field_po=field_po) example.field_type_ids = schema_graph.get_field_type_ids( num_included_nodes, table_po=table_po, field_po=field_po) example.table_masks = schema_graph.get_table_masks(num_included_nodes, table_po=table_po, field_po=field_po) example.field_table_pos = schema_graph.get_field_table_pos( num_included_nodes, table_po=table_po, field_po=field_po) example.schema_M = schema_graph.adj_matrix example.M = get_text_schema_adjacency_matrix(text_features, example.schema_M) else: num_included_nodes = schema_graph.num_nodes # Value copy feature extraction if args.read_picklist: constant_memory_features = text_features + value_features constant_memory = text_tokens + value_tokens example.text_ptr_values = constant_memory else: constant_memory_features = text_features constant_ptr_value_ids, constant_unique_input_ids = vec.vectorize_ptr_in( constant_memory_features, program_vocab) if isinstance(example, Text2SQLExample): example.text_ptr_value_ids = constant_ptr_value_ids example.ptr_value_ids = constant_ptr_value_ids + [ program_vocab.size + len(constant_memory_features) + x for x in range(num_included_nodes) ] if not args.leaderboard_submission: for j, program in enumerate(program_list): if isinstance(example, Text2SQLExample): # Model II. Bridge output program_singleton_field_tokens, program_singleton_field_token_types = \ tok.wikisql_struct_to_tokens(example.program_ast_, schema_graph, tu) program_singleton_field_tokens = [ START_TOKEN ] + program_singleton_field_tokens + [EOS_TOKEN] program_singleton_field_token_types = \ [RESERVED_TOKEN_TYPE] + program_singleton_field_token_types + [RESERVED_TOKEN_TYPE] example.program_singleton_field_tokens_list.append( program_singleton_field_tokens) example.program_singleton_field_token_types_list.append( program_singleton_field_token_types) program_singleton_field_input_ids = vec.vectorize_singleton( program_singleton_field_tokens, program_singleton_field_token_types, program_vocab) example.program_singleton_field_input_ids_list.append( program_singleton_field_input_ids) else: # Model II. Bridge output example.program_singleton_field_input_ids_list.append( example.example.program_singleton_field_input_ids_list[j]) program_singleton_field_tokens = example.example.program_singleton_field_tokens_list[ j] program_singleton_field_token_types = example.example.program_singleton_field_token_types_list[ j] program_field_ptr_value_ids = vec.vectorize_field_ptr_out( program_singleton_field_tokens, program_singleton_field_token_types, program_vocab, constant_unique_input_ids, max_memory_size=len(constant_memory_features), schema=schema_graph, num_included_nodes=num_included_nodes) example.program_text_and_field_ptr_value_ids_list.append( program_field_ptr_value_ids) table_ids = [ schema_graph.get_table_id(table_name) for table_name in example.gt_table_names_list[j] ] example.table_ids_list.append(table_ids) assert ([schema_graph.get_table(x).name for x in table_ids] == example.gt_table_names) # sanity check ############################ # NL+Schema pointer output contains tokens that does not belong to any of the following categories if verbose: if program_vocab.unk_id in program_field_ptr_value_ids: unk_indices = [ i for i, x in enumerate(program_field_ptr_value_ids) if x == program_vocab.unk_id ] print('OOV II: {}'.format(' '.join([ program_singleton_field_tokens[i] for i in unk_indices ]))) example.pretty_print( schema=schema_graph, de_vectorize_ptr=vec.de_vectorize_ptr, de_vectorize_field_ptr=vec.de_vectorize_field_ptr, rev_vocab=program_vocab, post_process=post_process, use_table_aware_te=(args.model_id in [BRIDGE])) query_oov = True if program_vocab.unk_field_id in program_field_ptr_value_ids: example.pretty_print( schema=schema_graph, de_vectorize_ptr=vec.de_vectorize_ptr, de_vectorize_field_ptr=vec.de_vectorize_field_ptr, rev_vocab=program_vocab, post_process=post_process, use_table_aware_te=(args.model_id in [BRIDGE])) if program_vocab.unk_table_id in program_field_ptr_value_ids: example.pretty_print( schema=schema_graph, de_vectorize_ptr=vec.de_vectorize_ptr, de_vectorize_field_ptr=vec.de_vectorize_field_ptr, rev_vocab=program_vocab, post_process=post_process, use_table_aware_te=(args.model_id in [BRIDGE])) ############################ # Store the ground truth queries after preprocessing to run a relaxed evaluation or # to evaluate with partial queries if split == 'dev': input_tokens = text_tokens if args.model_id in [BRIDGE]: _p = vec.de_vectorize_field_ptr( program_field_ptr_value_ids, program_vocab, input_tokens, schema=schema_graph, post_process=post_process) else: _p = program example.gt_program_list.append(_p) # sanity check ############################ # try: # assert(equal_ignoring_trivial_diffs(_p, program.lower(), verbose=True)) # except Exception: # print('_p:\t\t{}'.format(_p)) # print('program:\t{}'.format(program)) # print() # import pdb # pdb.set_trace() ############################ example.run_unit_tests() return query_oov, denormalized, schema_truncated, token_restored