def batched_rev_idxes(self): if self._batched_rev_idxes is None: padder = DataPadder(2, pad_vals=0) # again pad 0 batched_rev_idxes, _ = padder.pad([ s.align_info.split2orig for s in self.seq_subs ]) # [bsize, sub_len] self._batched_rev_idxes = BK.input_idx(batched_rev_idxes) return self._batched_rev_idxes # [bsize, sub_len]
def __init__(self, berter: BertEncoder, seq_subs: List[InputSubwordSeqField]): self.seq_subs = seq_subs self.berter = berter self.bsize = len(seq_subs) self.arange1_t = BK.arange_idx(self.bsize) # [bsize] self.arange2_t = self.arange1_t.unsqueeze(-1) # [bsize, 1] self.arange3_t = self.arange2_t.unsqueeze(-1) # [bsize, 1, 1] # -- tokenizer = self.berter.tokenizer PAD_IDX = tokenizer.pad_token_id # MASK_IDX = tokenizer.mask_token_id # CLS_IDX_l = [tokenizer.cls_token_id] # SEP_IDX_l = [tokenizer.sep_token_id] # make batched idxes padder = DataPadder(2, pad_vals=PAD_IDX, mask_range=2) batched_sublens = [len(s.idxes) for s in seq_subs] # [bsize] batched_input_ids, batched_input_mask = padder.pad( [s.idxes for s in seq_subs]) # [bsize, sub_len] self.batched_sublens_p1 = BK.input_idx( batched_sublens ) + 1 # also the idx of EOS (if counting including BOS) self.batched_input_ids = BK.input_idx(batched_input_ids) self.batched_input_mask = BK.input_real(batched_input_mask) # make batched mappings (sub->orig) padder2 = DataPadder(2, pad_vals=0, mask_range=2) # pad as 0 to avoid out-of-range batched_first_idxes, batched_first_mask = padder2.pad( [s.align_info.orig2begin for s in seq_subs]) # [bsize, orig_len] self.batched_first_idxes = BK.input_idx(batched_first_idxes) self.batched_first_mask = BK.input_real(batched_first_mask) # reversed batched_mappings (orig->sub) (created when needed) self._batched_rev_idxes = None # [bsize, sub_len] # -- self.batched_repl_masks = None # [bsize, sub_len], to replace with MASK self.batched_token_type_ids = None # [bsize, 1+sub_len+1] self.batched_position_ids = None # [bsize, 1+sub_len+1] self.other_factors = {} # name -> aug_batched_ids
def _transform_factors(self, factors: Union[List[List[int]], BK.Expr], is_orig: bool, PAD_IDX: Union[int, float]): if isinstance(factors, BK.Expr): # already padded batched_ids = factors else: padder = DataPadder(2, pad_vals=PAD_IDX) batched_ids, _ = padder.pad(factors) batched_ids = BK.input_idx( batched_ids) # [bsize, orig-len if is_orig else sub_len] if is_orig: # map to subtoks final_batched_ids = batched_ids[ self.arange2_t, self.batched_rev_idxes] # [bsize, sub_len] else: final_batched_ids = batched_ids # [bsize, sub_len] return final_batched_ids
def prepare_batch(insts: List, idx_f: Callable, padder: DataPadder): cur_input_list = [idx_f(z) for z in insts] # idx List cur_input_arr, _ = padder.pad(cur_input_list) ret_expr = BK.input_idx(cur_input_arr) return ret_expr