Example #1
0
 def de_vectorize(self,
                  p_cpu,
                  out_vocab,
                  input_ptr_values,
                  schema=None,
                  table_po=None,
                  field_po=None,
                  return_tokens=False):
     # convert output prediction vector into a human readable string
     if self.model_id in [SEQ2SEQ_PG]:
         return vec.de_vectorize_ptr(p_cpu,
                                     out_vocab,
                                     memory=input_ptr_values,
                                     post_process=self.output_post_process,
                                     return_tokens=return_tokens)
     elif self.model_id in [BRIDGE]:
         return vec.de_vectorize_field_ptr(
             p_cpu,
             out_vocab,
             memory=input_ptr_values,
             schema=schema,
             table_po=table_po,
             field_po=field_po,
             post_process=self.output_post_process,
             return_tokens=return_tokens)
     elif self.model_id == SEQ2SEQ:
         return vec.de_vectorize(p_cpu,
                                 out_vocab,
                                 post_process=self.output_post_process,
                                 return_tokens=return_tokens)
     else:
         raise NotImplementedError
Example #2
0
    def eval_example(self,
                     example_pred,
                     example,
                     verbose=False,
                     example_id=None,
                     schema_graph=None):
        example_pred = self.pack_sql(example_pred)
        gt_ast = example.program_ast
        s_correct, g_correct, o_correct, f_correct, w_correct, h_correct, l_correct = self.match(
            example_pred, gt_ast)
        sql_correct = s_correct and g_correct and o_correct and f_correct and w_correct and h_correct and l_correct

        if not sql_correct and verbose:
            assert (schema_graph is not None)

            text = example.text
            text_tokens = vec.de_vectorize(example.text_ids,
                                           return_tokens=True)
            text_ptr_values = example.text_ptr_values
            print('Example {}'.format(example_id))
            print('NL:\t{}'.format(text.encode('utf-8')))
            print('NL tokens:\t{}'.format(encode_str_list(
                text_tokens, 'utf-8')))
            print('NL tokens (original):\t{}'.format(
                encode_str_list(text_ptr_values, 'utf-8')))
            print('NL tokens (recovered): {}'.format(vec.de_vectorize_ptr(
                example.text_ptr_value_ids, self.value_vocab,
                text_ptr_values).encode('utf-8'),
                                                     return_tokens=True))

            for i, program in enumerate(example.program_list):
                print('Target {}'.format(i))
                print('- string: {}'.format(program.encode('utf-8')))

            badges = [
                get_badge(sql_correct),
                get_badge(s_correct),
                get_badge(g_correct),
                get_badge(o_correct),
                get_badge(f_correct),
                get_badge(w_correct),
                get_badge(h_correct),
                get_badge(l_correct)
            ]
            print(' '.join(badges))

            serializer = SQLSerializer(schema_graph, self.field_vocab,
                                       self.aggregation_ops,
                                       self.arithmetic_ops, self.condition_ops,
                                       self.logical_ops, self.value_vocab)
            print('select clause: {}'.format(
                serializer.serialize_select(example_pred['select'])))
            print('group by clause: {}'.format(
                serializer.serialize_group_by(example_pred['groupBy'])))
            print('order by clause: {}'.format(
                serializer.serialize_order_by(example_pred['orderBy'])))

        return sql_correct, s_correct, g_correct, o_correct, f_correct, w_correct, h_correct, l_correct