示例#1
0
    def sort_key(self, record):
        """The key of record.

        We will apply sorting before batching. It can decrease the number of padding and
        speedup training.
        """
        return [to_optimized_size(len(record.token_ids))]
示例#2
0
 def _pad_batch_data_to_len(self, insts, pad_id=0, given_len=0):
     """Pad the instances to a given length in batch."""
     max_len = to_optimized_size(max(map(len, insts)))
     if given_len < max_len:
         raise ValueError(
             f"given_len = {given_len}, max_len = {max_len}, given_len should be larger than max_len in batch data."
         )
     inst_data = np.array([
         list(inst) + [pad_id] * (given_len - len(inst)) for inst in insts
     ])
     return inst_data.astype("int64").reshape([-1, given_len, 1])
示例#3
0
    def _gen_self_attn_mask(self,
                            batch_token_ids,
                            batch_tgt_start_idx=None,
                            is_unidirectional=True,
                            num_aux_token=0):
        """Generate self attention masking matrix.

        This is a helpful function to generate different types of attention masking matrix.
        1. Bi-directional: all tokens can attent to all other tokens.
        2. Uni-directional: all tokens can only attent to their former tokens.
        3. Seq2seq: tokens in source sequence can attent each other, tokens in target sequence can only attent the
            tokens in source sequence and the former token in target sequence.

        Args:
            batch_token_ids: A batch of token ids.
            batch_tgt_start_idx: A batch of indices which represent the starting index of target sequence.
            is_unidirectional: Whether generate uni-directional masking matrix. When `batch_tgt_start_idx` is not
                `None` and `is_unidirectional` is True, then it will generate seq2seq masking matrix (source sequence
                is bi-directional attention and target sequence is uni-directional attention).
            num_aux_token: The number of auxiliary tokens. The auxiliary tokens will concatenate to the begin of
                sequence. They are considered as a part of source sequence.
        """
        max_len = to_optimized_size(max(map(len, batch_token_ids)))
        input_mask_data = np.zeros(
            (len(batch_token_ids), max_len + num_aux_token,
             max_len + num_aux_token))
        if is_unidirectional:
            for index, mask_data in enumerate(input_mask_data):
                start = 0 if batch_tgt_start_idx is None else batch_tgt_start_idx[
                    index]
                end = len(batch_token_ids[index])
                mask_data[:end + num_aux_token, :start + num_aux_token] = 1.0
                # Generate the lower triangular matrix using the slice of matrix
                b = np.tril(np.ones([end - start, end - start]), 0)
                mask_data[start + num_aux_token:end + num_aux_token,
                          start + num_aux_token:end + num_aux_token] = b
        else:
            for index, token_ids in enumerate(batch_token_ids):
                input_mask_data[index, :len(token_ids) +
                                num_aux_token, :len(token_ids) +
                                num_aux_token] = 1.0
        return input_mask_data.astype("float32")
示例#4
0
    def _pad_batch_data_to_len_for_topk(self, insts, pad_id=0, given_len=0):
        """Pad the instances to a given length in batch."""
        max_len = to_optimized_size(max(map(len, insts)))
        if given_len < max_len:
            raise ValueError(
                f"given_len = {given_len}, max_len = {max_len}, given_len should be larger than max_len in batch data."
            )
        inst_data = []
        for inst in insts:
            first = inst[0]
            b = first[0]
            k = first[1]
            cur_len = len(inst)
            pad_item = [b, k, pad_id]
            for i in range(cur_len):
                inst_data.append(inst[i])
            for i in range(given_len - cur_len):
                inst_data.append(pad_item)

        inst_data = np.array(inst_data)
        # 4d
        return inst_data.astype("int64").reshape(
            [-1, self.max_knowledge_num, given_len, 3])
示例#5
0
    def _pad_batch_records_for_training(self, batch_records):
        """Pad batch records and mask for KAG training."""
        batch = {}

        batch_data_id = [record.data_id for record in batch_records]
        batch["data_id"] = np.array(batch_data_id).astype("int64").reshape(
            [-1, 1])

        # token_ids, [n, len, 1]
        dual_src_batch_token_ids = [
            record.dual_src_token_ids for record in batch_records
        ]
        # [n * k, len, 1]
        dual_knowledge_batch_token_ids = self._get_batch_knowledge_ids(
            batch_records, "token_ids")
        # [n, k, len, 1]
        single_batch_token_ids = self._get_batch_single_item(
            batch_records, "token_ids")
        # type_ids
        dual_src_batch_type_ids = [
            record.dual_src_type_ids for record in batch_records
        ]
        dual_knowledge_batch_type_ids = self._get_batch_knowledge_ids(
            batch_records, "type_ids")
        single_batch_type_ids = self._get_batch_single_item(
            batch_records, "type_ids")
        # pos_ids
        dual_src_batch_pos_ids = [
            record.dual_src_pos_ids for record in batch_records
        ]
        dual_knowledge_batch_pos_ids = self._get_batch_knowledge_ids(
            batch_records, "pos_ids")
        single_batch_pos_ids = self._get_batch_single_item(
            batch_records, "pos_ids")

        if self.use_role:
            dual_src_batch_role_ids = [
                record.dual_src_role_ids for record in batch_records
            ]
            dual_knowledge_batch_role_ids = self._get_batch_knowledge_ids(
                batch_records, "role_ids")
            single_batch_role_ids = self._get_batch_single_item(
                batch_records, "role_ids")

        batch_tgt_start_idx = [
            record.tgt_start_idx for record in batch_records
        ]
        batch_tgt_mask_pos = [record.tgt_mask_pos for record in batch_records]
        batch_exact_k_lens = [record.exact_k_len for record in batch_records]

        batch_size = len(batch_records)
        tgt_label, tgt_idx = self._mask_batch_as_list_for_topk_gen(
            batch_size=batch_size,
            exact_k_lens=batch_exact_k_lens,
            batch_tokens=single_batch_token_ids,
            vocab_size=self.vocab_size,
            bos_id=self.bos_id,
            eos_id=self.eos_id,
            mask_id=self.mask_id,
            batch_mask_start_pos=batch_tgt_start_idx,
            batch_tgt_mask_pos=batch_tgt_mask_pos)

        flatten_batch_tgt_start_idx = [
            j for i in batch_tgt_start_idx for j in i
        ]
        batch["single_attention_mask"] = self._gen_self_attn_mask(
            single_batch_token_ids,
            batch_tgt_start_idx=flatten_batch_tgt_start_idx)

        given_len = self.max_tgt_len
        batch["tgt_label"] = self._pad_batch_data_to_len(tgt_label,
                                                         pad_id=self.bos_id,
                                                         given_len=given_len)
        batch["tgt_idx"] = self._pad_batch_data_to_len_for_topk(
            tgt_idx, pad_id=self.pad_id, given_len=given_len)

        batch["single_token_ids"] = pad_batch_data(single_batch_token_ids,
                                                   pad_id=self.pad_id)
        batch["single_type_ids"] = pad_batch_data(single_batch_type_ids,
                                                  pad_id=self.pad_id)
        batch["single_pos_ids"] = pad_batch_data(single_batch_pos_ids,
                                                 pad_id=self.pad_id)
        if self.use_role:
            batch["single_role_ids"] = pad_batch_data(single_batch_role_ids,
                                                      pad_id=self.pad_id)

        max_len = to_optimized_size(max(map(len, single_batch_token_ids)))
        batch["tgt_label"] = batch["tgt_label"].reshape(
            [-1, self.max_knowledge_num, given_len, 1])
        batch["single_attention_mask"] = batch[
            "single_attention_mask"].reshape(
                [-1, self.max_knowledge_num, max_len, max_len])
        batch["single_token_ids"] = batch["single_token_ids"].reshape(
            [-1, self.max_knowledge_num, max_len, 1])
        batch["single_type_ids"] = batch["single_type_ids"].reshape(
            [-1, self.max_knowledge_num, max_len, 1])
        batch["single_pos_ids"] = batch["single_pos_ids"].reshape(
            [-1, self.max_knowledge_num, max_len, 1])
        if self.use_role:
            batch["single_role_ids"] = batch["single_role_ids"].reshape(
                [-1, self.max_knowledge_num, max_len, 1])

        # for dual
        batch["dual_src_token_ids"] = pad_batch_data(dual_src_batch_token_ids,
                                                     pad_id=self.pad_id)
        batch["dual_knowledge_token_ids"] = pad_batch_data(
            dual_knowledge_batch_token_ids, pad_id=self.pad_id)
        batch["dual_src_type_ids"] = pad_batch_data(dual_src_batch_type_ids,
                                                    pad_id=self.pad_id)
        batch["dual_knowledge_type_ids"] = pad_batch_data(
            dual_knowledge_batch_type_ids, pad_id=self.pad_id)
        batch["dual_src_pos_ids"] = pad_batch_data(dual_src_batch_pos_ids,
                                                   pad_id=self.pad_id)
        batch["dual_knowledge_pos_ids"] = pad_batch_data(
            dual_knowledge_batch_pos_ids, pad_id=self.pad_id)
        if self.use_role:
            batch["dual_src_role_ids"] = pad_batch_data(
                dual_src_batch_role_ids, pad_id=self.pad_id)
            batch["dual_knowledge_role_ids"] = pad_batch_data(
                dual_knowledge_batch_role_ids, pad_id=self.pad_id)
        batch["dual_src_attention_mask"] = self._gen_self_attn_mask(
            dual_src_batch_token_ids, is_unidirectional=False)
        batch["dual_knowledge_attention_mask"] = self._gen_self_attn_mask(
            dual_knowledge_batch_token_ids, is_unidirectional=False)

        return batch