Ejemplo n.º 1
0
    def format_batch(self, mini_batch):
        def get_decoder_input_ids():
            if self.training:
                if self.model_id in [BRIDGE]:
                    X = [
                        exp.program_singleton_field_input_ids
                        for exp in mini_batch
                    ]
                else:
                    X = [exp.program_input_ids for exp in mini_batch]
                return ops.pad_batch(X, self.mdl.out_vocab.pad_id)
            else:
                return None

        def get_encoder_attn_mask(table_names, table_masks):
            schema_pos = [
                schema_graph.get_schema_pos(table_name)
                for table_name in table_names
            ]
            encoder_attn_mask = [1 for _ in range(exp.num_text_tokens)]
            # asterisk marker
            encoder_attn_mask.append(1)
            is_selected_table = False
            for j in range(1, len(table_masks)):
                if j in schema_pos:
                    encoder_attn_mask.append(1)
                    is_selected_table = True
                elif table_masks[j] == 1:
                    # mask current table
                    encoder_attn_mask.append(0)
                    is_selected_table = False
                else:
                    if is_selected_table:
                        encoder_attn_mask.append(1)
                    else:
                        encoder_attn_mask.append(0)
            return encoder_attn_mask

        super().format_batch(mini_batch)
        encoder_input_ids = ops.pad_batch([exp.text_ids for exp in mini_batch],
                                          self.mdl.in_vocab.pad_id)
        decoder_input_ids = get_decoder_input_ids()

        table_samples = []

        if self.model_id == SEQ2SEQ:
            return encoder_input_ids, decoder_input_ids
        elif self.model_id in [BRIDGE]:
            encoder_ptr_input_ids, encoder_ptr_value_ids, decoder_ptr_value_ids = [], [], []
            primary_key_ids, foreign_key_ids, field_type_ids, table_masks, table_positions, table_field_scopes, \
                field_table_pos, transformer_output_value_masks, schema_memory_masks = [], [], [], [], [], [], [], [], []
            for exp in mini_batch:
                schema_graph = self.schema_graphs.get_schema(exp.db_id)
                # exp.pretty_print(example_id=0,
                #                  schema=schema_graph,
                #                  de_vectorize_ptr=vec.de_vectorize_ptr,
                #                  de_vectorize_field_ptr=vec.de_vectorize_field_ptr,
                #                  rev_vocab=self.out_vocab,
                #                  post_process=self.output_post_process,
                #                  use_table_aware_te=(self.model_id in [BRIDGE]))
                # import pdb
                # pdb.set_trace()
                if self.training:
                    # Compute schema layout
                    if exp.gt_table_names_list:
                        gt_tables = set([
                            schema_graph.get_table_id(t_name)
                            for t_name in exp.gt_table_names
                        ])
                    else:
                        gt_table_names = [
                            token for token, t in zip(
                                exp.program_singleton_field_tokens,
                                exp.program_singleton_field_token_types)
                            if t == 0
                        ]
                        gt_tables = set([
                            schema_graph.get_table_id(t_name)
                            for t_name in gt_table_names
                        ])
                    # [Hack] Baseball database has a complex schema which does not fit the input size of BERT. We select
                    # the ground truth tables and randomly add a few other tables for training.
                    if schema_graph.name.startswith('baseball'):
                        tables = list(gt_tables)
                        tables += random.sample(
                            [
                                i for i in range(schema_graph.num_tables)
                                if i not in gt_tables
                            ],
                            k=min(random.randint(1, 7),
                                  schema_graph.num_tables - len(gt_tables)))
                    else:
                        tables = list(range(schema_graph.num_tables))
                    if self.args.table_shuffling:
                        table_to_drop = random.choice(tables)
                        if table_to_drop not in gt_tables:
                            if random.uniform(0, 1) < 0.3:
                                tables = [
                                    x for x in tables if x != table_to_drop
                                ]
                        table_po, field_po = schema_graph.get_schema_perceived_order(
                            tables,
                            random_table_order=True,
                            random_field_order=self.args.random_field_order)
                    else:
                        table_po, field_po = schema_graph.get_schema_perceived_order(
                            tables,
                            random_table_order=False,
                            random_field_order=self.args.random_field_order)

                    # Schema feature extraction
                    question_encoding = exp.text if self.args.use_picklist else None
                    schema_features, matched_values = schema_graph.get_serialization(
                        self.tu,
                        flatten_features=True,
                        table_po=table_po,
                        field_po=field_po,
                        use_typed_field_markers=self.args.
                        use_typed_field_markers,
                        use_graph_encoding=self.args.use_graph_encoding,
                        question_encoding=question_encoding,
                        top_k_matches=self.args.top_k_picklist_matches,
                        num_values_per_field=self.args.num_values_per_field,
                        no_anchor_text=self.args.no_anchor_text,
                        verbose=False)
                    ptr_input_tokens, ptr_input_values, num_excluded_tables, num_excluded_fields = \
                        get_table_aware_transformer_encoder_inputs(
                            exp.text_ptr_values, exp.text_tokens, schema_features, self.tu)
                    assert (len(ptr_input_tokens) <= self.tu.tokenizer.max_len)
                    if num_excluded_fields > 0:
                        print('Warning: training input truncated')
                    num_included_nodes = schema_graph.get_num_perceived_nodes(tables) + 1 \
                                         - num_excluded_tables - num_excluded_fields
                    encoder_ptr_input_ids.append(
                        self.tu.tokenizer.convert_tokens_to_ids(
                            ptr_input_tokens))
                    if self.args.read_picklist:
                        exp.transformer_output_value_mask, value_features, value_tokens = \
                            get_transformer_output_value_mask(ptr_input_tokens, matched_values, self.tu)
                        transformer_output_value_masks.append(
                            exp.transformer_output_value_mask)
                    primary_key_ids.append(
                        schema_graph.get_primary_key_ids(
                            num_included_nodes, table_po, field_po))
                    foreign_key_ids.append(
                        schema_graph.get_foreign_key_ids(
                            num_included_nodes, table_po, field_po))
                    field_type_ids.append(
                        schema_graph.get_field_type_ids(
                            num_included_nodes, table_po, field_po))
                    table_masks.append(
                        schema_graph.get_table_masks(num_included_nodes,
                                                     table_po, field_po))

                    # Value copy feature extraction
                    if self.args.read_picklist:
                        constant_memory_features = exp.text_tokens + value_features
                        constant_memory = exp.text_ptr_values + value_tokens
                        exp.text_ptr_values = constant_memory
                    else:
                        constant_memory_features = exp.text_tokens
                    constant_ptr_value_ids, constant_unique_input_ids = vec.vectorize_ptr_in(
                        constant_memory_features, self.out_vocab)
                    encoder_ptr_value_ids.append(constant_ptr_value_ids + [
                        self.out_vocab.size + len(constant_memory_features) + x
                        for x in range(num_included_nodes)
                    ])
                    program_field_ptr_value_ids = \
                        vec.vectorize_field_ptr_out(exp.program_singleton_field_tokens,
                                                    exp.program_singleton_field_token_types,
                                                    self.out_vocab, constant_unique_input_ids,
                                                    max_memory_size=len(constant_memory_features),
                                                    schema=schema_graph,
                                                    num_included_nodes=num_included_nodes)
                    decoder_ptr_value_ids.append(program_field_ptr_value_ids)
                else:
                    encoder_ptr_input_ids = [
                        exp.ptr_input_ids for exp in mini_batch
                    ]
                    encoder_ptr_value_ids = [
                        exp.ptr_value_ids for exp in mini_batch
                    ]
                    decoder_ptr_value_ids = [exp.program_text_and_field_ptr_value_ids for exp in mini_batch] \
                        if self.training else None
                    primary_key_ids = [
                        exp.primary_key_ids for exp in mini_batch
                    ]
                    foreign_key_ids = [
                        exp.foreign_key_ids for exp in mini_batch
                    ]
                    field_type_ids = [exp.field_type_ids for exp in mini_batch]
                    table_masks = [exp.table_masks for exp in mini_batch]
                    # TODO: here we assume that all nodes in the schema graph are included
                    table_pos, table_field_scope = schema_graph.get_table_scopes(
                        schema_graph.num_nodes)
                    table_positions.append(table_pos)
                    table_field_scopes.append(table_field_scope)
                    if self.args.read_picklist:
                        transformer_output_value_masks.append(
                            exp.transformer_output_value_mask)

            encoder_ptr_input_ids = ops.pad_batch(encoder_ptr_input_ids,
                                                  self.mdl.in_vocab.pad_id)
            encoder_ptr_value_ids = ops.pad_batch(encoder_ptr_value_ids,
                                                  self.mdl.in_vocab.pad_id)
            schema_memory_masks = ops.pad_batch(schema_memory_masks, pad_id=0) \
                if (self.args.use_pred_tables and not self.training) else (None, None)
            decoder_ptr_value_ids = ops.pad_batch(decoder_ptr_value_ids, self.mdl.out_vocab.pad_id) \
                if self.training else None
            primary_key_ids = ops.pad_batch(primary_key_ids,
                                            self.mdl.in_vocab.pad_id)
            foreign_key_ids = ops.pad_batch(foreign_key_ids,
                                            self.mdl.in_vocab.pad_id)
            field_type_ids = ops.pad_batch(field_type_ids,
                                           self.mdl.in_vocab.pad_id)
            table_masks = ops.pad_batch(table_masks, pad_id=0)
            transformer_output_value_masks = ops.pad_batch(transformer_output_value_masks, pad_id=0, dtype=torch.uint8) \
                if self.args.read_picklist else (None, None)
            if not self.training:
                table_positions = ops.pad_batch(table_positions, pad_id=-1) \
                    if self.args.process_sql_in_execution_order else (None, None)
                table_field_scopes = ops.pad_batch_2D(table_field_scopes, pad_id=0) \
                    if self.args.process_sql_in_execution_order else (None, None)
            graphs = None
            return encoder_input_ids, decoder_input_ids, encoder_ptr_input_ids, encoder_ptr_value_ids, \
                   decoder_ptr_value_ids, transformer_output_value_masks, schema_memory_masks, graphs, \
                   (primary_key_ids, foreign_key_ids, field_type_ids, table_masks, table_positions,
                    table_field_scopes, field_table_pos), table_samples
        elif self.model_id in [SEQ2SEQ_PG]:
            encoder_ptr_input_ids = [exp.ptr_input_ids for exp in mini_batch]
            encoder_ptr_value_ids = [exp.ptr_value_ids for exp in mini_batch]
            decoder_ptr_value_ids = [
                exp.program_text_ptr_value_ids for exp in mini_batch
            ]
            encoder_ptr_input_ids = ops.pad_batch(encoder_ptr_input_ids,
                                                  self.mdl.in_vocab.pad_id)
            encoder_ptr_value_ids = ops.pad_batch(encoder_ptr_value_ids,
                                                  self.mdl.in_vocab.pad_id)
            decoder_ptr_value_ids = ops.pad_batch(decoder_ptr_value_ids,
                                                  self.mdl.out_vocab.pad_id)
            return encoder_input_ids, decoder_input_ids, encoder_ptr_input_ids, encoder_ptr_value_ids, \
                   decoder_ptr_value_ids
        else:
            raise NotImplementedError
def preprocess_example(split,
                       example,
                       args,
                       parsed_programs,
                       text_tokenize,
                       program_tokenize,
                       post_process,
                       table_utils,
                       schema_graph,
                       vocabs,
                       verbose=False):
    tu = table_utils
    text_vocab = vocabs['text']
    program_vocab = vocabs['program']

    def get_memory_values(features, raw_text, args):
        if args.pretrained_transformer.startswith(
                'bert-') and args.pretrained_transformer.endswith('-uncased'):
            return utils.restore_feature_case(features, raw_text, tu)
        else:
            return features

    def get_text_schema_adjacency_matrix(text_features, s_M):
        schema_size = s_M.shape[0]
        text_size = len(text_features)
        full_size = schema_size + text_size
        M = ssp.lil_matrix((full_size, full_size), dtype=np.int)
        M[-schema_size:, -schema_size:] = s_M
        return M

    # sanity check
    ############################
    query_oov = False
    denormalized = False
    schema_truncated = False
    token_restored = True
    ############################

    # Text feature extraction and set program ground truth list
    if isinstance(example, Text2SQLExample):
        if args.pretrained_transformer:
            text_features = text_tokenize(example.text)
            text_tokens, token_starts, token_ends = get_memory_values(
                text_features, example.text, args)
            if not token_starts:
                token_restored = False
        else:
            text_tokens = text_tokenize(example.text, functional_tokens)
            text_features = [t.lower() for t in text_tokens]
        example.text_tokens = text_features
        example.text_ptr_values = text_tokens
        example.text_token_starts = token_starts
        example.text_token_ends = token_ends
        example.text_ids = vec.vectorize(text_features, text_vocab)
        example.text_ptr_input_ids = vec.vectorize(text_features, text_vocab)
        program_list = example.program_list
        example.values = [
            (schema_graph.get_field(cond[0]).signature, cond[2])
            for cond in example.program_ast_list_[0]['conds']
            if (isinstance(cond[2], str) and not is_number(cond[2]))
        ]
    else:
        text_tokens = example.example.text_ptr_values
        text_features = example.example.text_tokens
        program_list = example.example.program_list

    # Schema feature extraction
    if args.model_id in [BRIDGE]:
        question_encoding = example.text if args.use_picklist else None
        tables = sorted([schema_graph.get_table_id(t_name) for t_name in example.gt_table_names]) \
            if args.use_oracle_tables else None
        table_po, field_po = schema_graph.get_schema_perceived_order(tables)
        schema_features, matched_values = schema_graph.get_serialization(
            tu,
            flatten_features=True,
            table_po=table_po,
            field_po=field_po,
            use_typed_field_markers=args.use_typed_field_markers,
            use_graph_encoding=args.use_graph_encoding,
            question_encoding=question_encoding,
            top_k_matches=args.top_k_picklist_matches,
            num_values_per_field=args.num_values_per_field,
            no_anchor_text=args.no_anchor_text)
        example.matched_values = matched_values
        example.input_tokens, example.input_ptr_values, num_excluded_tables, num_excluded_fields = \
            get_table_aware_transformer_encoder_inputs(text_tokens, text_features, schema_features, table_utils)
        schema_truncated = (num_excluded_fields > 0)
        num_included_nodes = schema_graph.get_num_perceived_nodes(
            table_po) + 1 - num_excluded_tables - num_excluded_fields
        example.ptr_input_ids = vec.vectorize(example.input_tokens, text_vocab)
        if args.read_picklist:
            example.transformer_output_value_mask, value_features, value_tokens = \
                get_transformer_output_value_mask(example.input_tokens, matched_values, tu)
        example.primary_key_ids = schema_graph.get_primary_key_ids(
            num_included_nodes, table_po=table_po, field_po=field_po)
        example.foreign_key_ids = schema_graph.get_foreign_key_ids(
            num_included_nodes, table_po=table_po, field_po=field_po)
        example.field_type_ids = schema_graph.get_field_type_ids(
            num_included_nodes, table_po=table_po, field_po=field_po)
        example.table_masks = schema_graph.get_table_masks(num_included_nodes,
                                                           table_po=table_po,
                                                           field_po=field_po)
        example.field_table_pos = schema_graph.get_field_table_pos(
            num_included_nodes, table_po=table_po, field_po=field_po)
        example.schema_M = schema_graph.adj_matrix
        example.M = get_text_schema_adjacency_matrix(text_features,
                                                     example.schema_M)
    else:
        num_included_nodes = schema_graph.num_nodes

    # Value copy feature extraction
    if args.read_picklist:
        constant_memory_features = text_features + value_features
        constant_memory = text_tokens + value_tokens
        example.text_ptr_values = constant_memory
    else:
        constant_memory_features = text_features
    constant_ptr_value_ids, constant_unique_input_ids = vec.vectorize_ptr_in(
        constant_memory_features, program_vocab)
    if isinstance(example, Text2SQLExample):
        example.text_ptr_value_ids = constant_ptr_value_ids
    example.ptr_value_ids = constant_ptr_value_ids + [
        program_vocab.size + len(constant_memory_features) + x
        for x in range(num_included_nodes)
    ]

    if not args.leaderboard_submission:
        for j, program in enumerate(program_list):
            if isinstance(example, Text2SQLExample):
                # Model II. Bridge output
                program_singleton_field_tokens, program_singleton_field_token_types = \
                    tok.wikisql_struct_to_tokens(example.program_ast_, schema_graph, tu)
                program_singleton_field_tokens = [
                    START_TOKEN
                ] + program_singleton_field_tokens + [EOS_TOKEN]
                program_singleton_field_token_types = \
                    [RESERVED_TOKEN_TYPE] + program_singleton_field_token_types + [RESERVED_TOKEN_TYPE]
                example.program_singleton_field_tokens_list.append(
                    program_singleton_field_tokens)
                example.program_singleton_field_token_types_list.append(
                    program_singleton_field_token_types)
                program_singleton_field_input_ids = vec.vectorize_singleton(
                    program_singleton_field_tokens,
                    program_singleton_field_token_types, program_vocab)
                example.program_singleton_field_input_ids_list.append(
                    program_singleton_field_input_ids)
            else:
                # Model II. Bridge output
                example.program_singleton_field_input_ids_list.append(
                    example.example.program_singleton_field_input_ids_list[j])
                program_singleton_field_tokens = example.example.program_singleton_field_tokens_list[
                    j]
                program_singleton_field_token_types = example.example.program_singleton_field_token_types_list[
                    j]

            program_field_ptr_value_ids = vec.vectorize_field_ptr_out(
                program_singleton_field_tokens,
                program_singleton_field_token_types,
                program_vocab,
                constant_unique_input_ids,
                max_memory_size=len(constant_memory_features),
                schema=schema_graph,
                num_included_nodes=num_included_nodes)
            example.program_text_and_field_ptr_value_ids_list.append(
                program_field_ptr_value_ids)

            table_ids = [
                schema_graph.get_table_id(table_name)
                for table_name in example.gt_table_names_list[j]
            ]
            example.table_ids_list.append(table_ids)
            assert ([schema_graph.get_table(x).name
                     for x in table_ids] == example.gt_table_names)

            # sanity check
            ############################
            #   NL+Schema pointer output contains tokens that does not belong to any of the following categories
            if verbose:
                if program_vocab.unk_id in program_field_ptr_value_ids:
                    unk_indices = [
                        i for i, x in enumerate(program_field_ptr_value_ids)
                        if x == program_vocab.unk_id
                    ]
                    print('OOV II: {}'.format(' '.join([
                        program_singleton_field_tokens[i] for i in unk_indices
                    ])))
                    example.pretty_print(
                        schema=schema_graph,
                        de_vectorize_ptr=vec.de_vectorize_ptr,
                        de_vectorize_field_ptr=vec.de_vectorize_field_ptr,
                        rev_vocab=program_vocab,
                        post_process=post_process,
                        use_table_aware_te=(args.model_id in [BRIDGE]))
                    query_oov = True
            if program_vocab.unk_field_id in program_field_ptr_value_ids:
                example.pretty_print(
                    schema=schema_graph,
                    de_vectorize_ptr=vec.de_vectorize_ptr,
                    de_vectorize_field_ptr=vec.de_vectorize_field_ptr,
                    rev_vocab=program_vocab,
                    post_process=post_process,
                    use_table_aware_te=(args.model_id in [BRIDGE]))
            if program_vocab.unk_table_id in program_field_ptr_value_ids:
                example.pretty_print(
                    schema=schema_graph,
                    de_vectorize_ptr=vec.de_vectorize_ptr,
                    de_vectorize_field_ptr=vec.de_vectorize_field_ptr,
                    rev_vocab=program_vocab,
                    post_process=post_process,
                    use_table_aware_te=(args.model_id in [BRIDGE]))
            ############################

            # Store the ground truth queries after preprocessing to run a relaxed evaluation or
            # to evaluate with partial queries
            if split == 'dev':
                input_tokens = text_tokens
                if args.model_id in [BRIDGE]:
                    _p = vec.de_vectorize_field_ptr(
                        program_field_ptr_value_ids,
                        program_vocab,
                        input_tokens,
                        schema=schema_graph,
                        post_process=post_process)
                else:
                    _p = program
                example.gt_program_list.append(_p)

            # sanity check
            ############################
            # try:
            #     assert(equal_ignoring_trivial_diffs(_p, program.lower(), verbose=True))
            # except Exception:
            #     print('_p:\t\t{}'.format(_p))
            #     print('program:\t{}'.format(program))
            #     print()
            #     import pdb
            #     pdb.set_trace()
            ############################

        example.run_unit_tests()

    return query_oov, denormalized, schema_truncated, token_restored