def de_vectorize(self, p_cpu, out_vocab, input_ptr_values, schema=None, table_po=None, field_po=None, return_tokens=False): # convert output prediction vector into a human readable string if self.model_id in [SEQ2SEQ_PG]: return vec.de_vectorize_ptr(p_cpu, out_vocab, memory=input_ptr_values, post_process=self.output_post_process, return_tokens=return_tokens) elif self.model_id in [BRIDGE]: return vec.de_vectorize_field_ptr( p_cpu, out_vocab, memory=input_ptr_values, schema=schema, table_po=table_po, field_po=field_po, post_process=self.output_post_process, return_tokens=return_tokens) elif self.model_id == SEQ2SEQ: return vec.de_vectorize(p_cpu, out_vocab, post_process=self.output_post_process, return_tokens=return_tokens) else: raise NotImplementedError
def eval_example(self, example_pred, example, verbose=False, example_id=None, schema_graph=None): example_pred = self.pack_sql(example_pred) gt_ast = example.program_ast s_correct, g_correct, o_correct, f_correct, w_correct, h_correct, l_correct = self.match( example_pred, gt_ast) sql_correct = s_correct and g_correct and o_correct and f_correct and w_correct and h_correct and l_correct if not sql_correct and verbose: assert (schema_graph is not None) text = example.text text_tokens = vec.de_vectorize(example.text_ids, return_tokens=True) text_ptr_values = example.text_ptr_values print('Example {}'.format(example_id)) print('NL:\t{}'.format(text.encode('utf-8'))) print('NL tokens:\t{}'.format(encode_str_list( text_tokens, 'utf-8'))) print('NL tokens (original):\t{}'.format( encode_str_list(text_ptr_values, 'utf-8'))) print('NL tokens (recovered): {}'.format(vec.de_vectorize_ptr( example.text_ptr_value_ids, self.value_vocab, text_ptr_values).encode('utf-8'), return_tokens=True)) for i, program in enumerate(example.program_list): print('Target {}'.format(i)) print('- string: {}'.format(program.encode('utf-8'))) badges = [ get_badge(sql_correct), get_badge(s_correct), get_badge(g_correct), get_badge(o_correct), get_badge(f_correct), get_badge(w_correct), get_badge(h_correct), get_badge(l_correct) ] print(' '.join(badges)) serializer = SQLSerializer(schema_graph, self.field_vocab, self.aggregation_ops, self.arithmetic_ops, self.condition_ops, self.logical_ops, self.value_vocab) print('select clause: {}'.format( serializer.serialize_select(example_pred['select']))) print('group by clause: {}'.format( serializer.serialize_group_by(example_pred['groupBy']))) print('order by clause: {}'.format( serializer.serialize_order_by(example_pred['orderBy']))) return sql_correct, s_correct, g_correct, o_correct, f_correct, w_correct, h_correct, l_correct
def preprocess_example(split, example, args, parsed_programs, text_tokenize, program_tokenize, post_process, trans_utils, schema_graph, vocabs, verbose=False): tu = trans_utils text_vocab = vocabs['text'] program_vocab = vocabs['program'] def get_memory_values(features, raw_text, args): if args.pretrained_transformer.startswith( 'bert-') and args.pretrained_transformer.endswith('-uncased'): return utils.restore_feature_case(features, raw_text, tu) else: return features def get_text_schema_adjacency_matrix(text_features, s_M): schema_size = s_M.shape[0] text_size = len(text_features) full_size = schema_size + text_size M = ssp.lil_matrix((full_size, full_size), dtype=np.int) M[-schema_size:, -schema_size:] = s_M return M # sanity check ############################ query_oov = False denormalized = False schema_truncated = False token_restored = True ############################ # Text feature extraction and set program ground truth list if isinstance(example, Text2SQLExample): if args.pretrained_transformer: text_features = text_tokenize(example.text) text_tokens, token_starts, token_ends = get_memory_values( text_features, example.text, args) if not token_starts: token_restored = False else: text_tokens = text_tokenize(example.text, functional_tokens) text_features = [t.lower() for t in text_tokens] example.text_tokens = text_features example.text_ptr_values = text_tokens example.text_token_starts = token_starts example.text_token_ends = token_ends example.text_ids = vec.vectorize(text_features, text_vocab) example.text_ptr_input_ids = vec.vectorize(text_features, text_vocab) program_list = example.program_list else: text_tokens = example.example.text_ptr_values text_features = example.example.text_tokens program_list = example.example.program_list # Schema feature extraction if args.model_id in [BRIDGE]: question_encoding = example.text if args.use_picklist else None tables = sorted([schema_graph.get_table_id(t_name) for t_name in example.gt_table_names]) \ if args.use_oracle_tables else None table_po, field_po = schema_graph.get_schema_perceived_order(tables) schema_features, matched_values = schema_graph.get_serialization( tu, flatten_features=True, table_po=table_po, field_po=field_po, use_typed_field_markers=args.use_typed_field_markers, use_graph_encoding=args.use_graph_encoding, question_encoding=question_encoding, top_k_matches=args.top_k_picklist_matches, match_threshold=args.anchor_text_match_threshold, num_values_per_field=args.num_values_per_field, no_anchor_text=args.no_anchor_text) example.matched_values = matched_values example.input_tokens, example.input_ptr_values, num_excluded_tables, num_excluded_fields = \ get_table_aware_transformer_encoder_inputs(text_tokens, text_features, schema_features, trans_utils) schema_truncated = (num_excluded_fields > 0) num_included_nodes = schema_graph.get_num_perceived_nodes( table_po) + 1 - num_excluded_tables - num_excluded_fields example.ptr_input_ids = vec.vectorize(example.input_tokens, text_vocab) if args.read_picklist: example.transformer_output_value_mask, value_features, value_tokens = \ get_transformer_output_value_mask(example.input_tokens, matched_values, tu) example.primary_key_ids = schema_graph.get_primary_key_ids( num_included_nodes, table_po=table_po, field_po=field_po) example.foreign_key_ids = schema_graph.get_foreign_key_ids( num_included_nodes, table_po=table_po, field_po=field_po) example.field_type_ids = schema_graph.get_field_type_ids( num_included_nodes, table_po=table_po, field_po=field_po) example.table_masks = schema_graph.get_table_masks(num_included_nodes, table_po=table_po, field_po=field_po) example.field_table_pos = schema_graph.get_field_table_pos( num_included_nodes, table_po=table_po, field_po=field_po) example.schema_M = schema_graph.adj_matrix example.M = get_text_schema_adjacency_matrix(text_features, example.schema_M) else: num_included_nodes = schema_graph.num_nodes # Value copy feature extraction if args.read_picklist: constant_memory_features = text_features + value_features constant_memory = text_tokens + value_tokens example.text_ptr_values = constant_memory else: constant_memory_features = text_features constant_ptr_value_ids, constant_unique_input_ids = vec.vectorize_ptr_in( constant_memory_features, program_vocab) if isinstance(example, Text2SQLExample): example.text_ptr_value_ids = constant_ptr_value_ids example.ptr_value_ids = constant_ptr_value_ids + [ program_vocab.size + len(constant_memory_features) + x for x in range(num_included_nodes) ] if not args.leaderboard_submission: for j, program in enumerate(program_list): if isinstance(example, Text2SQLExample): ast, denormalized = get_ast(program, parsed_programs, args.denormalize_sql, schema_graph) if ast: example.program_ast_list.append(ast) program_tokens = program_tokenize( ast, schema=schema_graph, omit_from_clause=args.omit_from_clause, no_join_condition=args.no_join_condition, in_execution_order=args.process_sql_in_execution_order) assert (len(program_tokens) > 0) else: program_tokens = ['from'] program_tokens = [START_TOKEN] + program_tokens + [EOS_TOKEN] program_input_ids = vec.vectorize(program_tokens, program_vocab) example.program_input_ids_list.append(program_input_ids) if ast: example.values = extract_values(ast, schema_graph) else: example.values = [] # Model I. Vanilla pointer-generator output if args.model_id in [SEQ2SEQ_PG]: program_text_ptr_value_ids = vec.vectorize_ptr_out( program_tokens, program_vocab, constant_unique_input_ids) example.program_text_ptr_value_ids_list.append( program_text_ptr_value_ids) # sanity check # NL pointer output contains tokens that does not belong to any of the following categories # - reserved tokens # - tokens in the NL input # - tokens from environment variables (e.g. table schema) ############################ if program_vocab.unk_id in program_text_ptr_value_ids: # unk_indices = [i for i, x in enumerate(program_text_ptr_value_ids) if x == program_vocab.unk_id] # print('OOV I: {}'.format(' '.join([program_tokens[i] for i in unk_indices]))) # example.pretty_print(schema=schema_graph, # de_vectorize_ptr=vec.de_vectorize_ptr, # de_vectorize_field_ptr=vec.de_vectorize_field_ptr, # rev_vocab=program_vocab, # post_process=post_process) query_oov = True ############################ # Model II. Bridge output if ast: denormalized_ast, _ = denormalize(ast, schema_graph, return_parse_tree=True) example.program_denormalized_ast_list.append( denormalized_ast) tokenizer_output = program_tokenize( denormalized_ast, return_token_types=True, schema=schema_graph, keep_singleton_fields=True, omit_from_clause=args.omit_from_clause, no_join_condition=args.no_join_condition, atomic_value=False, num_token=NUM_TOKEN, str_token=STR_TOKEN, in_execution_order=args.process_sql_in_execution_order) program_singleton_field_tokens, program_singleton_field_token_types = tokenizer_output[: 2] else: program_singleton_field_tokens = ['from'] program_singleton_field_token_types = [RESERVED_TOKEN_TYPE] program_singleton_field_tokens = [ START_TOKEN ] + program_singleton_field_tokens + [EOS_TOKEN] program_singleton_field_token_types = \ [RESERVED_TOKEN_TYPE] + program_singleton_field_token_types + [RESERVED_TOKEN_TYPE] example.program_singleton_field_tokens_list.append( program_singleton_field_tokens) example.program_singleton_field_token_types_list.append( program_singleton_field_token_types) program_singleton_field_input_ids = vec.vectorize_singleton( program_singleton_field_tokens, program_singleton_field_token_types, program_vocab) example.program_singleton_field_input_ids_list.append( program_singleton_field_input_ids) else: # Model II. Bridge output example.program_singleton_field_input_ids_list.append( example.example.program_singleton_field_input_ids_list[j]) program_singleton_field_tokens = example.example.program_singleton_field_tokens_list[ j] program_singleton_field_token_types = example.example.program_singleton_field_token_types_list[ j] program_field_ptr_value_ids = vec.vectorize_field_ptr_out( program_singleton_field_tokens, program_singleton_field_token_types, program_vocab, constant_unique_input_ids, max_memory_size=len(constant_memory_features), schema=schema_graph, num_included_nodes=num_included_nodes) example.program_text_and_field_ptr_value_ids_list.append( program_field_ptr_value_ids) if example.gt_table_names_list: table_ids = [ schema_graph.get_table_id(table_name) for table_name in example.gt_table_names_list[j] ] example.table_ids_list.append(table_ids) assert ([schema_graph.get_table(x).name for x in table_ids] == example.gt_table_names) # sanity check ############################ # NL+Schema pointer output contains tokens that does not belong to any of the following categories if verbose: if program_vocab.unk_id in program_field_ptr_value_ids: unk_indices = [ i for i, x in enumerate(program_field_ptr_value_ids) if x == program_vocab.unk_id ] print('OOV II: {}'.format(' '.join([ program_singleton_field_tokens[i] for i in unk_indices ]))) example.pretty_print( schema=schema_graph, de_vectorize_ptr=vec.de_vectorize_ptr, de_vectorize_field_ptr=vec.de_vectorize_field_ptr, rev_vocab=program_vocab, post_process=post_process, use_table_aware_te=(args.model_id in [BRIDGE])) query_oov = True if program_vocab.unk_field_id in program_field_ptr_value_ids: example.pretty_print( schema=schema_graph, de_vectorize_ptr=vec.de_vectorize_ptr, de_vectorize_field_ptr=vec.de_vectorize_field_ptr, rev_vocab=program_vocab, post_process=post_process, use_table_aware_te=(args.model_id in [BRIDGE])) if program_vocab.unk_table_id in program_field_ptr_value_ids: example.pretty_print( schema=schema_graph, de_vectorize_ptr=vec.de_vectorize_ptr, de_vectorize_field_ptr=vec.de_vectorize_field_ptr, rev_vocab=program_vocab, post_process=post_process, use_table_aware_te=(args.model_id in [BRIDGE])) ############################ # Store the ground truth queries after preprocessing to run a relaxed evaluation or # to evaluate with partial queries if split == 'dev': input_tokens = text_tokens if args.model_id in [BRIDGE]: _p = vec.de_vectorize_field_ptr( program_field_ptr_value_ids, program_vocab, input_tokens, schema=schema_graph, post_process=post_process) elif args.model_id in [SEQ2SEQ_PG]: _p = vec.de_vectorize_ptr(program_text_ptr_value_ids, program_vocab, input_tokens, post_process=post_process) else: _p = program example.gt_program_list.append(_p) # sanity check ############################ # try: # assert(equal_ignoring_trivial_diffs(_p, program.lower(), verbose=True)) # except Exception: # print('_p:\t\t{}'.format(_p)) # print('program:\t{}'.format(program)) # print() # import pdb # pdb.set_trace() ############################ example.run_unit_tests() return query_oov, denormalized, schema_truncated, token_restored