def sort_key(self, record): """The key of record. We will apply sorting before batching. It can decrease the number of padding and speedup training. """ return [to_optimized_size(len(record.token_ids))]
def _pad_batch_data_to_len(self, insts, pad_id=0, given_len=0): """Pad the instances to a given length in batch.""" max_len = to_optimized_size(max(map(len, insts))) if given_len < max_len: raise ValueError( f"given_len = {given_len}, max_len = {max_len}, given_len should be larger than max_len in batch data." ) inst_data = np.array([ list(inst) + [pad_id] * (given_len - len(inst)) for inst in insts ]) return inst_data.astype("int64").reshape([-1, given_len, 1])
def _gen_self_attn_mask(self, batch_token_ids, batch_tgt_start_idx=None, is_unidirectional=True, num_aux_token=0): """Generate self attention masking matrix. This is a helpful function to generate different types of attention masking matrix. 1. Bi-directional: all tokens can attent to all other tokens. 2. Uni-directional: all tokens can only attent to their former tokens. 3. Seq2seq: tokens in source sequence can attent each other, tokens in target sequence can only attent the tokens in source sequence and the former token in target sequence. Args: batch_token_ids: A batch of token ids. batch_tgt_start_idx: A batch of indices which represent the starting index of target sequence. is_unidirectional: Whether generate uni-directional masking matrix. When `batch_tgt_start_idx` is not `None` and `is_unidirectional` is True, then it will generate seq2seq masking matrix (source sequence is bi-directional attention and target sequence is uni-directional attention). num_aux_token: The number of auxiliary tokens. The auxiliary tokens will concatenate to the begin of sequence. They are considered as a part of source sequence. """ max_len = to_optimized_size(max(map(len, batch_token_ids))) input_mask_data = np.zeros( (len(batch_token_ids), max_len + num_aux_token, max_len + num_aux_token)) if is_unidirectional: for index, mask_data in enumerate(input_mask_data): start = 0 if batch_tgt_start_idx is None else batch_tgt_start_idx[ index] end = len(batch_token_ids[index]) mask_data[:end + num_aux_token, :start + num_aux_token] = 1.0 # Generate the lower triangular matrix using the slice of matrix b = np.tril(np.ones([end - start, end - start]), 0) mask_data[start + num_aux_token:end + num_aux_token, start + num_aux_token:end + num_aux_token] = b else: for index, token_ids in enumerate(batch_token_ids): input_mask_data[index, :len(token_ids) + num_aux_token, :len(token_ids) + num_aux_token] = 1.0 return input_mask_data.astype("float32")
def _pad_batch_data_to_len_for_topk(self, insts, pad_id=0, given_len=0): """Pad the instances to a given length in batch.""" max_len = to_optimized_size(max(map(len, insts))) if given_len < max_len: raise ValueError( f"given_len = {given_len}, max_len = {max_len}, given_len should be larger than max_len in batch data." ) inst_data = [] for inst in insts: first = inst[0] b = first[0] k = first[1] cur_len = len(inst) pad_item = [b, k, pad_id] for i in range(cur_len): inst_data.append(inst[i]) for i in range(given_len - cur_len): inst_data.append(pad_item) inst_data = np.array(inst_data) # 4d return inst_data.astype("int64").reshape( [-1, self.max_knowledge_num, given_len, 3])
def _pad_batch_records_for_training(self, batch_records): """Pad batch records and mask for KAG training.""" batch = {} batch_data_id = [record.data_id for record in batch_records] batch["data_id"] = np.array(batch_data_id).astype("int64").reshape( [-1, 1]) # token_ids, [n, len, 1] dual_src_batch_token_ids = [ record.dual_src_token_ids for record in batch_records ] # [n * k, len, 1] dual_knowledge_batch_token_ids = self._get_batch_knowledge_ids( batch_records, "token_ids") # [n, k, len, 1] single_batch_token_ids = self._get_batch_single_item( batch_records, "token_ids") # type_ids dual_src_batch_type_ids = [ record.dual_src_type_ids for record in batch_records ] dual_knowledge_batch_type_ids = self._get_batch_knowledge_ids( batch_records, "type_ids") single_batch_type_ids = self._get_batch_single_item( batch_records, "type_ids") # pos_ids dual_src_batch_pos_ids = [ record.dual_src_pos_ids for record in batch_records ] dual_knowledge_batch_pos_ids = self._get_batch_knowledge_ids( batch_records, "pos_ids") single_batch_pos_ids = self._get_batch_single_item( batch_records, "pos_ids") if self.use_role: dual_src_batch_role_ids = [ record.dual_src_role_ids for record in batch_records ] dual_knowledge_batch_role_ids = self._get_batch_knowledge_ids( batch_records, "role_ids") single_batch_role_ids = self._get_batch_single_item( batch_records, "role_ids") batch_tgt_start_idx = [ record.tgt_start_idx for record in batch_records ] batch_tgt_mask_pos = [record.tgt_mask_pos for record in batch_records] batch_exact_k_lens = [record.exact_k_len for record in batch_records] batch_size = len(batch_records) tgt_label, tgt_idx = self._mask_batch_as_list_for_topk_gen( batch_size=batch_size, exact_k_lens=batch_exact_k_lens, batch_tokens=single_batch_token_ids, vocab_size=self.vocab_size, bos_id=self.bos_id, eos_id=self.eos_id, mask_id=self.mask_id, batch_mask_start_pos=batch_tgt_start_idx, batch_tgt_mask_pos=batch_tgt_mask_pos) flatten_batch_tgt_start_idx = [ j for i in batch_tgt_start_idx for j in i ] batch["single_attention_mask"] = self._gen_self_attn_mask( single_batch_token_ids, batch_tgt_start_idx=flatten_batch_tgt_start_idx) given_len = self.max_tgt_len batch["tgt_label"] = self._pad_batch_data_to_len(tgt_label, pad_id=self.bos_id, given_len=given_len) batch["tgt_idx"] = self._pad_batch_data_to_len_for_topk( tgt_idx, pad_id=self.pad_id, given_len=given_len) batch["single_token_ids"] = pad_batch_data(single_batch_token_ids, pad_id=self.pad_id) batch["single_type_ids"] = pad_batch_data(single_batch_type_ids, pad_id=self.pad_id) batch["single_pos_ids"] = pad_batch_data(single_batch_pos_ids, pad_id=self.pad_id) if self.use_role: batch["single_role_ids"] = pad_batch_data(single_batch_role_ids, pad_id=self.pad_id) max_len = to_optimized_size(max(map(len, single_batch_token_ids))) batch["tgt_label"] = batch["tgt_label"].reshape( [-1, self.max_knowledge_num, given_len, 1]) batch["single_attention_mask"] = batch[ "single_attention_mask"].reshape( [-1, self.max_knowledge_num, max_len, max_len]) batch["single_token_ids"] = batch["single_token_ids"].reshape( [-1, self.max_knowledge_num, max_len, 1]) batch["single_type_ids"] = batch["single_type_ids"].reshape( [-1, self.max_knowledge_num, max_len, 1]) batch["single_pos_ids"] = batch["single_pos_ids"].reshape( [-1, self.max_knowledge_num, max_len, 1]) if self.use_role: batch["single_role_ids"] = batch["single_role_ids"].reshape( [-1, self.max_knowledge_num, max_len, 1]) # for dual batch["dual_src_token_ids"] = pad_batch_data(dual_src_batch_token_ids, pad_id=self.pad_id) batch["dual_knowledge_token_ids"] = pad_batch_data( dual_knowledge_batch_token_ids, pad_id=self.pad_id) batch["dual_src_type_ids"] = pad_batch_data(dual_src_batch_type_ids, pad_id=self.pad_id) batch["dual_knowledge_type_ids"] = pad_batch_data( dual_knowledge_batch_type_ids, pad_id=self.pad_id) batch["dual_src_pos_ids"] = pad_batch_data(dual_src_batch_pos_ids, pad_id=self.pad_id) batch["dual_knowledge_pos_ids"] = pad_batch_data( dual_knowledge_batch_pos_ids, pad_id=self.pad_id) if self.use_role: batch["dual_src_role_ids"] = pad_batch_data( dual_src_batch_role_ids, pad_id=self.pad_id) batch["dual_knowledge_role_ids"] = pad_batch_data( dual_knowledge_batch_role_ids, pad_id=self.pad_id) batch["dual_src_attention_mask"] = self._gen_self_attn_mask( dual_src_batch_token_ids, is_unidirectional=False) batch["dual_knowledge_attention_mask"] = self._gen_self_attn_mask( dual_knowledge_batch_token_ids, is_unidirectional=False) return batch