Esempio n. 1
0
    def _pad_batch_records(self, batch_records, is_infer, phase=None):
        """Padding a batch of records and construct model's inputs."""
        batch = {}
        batch_token_ids = [record.token_ids for record in batch_records]
        batch_type_ids = [record.type_ids for record in batch_records]
        batch_pos_ids = [record.pos_ids for record in batch_records]
        batch["token_ids"] = pad_batch_data(batch_token_ids,
                                            pad_id=self.pad_id)
        batch["type_ids"] = pad_batch_data(batch_type_ids, pad_id=0)
        batch["pos_ids"] = pad_batch_data(batch_pos_ids, pad_id=0)
        if self.use_role:
            batch_role_ids = [record.role_ids for record in batch_records]
            batch["role_ids"] = pad_batch_data(batch_role_ids, pad_id=0)

        attention_mask = self._gen_self_attn_mask(batch_token_ids,
                                                  is_unidirectional=False)
        batch["attention_mask"] = attention_mask

        if not is_infer:
            batch_label = [record.label for record in batch_records]
            batch["label"] = np.array(batch_label).astype("int64").reshape(
                [-1, 1])
        else:
            batch_data_id = [record.data_id for record in batch_records]
            batch["data_id"] = np.array(batch_data_id).astype("int64").reshape(
                [-1, 1])

        return batch
Esempio n. 2
0
    def _pad_batch_records(self, batch_records, is_infer, phase=None):
        """Padding a batch of records and construct model's inputs.

        This function can be override by its subclass if necessary.
        """
        batch_size = len(batch_records)
        batch = {}
        batch_token_ids = [record.token_ids for record in batch_records]
        batch_type_ids = [record.type_ids for record in batch_records]
        batch_pos_ids = [record.pos_ids for record in batch_records]
        if self.use_role:
            batch_role_ids = [record.role_ids for record in batch_records]
        batch["token_ids"] = pad_batch_data(batch_token_ids,
                                            pad_id=self.pad_id)
        batch["type_ids"] = pad_batch_data(batch_type_ids, pad_id=0)
        batch["pos_ids"] = pad_batch_data(batch_pos_ids, pad_id=0)
        if self.use_role:
            batch["role_ids"] = pad_batch_data(batch_role_ids, pad_id=0)

        batch_tgt_start_idx = [
            record.tgt_start_idx for record in batch_records
        ]
        batch["generation_mask"] = self._gen_self_attn_mask(
            batch_token_ids, batch_tgt_start_idx=batch_tgt_start_idx)

        if is_infer:
            tgt_ids = np.array([[[self.bos_id]]] * len(batch_token_ids),
                               dtype="int64")
            if self.position_style == "continuous":
                tgt_pos = np.array(batch_tgt_start_idx, dtype="int64")
            else:
                tgt_pos = np.zeros_like(batch_tgt_start_idx, dtype="int64")
            tgt_pos = tgt_pos.reshape(-1, 1, 1)
            batch["init_score"] = np.zeros_like(tgt_ids,
                                                dtype="float32").reshape(
                                                    -1, 1).tolist()
            batch["tgt_ids"] = tgt_ids.tolist()
            batch["tgt_pos"] = tgt_pos.tolist()
            batch["parent_idx"] = np.array(range(batch_size), dtype="int32")

            batch["tgt_generation_mask"] = batch[
                "generation_mask"][:, 0:1, :].astype("float32")

            batch_data_id = [record.data_id for record in batch_records]
            batch["data_id"] = np.array(batch_data_id).astype("int64").reshape(
                [-1, 1])
        else:
            batch["tgt_label"], batch["tgt_idx"] = mask(
                batch_tokens=batch_token_ids,
                vocab_size=self.vocab_size,
                tgt_starts=batch_tgt_start_idx,
                bos_id=self.bos_id,
                eos_id=self.eos_id,
                mask_id=self.mask_id,
                is_unidirectional=True)

        return batch
Esempio n. 3
0
    def _pad_batch_records(self, batch_records, is_infer, phase=None):
        """Padding a batch of records and construct model's inputs."""
        batch = {}
        batch_token_ids = [record.token_ids for record in batch_records]
        batch_type_ids = [record.type_ids for record in batch_records]
        batch_pos_ids = [record.pos_ids for record in batch_records]
        if self.use_role:
            batch_role_ids = [record.role_ids for record in batch_records]
        batch_tgt_start_idx = [record.tgt_start_idx for record in batch_records]
        batch_label = [record.label for record in batch_records]

        batch_mask_token_ids, tgt_label, tgt_idx, label_idx = mask(
            batch_tokens=batch_token_ids,
            vocab_size=self.vocab_size,
            bos_id=self.bos_id,
            eos_id=self.eos_id,
            mask_id=self.mask_id,
            tgt_starts=batch_tgt_start_idx,
            labels=batch_label,
            is_unidirectional=False)
        if not is_infer:
            batch_token_ids = batch_mask_token_ids
        batch["token_ids"] = pad_batch_data(batch_token_ids, pad_id=self.pad_id)
        batch["type_ids"] = pad_batch_data(batch_type_ids, pad_id=0)
        batch["pos_ids"] = pad_batch_data(batch_pos_ids, pad_id=0)
        if self.use_role:
            batch["role_ids"] = pad_batch_data(batch_role_ids, pad_id=0)
        attention_mask = self._gen_self_attn_mask(batch_token_ids, is_unidirectional=False)

        batch["attention_mask"] = attention_mask
        batch["label_idx"] = label_idx

        if not is_infer:
            batch_label = np.array(batch_label).astype("int64").reshape([-1, 1])
            batch["label"] = batch_label
            batch["tgt_label"] = tgt_label
            batch["tgt_idx"] = tgt_idx
        else:
            batch_data_id = [record.data_id for record in batch_records]
            batch["data_id"] = np.array(batch_data_id).astype("int64").reshape([-1, 1])

        return batch
Esempio n. 4
0
    def _pad_batch_records(self, batch_records, is_infer, phase=None):
        """pad batch records and mask"""
        batch = {}
        batch_token_ids = [record.token_ids for record in batch_records]
        batch_type_ids = [record.type_ids for record in batch_records]
        batch_pos_ids = [record.pos_ids for record in batch_records]
        if self.use_role:
            batch_role_ids = [record.role_ids for record in batch_records]

        batch["token_ids"] = pad_batch_data(batch_token_ids, pad_id=self.pad_id)
        batch["type_ids"] = pad_batch_data(batch_type_ids, pad_id=self.pad_id)
        batch["pos_ids"] = pad_batch_data(batch_pos_ids, pad_id=self.pad_id)

        if self.use_role:
            batch["role_ids"] = pad_batch_data(batch_role_ids, pad_id=self.pad_id)

        batch["attention_mask"] = self._gen_self_attn_mask(batch_token_ids, is_unidirectional=False)

        batch_data_id = [record.data_id for record in batch_records]
        batch["data_id"] = np.array(batch_data_id).astype("int64").reshape([-1, 1])

        return batch
Esempio n. 5
0
    def _pad_batch_records(self, batch_records, is_infer, **kwargs):
        """Padding a batch of records and construct model's inputs."""
        batch = {}
        batch_token_ids = [record.token_ids for record in batch_records]
        batch_type_ids = [record.type_ids for record in batch_records]
        batch_pos_ids = [record.pos_ids for record in batch_records]
        if self.use_role:
            batch_role_ids = [record.role_ids for record in batch_records]
        batch_tgt_start_idx = [record.tgt_start_idx for record in batch_records]

        batch_size = len(batch_token_ids)

        # padding
        batch["token_ids"] = pad_batch_data(batch_token_ids, pad_id=self.pad_id)
        batch["type_ids"] = pad_batch_data(batch_type_ids, pad_id=0)
        batch["pos_ids"] = pad_batch_data(batch_pos_ids, pad_id=0)
        if self.use_role:
            batch["role_ids"] = pad_batch_data(batch_role_ids, pad_id=0)

        batch["generation_mask"] = self._gen_self_attn_mask(
            batch_token_ids,
            batch_tgt_start_idx=batch_tgt_start_idx,
            is_unidirectional=True,
            num_aux_token=1)
        if not is_infer:
            batch["recognition_mask"] = self._gen_self_attn_mask(
                batch_token_ids,
                is_unidirectional=False,
                num_aux_token=1)

        if is_infer:
            tgt_ids = np.array([[[self.bos_id]]] * batch_size, dtype="int64")
            if self.position_style == "continuous":
                tgt_pos = np.array(batch_tgt_start_idx, dtype="int64")
            else:
                tgt_pos = np.zeros_like(batch_tgt_start_idx, dtype="int64")
            tgt_pos = tgt_pos.reshape(-1, 1, 1)
            batch["init_score"] = np.zeros_like(tgt_ids, dtype="float32").reshape(-1, 1).tolist()
            batch["tgt_ids"] = tgt_ids.tolist()
            batch["tgt_pos"] = tgt_pos.tolist()
            batch["parent_idx"] = np.array(range(batch_size), dtype="int32")
            batch["latent_id"] = np.zeros([batch_size], dtype="int32")

            batch["tgt_generation_mask"] = batch["generation_mask"][:, 0:1, :].astype("float32")

            batch_data_id = [record.data_id for record in batch_records]
            batch["data_id"] = np.array(batch_data_id).astype("int64").reshape([-1, 1])
        else:
            mask_return_list = mask(
                batch_tokens=batch_token_ids,
                vocab_size=self.vocab_size,
                tgt_starts=batch_tgt_start_idx,
                is_unidirectional=True,
                use_latent=True,
                use_bow=self.use_bow)
            batch["tgt_label"] = mask_return_list[0]
            batch["tgt_idx"] = mask_return_list[1]
            if self.use_bow:
                batch["bow_label"] = mask_return_list[2]
                batch["bow_idx"] = mask_return_list[3]

        return batch
Esempio n. 6
0
    def _pad_batch_records_for_training(self, batch_records):
        """Pad batch records and mask for KAG training."""
        batch = {}

        batch_data_id = [record.data_id for record in batch_records]
        batch["data_id"] = np.array(batch_data_id).astype("int64").reshape(
            [-1, 1])

        # token_ids, [n, len, 1]
        dual_src_batch_token_ids = [
            record.dual_src_token_ids for record in batch_records
        ]
        # [n * k, len, 1]
        dual_knowledge_batch_token_ids = self._get_batch_knowledge_ids(
            batch_records, "token_ids")
        # [n, k, len, 1]
        single_batch_token_ids = self._get_batch_single_item(
            batch_records, "token_ids")
        # type_ids
        dual_src_batch_type_ids = [
            record.dual_src_type_ids for record in batch_records
        ]
        dual_knowledge_batch_type_ids = self._get_batch_knowledge_ids(
            batch_records, "type_ids")
        single_batch_type_ids = self._get_batch_single_item(
            batch_records, "type_ids")
        # pos_ids
        dual_src_batch_pos_ids = [
            record.dual_src_pos_ids for record in batch_records
        ]
        dual_knowledge_batch_pos_ids = self._get_batch_knowledge_ids(
            batch_records, "pos_ids")
        single_batch_pos_ids = self._get_batch_single_item(
            batch_records, "pos_ids")

        if self.use_role:
            dual_src_batch_role_ids = [
                record.dual_src_role_ids for record in batch_records
            ]
            dual_knowledge_batch_role_ids = self._get_batch_knowledge_ids(
                batch_records, "role_ids")
            single_batch_role_ids = self._get_batch_single_item(
                batch_records, "role_ids")

        batch_tgt_start_idx = [
            record.tgt_start_idx for record in batch_records
        ]
        batch_tgt_mask_pos = [record.tgt_mask_pos for record in batch_records]
        batch_exact_k_lens = [record.exact_k_len for record in batch_records]

        batch_size = len(batch_records)
        tgt_label, tgt_idx = self._mask_batch_as_list_for_topk_gen(
            batch_size=batch_size,
            exact_k_lens=batch_exact_k_lens,
            batch_tokens=single_batch_token_ids,
            vocab_size=self.vocab_size,
            bos_id=self.bos_id,
            eos_id=self.eos_id,
            mask_id=self.mask_id,
            batch_mask_start_pos=batch_tgt_start_idx,
            batch_tgt_mask_pos=batch_tgt_mask_pos)

        flatten_batch_tgt_start_idx = [
            j for i in batch_tgt_start_idx for j in i
        ]
        batch["single_attention_mask"] = self._gen_self_attn_mask(
            single_batch_token_ids,
            batch_tgt_start_idx=flatten_batch_tgt_start_idx)

        given_len = self.max_tgt_len
        batch["tgt_label"] = self._pad_batch_data_to_len(tgt_label,
                                                         pad_id=self.bos_id,
                                                         given_len=given_len)
        batch["tgt_idx"] = self._pad_batch_data_to_len_for_topk(
            tgt_idx, pad_id=self.pad_id, given_len=given_len)

        batch["single_token_ids"] = pad_batch_data(single_batch_token_ids,
                                                   pad_id=self.pad_id)
        batch["single_type_ids"] = pad_batch_data(single_batch_type_ids,
                                                  pad_id=self.pad_id)
        batch["single_pos_ids"] = pad_batch_data(single_batch_pos_ids,
                                                 pad_id=self.pad_id)
        if self.use_role:
            batch["single_role_ids"] = pad_batch_data(single_batch_role_ids,
                                                      pad_id=self.pad_id)

        max_len = to_optimized_size(max(map(len, single_batch_token_ids)))
        batch["tgt_label"] = batch["tgt_label"].reshape(
            [-1, self.max_knowledge_num, given_len, 1])
        batch["single_attention_mask"] = batch[
            "single_attention_mask"].reshape(
                [-1, self.max_knowledge_num, max_len, max_len])
        batch["single_token_ids"] = batch["single_token_ids"].reshape(
            [-1, self.max_knowledge_num, max_len, 1])
        batch["single_type_ids"] = batch["single_type_ids"].reshape(
            [-1, self.max_knowledge_num, max_len, 1])
        batch["single_pos_ids"] = batch["single_pos_ids"].reshape(
            [-1, self.max_knowledge_num, max_len, 1])
        if self.use_role:
            batch["single_role_ids"] = batch["single_role_ids"].reshape(
                [-1, self.max_knowledge_num, max_len, 1])

        # for dual
        batch["dual_src_token_ids"] = pad_batch_data(dual_src_batch_token_ids,
                                                     pad_id=self.pad_id)
        batch["dual_knowledge_token_ids"] = pad_batch_data(
            dual_knowledge_batch_token_ids, pad_id=self.pad_id)
        batch["dual_src_type_ids"] = pad_batch_data(dual_src_batch_type_ids,
                                                    pad_id=self.pad_id)
        batch["dual_knowledge_type_ids"] = pad_batch_data(
            dual_knowledge_batch_type_ids, pad_id=self.pad_id)
        batch["dual_src_pos_ids"] = pad_batch_data(dual_src_batch_pos_ids,
                                                   pad_id=self.pad_id)
        batch["dual_knowledge_pos_ids"] = pad_batch_data(
            dual_knowledge_batch_pos_ids, pad_id=self.pad_id)
        if self.use_role:
            batch["dual_src_role_ids"] = pad_batch_data(
                dual_src_batch_role_ids, pad_id=self.pad_id)
            batch["dual_knowledge_role_ids"] = pad_batch_data(
                dual_knowledge_batch_role_ids, pad_id=self.pad_id)
        batch["dual_src_attention_mask"] = self._gen_self_attn_mask(
            dual_src_batch_token_ids, is_unidirectional=False)
        batch["dual_knowledge_attention_mask"] = self._gen_self_attn_mask(
            dual_knowledge_batch_token_ids, is_unidirectional=False)

        return batch