Пример #1
0
 def de_vectorize(self,
                  p_cpu,
                  out_vocab,
                  input_ptr_values,
                  schema=None,
                  table_po=None,
                  field_po=None,
                  return_tokens=False):
     # convert output prediction vector into a human readable string
     if self.model_id in [SEQ2SEQ_PG]:
         return vec.de_vectorize_ptr(p_cpu,
                                     out_vocab,
                                     memory=input_ptr_values,
                                     post_process=self.output_post_process,
                                     return_tokens=return_tokens)
     elif self.model_id in [BRIDGE]:
         return vec.de_vectorize_field_ptr(
             p_cpu,
             out_vocab,
             memory=input_ptr_values,
             schema=schema,
             table_po=table_po,
             field_po=field_po,
             post_process=self.output_post_process,
             return_tokens=return_tokens)
     elif self.model_id == SEQ2SEQ:
         return vec.de_vectorize(p_cpu,
                                 out_vocab,
                                 post_process=self.output_post_process,
                                 return_tokens=return_tokens)
     else:
         raise NotImplementedError
def preprocess_example(split,
                       example,
                       args,
                       parsed_programs,
                       text_tokenize,
                       program_tokenize,
                       post_process,
                       table_utils,
                       schema_graph,
                       vocabs,
                       verbose=False):
    tu = table_utils
    text_vocab = vocabs['text']
    program_vocab = vocabs['program']

    def get_memory_values(features, raw_text, args):
        if args.pretrained_transformer.startswith(
                'bert-') and args.pretrained_transformer.endswith('-uncased'):
            return utils.restore_feature_case(features, raw_text, tu)
        else:
            return features

    def get_text_schema_adjacency_matrix(text_features, s_M):
        schema_size = s_M.shape[0]
        text_size = len(text_features)
        full_size = schema_size + text_size
        M = ssp.lil_matrix((full_size, full_size), dtype=np.int)
        M[-schema_size:, -schema_size:] = s_M
        return M

    # sanity check
    ############################
    query_oov = False
    denormalized = False
    schema_truncated = False
    token_restored = True
    ############################

    # Text feature extraction and set program ground truth list
    if isinstance(example, Text2SQLExample):
        if args.pretrained_transformer:
            text_features = text_tokenize(example.text)
            text_tokens, token_starts, token_ends = get_memory_values(
                text_features, example.text, args)
            if not token_starts:
                token_restored = False
        else:
            text_tokens = text_tokenize(example.text, functional_tokens)
            text_features = [t.lower() for t in text_tokens]
        example.text_tokens = text_features
        example.text_ptr_values = text_tokens
        example.text_token_starts = token_starts
        example.text_token_ends = token_ends
        example.text_ids = vec.vectorize(text_features, text_vocab)
        example.text_ptr_input_ids = vec.vectorize(text_features, text_vocab)
        program_list = example.program_list
        example.values = [
            (schema_graph.get_field(cond[0]).signature, cond[2])
            for cond in example.program_ast_list_[0]['conds']
            if (isinstance(cond[2], str) and not is_number(cond[2]))
        ]
    else:
        text_tokens = example.example.text_ptr_values
        text_features = example.example.text_tokens
        program_list = example.example.program_list

    # Schema feature extraction
    if args.model_id in [BRIDGE]:
        question_encoding = example.text if args.use_picklist else None
        tables = sorted([schema_graph.get_table_id(t_name) for t_name in example.gt_table_names]) \
            if args.use_oracle_tables else None
        table_po, field_po = schema_graph.get_schema_perceived_order(tables)
        schema_features, matched_values = schema_graph.get_serialization(
            tu,
            flatten_features=True,
            table_po=table_po,
            field_po=field_po,
            use_typed_field_markers=args.use_typed_field_markers,
            use_graph_encoding=args.use_graph_encoding,
            question_encoding=question_encoding,
            top_k_matches=args.top_k_picklist_matches,
            num_values_per_field=args.num_values_per_field,
            no_anchor_text=args.no_anchor_text)
        example.matched_values = matched_values
        example.input_tokens, example.input_ptr_values, num_excluded_tables, num_excluded_fields = \
            get_table_aware_transformer_encoder_inputs(text_tokens, text_features, schema_features, table_utils)
        schema_truncated = (num_excluded_fields > 0)
        num_included_nodes = schema_graph.get_num_perceived_nodes(
            table_po) + 1 - num_excluded_tables - num_excluded_fields
        example.ptr_input_ids = vec.vectorize(example.input_tokens, text_vocab)
        if args.read_picklist:
            example.transformer_output_value_mask, value_features, value_tokens = \
                get_transformer_output_value_mask(example.input_tokens, matched_values, tu)
        example.primary_key_ids = schema_graph.get_primary_key_ids(
            num_included_nodes, table_po=table_po, field_po=field_po)
        example.foreign_key_ids = schema_graph.get_foreign_key_ids(
            num_included_nodes, table_po=table_po, field_po=field_po)
        example.field_type_ids = schema_graph.get_field_type_ids(
            num_included_nodes, table_po=table_po, field_po=field_po)
        example.table_masks = schema_graph.get_table_masks(num_included_nodes,
                                                           table_po=table_po,
                                                           field_po=field_po)
        example.field_table_pos = schema_graph.get_field_table_pos(
            num_included_nodes, table_po=table_po, field_po=field_po)
        example.schema_M = schema_graph.adj_matrix
        example.M = get_text_schema_adjacency_matrix(text_features,
                                                     example.schema_M)
    else:
        num_included_nodes = schema_graph.num_nodes

    # Value copy feature extraction
    if args.read_picklist:
        constant_memory_features = text_features + value_features
        constant_memory = text_tokens + value_tokens
        example.text_ptr_values = constant_memory
    else:
        constant_memory_features = text_features
    constant_ptr_value_ids, constant_unique_input_ids = vec.vectorize_ptr_in(
        constant_memory_features, program_vocab)
    if isinstance(example, Text2SQLExample):
        example.text_ptr_value_ids = constant_ptr_value_ids
    example.ptr_value_ids = constant_ptr_value_ids + [
        program_vocab.size + len(constant_memory_features) + x
        for x in range(num_included_nodes)
    ]

    if not args.leaderboard_submission:
        for j, program in enumerate(program_list):
            if isinstance(example, Text2SQLExample):
                # Model II. Bridge output
                program_singleton_field_tokens, program_singleton_field_token_types = \
                    tok.wikisql_struct_to_tokens(example.program_ast_, schema_graph, tu)
                program_singleton_field_tokens = [
                    START_TOKEN
                ] + program_singleton_field_tokens + [EOS_TOKEN]
                program_singleton_field_token_types = \
                    [RESERVED_TOKEN_TYPE] + program_singleton_field_token_types + [RESERVED_TOKEN_TYPE]
                example.program_singleton_field_tokens_list.append(
                    program_singleton_field_tokens)
                example.program_singleton_field_token_types_list.append(
                    program_singleton_field_token_types)
                program_singleton_field_input_ids = vec.vectorize_singleton(
                    program_singleton_field_tokens,
                    program_singleton_field_token_types, program_vocab)
                example.program_singleton_field_input_ids_list.append(
                    program_singleton_field_input_ids)
            else:
                # Model II. Bridge output
                example.program_singleton_field_input_ids_list.append(
                    example.example.program_singleton_field_input_ids_list[j])
                program_singleton_field_tokens = example.example.program_singleton_field_tokens_list[
                    j]
                program_singleton_field_token_types = example.example.program_singleton_field_token_types_list[
                    j]

            program_field_ptr_value_ids = vec.vectorize_field_ptr_out(
                program_singleton_field_tokens,
                program_singleton_field_token_types,
                program_vocab,
                constant_unique_input_ids,
                max_memory_size=len(constant_memory_features),
                schema=schema_graph,
                num_included_nodes=num_included_nodes)
            example.program_text_and_field_ptr_value_ids_list.append(
                program_field_ptr_value_ids)

            table_ids = [
                schema_graph.get_table_id(table_name)
                for table_name in example.gt_table_names_list[j]
            ]
            example.table_ids_list.append(table_ids)
            assert ([schema_graph.get_table(x).name
                     for x in table_ids] == example.gt_table_names)

            # sanity check
            ############################
            #   NL+Schema pointer output contains tokens that does not belong to any of the following categories
            if verbose:
                if program_vocab.unk_id in program_field_ptr_value_ids:
                    unk_indices = [
                        i for i, x in enumerate(program_field_ptr_value_ids)
                        if x == program_vocab.unk_id
                    ]
                    print('OOV II: {}'.format(' '.join([
                        program_singleton_field_tokens[i] for i in unk_indices
                    ])))
                    example.pretty_print(
                        schema=schema_graph,
                        de_vectorize_ptr=vec.de_vectorize_ptr,
                        de_vectorize_field_ptr=vec.de_vectorize_field_ptr,
                        rev_vocab=program_vocab,
                        post_process=post_process,
                        use_table_aware_te=(args.model_id in [BRIDGE]))
                    query_oov = True
            if program_vocab.unk_field_id in program_field_ptr_value_ids:
                example.pretty_print(
                    schema=schema_graph,
                    de_vectorize_ptr=vec.de_vectorize_ptr,
                    de_vectorize_field_ptr=vec.de_vectorize_field_ptr,
                    rev_vocab=program_vocab,
                    post_process=post_process,
                    use_table_aware_te=(args.model_id in [BRIDGE]))
            if program_vocab.unk_table_id in program_field_ptr_value_ids:
                example.pretty_print(
                    schema=schema_graph,
                    de_vectorize_ptr=vec.de_vectorize_ptr,
                    de_vectorize_field_ptr=vec.de_vectorize_field_ptr,
                    rev_vocab=program_vocab,
                    post_process=post_process,
                    use_table_aware_te=(args.model_id in [BRIDGE]))
            ############################

            # Store the ground truth queries after preprocessing to run a relaxed evaluation or
            # to evaluate with partial queries
            if split == 'dev':
                input_tokens = text_tokens
                if args.model_id in [BRIDGE]:
                    _p = vec.de_vectorize_field_ptr(
                        program_field_ptr_value_ids,
                        program_vocab,
                        input_tokens,
                        schema=schema_graph,
                        post_process=post_process)
                else:
                    _p = program
                example.gt_program_list.append(_p)

            # sanity check
            ############################
            # try:
            #     assert(equal_ignoring_trivial_diffs(_p, program.lower(), verbose=True))
            # except Exception:
            #     print('_p:\t\t{}'.format(_p))
            #     print('program:\t{}'.format(program))
            #     print()
            #     import pdb
            #     pdb.set_trace()
            ############################

        example.run_unit_tests()

    return query_oov, denormalized, schema_truncated, token_restored