def de_vectorize(self, p_cpu, out_vocab, input_ptr_values, schema=None, table_po=None, field_po=None, return_tokens=False): # convert output prediction vector into a human readable string if self.model_id in [SEQ2SEQ_PG]: return vec.de_vectorize_ptr(p_cpu, out_vocab, memory=input_ptr_values, post_process=self.output_post_process, return_tokens=return_tokens) elif self.model_id in [BRIDGE]: return vec.de_vectorize_field_ptr( p_cpu, out_vocab, memory=input_ptr_values, schema=schema, table_po=table_po, field_po=field_po, post_process=self.output_post_process, return_tokens=return_tokens) elif self.model_id == SEQ2SEQ: return vec.de_vectorize(p_cpu, out_vocab, post_process=self.output_post_process, return_tokens=return_tokens) else: raise NotImplementedError
def eval_example(self, example_pred, example, verbose=False, example_id=None, schema_graph=None): example_pred = self.pack_sql(example_pred) gt_ast = example.program_ast s_correct, g_correct, o_correct, f_correct, w_correct, h_correct, l_correct = self.match( example_pred, gt_ast) sql_correct = s_correct and g_correct and o_correct and f_correct and w_correct and h_correct and l_correct if not sql_correct and verbose: assert (schema_graph is not None) text = example.text text_tokens = vec.de_vectorize(example.text_ids, return_tokens=True) text_ptr_values = example.text_ptr_values print('Example {}'.format(example_id)) print('NL:\t{}'.format(text.encode('utf-8'))) print('NL tokens:\t{}'.format(encode_str_list( text_tokens, 'utf-8'))) print('NL tokens (original):\t{}'.format( encode_str_list(text_ptr_values, 'utf-8'))) print('NL tokens (recovered): {}'.format(vec.de_vectorize_ptr( example.text_ptr_value_ids, self.value_vocab, text_ptr_values).encode('utf-8'), return_tokens=True)) for i, program in enumerate(example.program_list): print('Target {}'.format(i)) print('- string: {}'.format(program.encode('utf-8'))) badges = [ get_badge(sql_correct), get_badge(s_correct), get_badge(g_correct), get_badge(o_correct), get_badge(f_correct), get_badge(w_correct), get_badge(h_correct), get_badge(l_correct) ] print(' '.join(badges)) serializer = SQLSerializer(schema_graph, self.field_vocab, self.aggregation_ops, self.arithmetic_ops, self.condition_ops, self.logical_ops, self.value_vocab) print('select clause: {}'.format( serializer.serialize_select(example_pred['select']))) print('group by clause: {}'.format( serializer.serialize_group_by(example_pred['groupBy']))) print('order by clause: {}'.format( serializer.serialize_order_by(example_pred['orderBy']))) return sql_correct, s_correct, g_correct, o_correct, f_correct, w_correct, h_correct, l_correct