def select_topk(score_t: BK.Expr, topk_t: Union[int, BK.Expr], mask_t: BK.Expr = None, dim=-1): # prepare K if isinstance(topk_t, int): K = topk_t tmp_shape = BK.get_shape(score_t) tmp_shape[dim] = 1 # set it as 1 topk_t = BK.constants_idx(tmp_shape, K) else: K = topk_t.max().item() exact_rank_t = topk_t - 1 # [bsize, 1] exact_rank_t.clamp_(min=0, max=K - 1) # make it in valid range! # mask values if mask_t is not None: score_t = score_t + Constants.REAL_PRAC_MIN * (1. - mask_t) # topk topk_vals, _ = score_t.topk(K, dim, largest=True, sorted=True) # [*, K, *] # gather score sel_thresh = topk_vals.gather(dim, exact_rank_t) # [*, 1, *] # get topk_mask topk_mask = (score_t >= sel_thresh).float() # [*, D, *] if mask_t is not None: topk_mask *= mask_t return topk_mask
def inference_forward(scores_t: BK.Expr, mat_t: BK.Expr, mask_t: BK.Expr, beam_k: int = 0): scores_shape = BK.get_shape(scores_t) # [*, slen, L] need_topk = (beam_k > 0) and (beam_k < scores_shape[-1] ) # whether we need topk # -- score_slices = split_at_dim(scores_t, -2, True) # List[*, 1, L] mask_slices = split_at_dim(mask_t, -1, True) # List[*, 1] # the loop on slen start_shape = scores_shape[:-2] + [1] # [*, 1] last_labs_t = BK.constants_idx(start_shape, 0) # [*, K], todo(note): start with 0! last_accu_scores = BK.zeros(start_shape) # accumulated scores: [*, K] last_potential = BK.zeros( start_shape) # accumulated potentials: [*, K] full_labs_t = BK.arange_idx(scores_shape[-1]).view( [1] * (len(scores_shape) - 2) + [-1]) # [*, L] cur_step = 0 for one_score_slice, one_mask_slice in zip(score_slices, mask_slices): # [*,L],[*,1] one_mask_slice_neg = 1. - one_mask_slice # [*,1] # get current scores if cur_step == 0: # no transition at start! one_cur_scores = one_score_slice # [*, 1, L] else: one_cur_scores = one_score_slice + mat_t[ last_labs_t] # [*, K, L] # first for potentials expanded_potentials = last_potential.unsqueeze( -1) + one_cur_scores # [*, K, L] merged_potentials = log_sum_exp(expanded_potentials, -2) # [*, L] # optional for topk with merging; note: not really topk!! if need_topk: # todo(+W): another option is to directly select with potentials rather than accu_scores expanded_scores = last_accu_scores.unsqueeze( -1) + one_cur_scores # [*, K, L] # max at -2, merge same current label max_scores, max_idxes = expanded_scores.max(-2) # [*, L] # topk at current step, no need to sort! new_accu_scores, new_labs_t = max_scores.topk( beam_k, -1, sorted=False) # [*, K] new_potential = merged_potentials.gather(-1, new_labs_t) # [*, K] # mask and update last_potential = last_potential * one_mask_slice_neg + new_potential * one_mask_slice # [*, K] last_accu_scores = last_accu_scores * one_mask_slice_neg + new_accu_scores * one_mask_slice # [*, K] last_labs_t = last_labs_t * one_mask_slice_neg.long( ) + new_labs_t * one_mask_slice.long() # [*, K] else: # mask and update last_potential = last_potential * one_mask_slice_neg + merged_potentials * one_mask_slice # [*, L(K)] # note: still need to mask this! last_labs_t = last_labs_t * one_mask_slice_neg.long( ) + full_labs_t * one_mask_slice.long() cur_step += 1 # finally sum all ret_potential = log_sum_exp(last_potential, -1) # [*] return ret_potential
def prep_enc(self, med: ZMediator, *args, **kwargs): conf: ZDecoderMlmConf = self.conf # note: we have to use enc-mask # todo(+W): currently do mlm for full seq regardless of center or not! sinfo = med.ibatch.seq_info enc_ids, enc_mask = sinfo.enc_input_ids, sinfo.enc_input_masks # [*, elen] _shape = enc_ids.shape # sample mask mlm_mask = (BK.rand(_shape) < conf.mlm_mrate).float() * enc_mask # [*, elen] # sample repl _repl_sample = BK.rand(_shape) # [*, elen], between [0, 1) mlm_repl_ids = BK.constants_idx(_shape, self.mask_token_id) # [*, elen] [MASK] _repl_rand, _repl_origin = self.repl_ranges mlm_repl_ids = BK.where(_repl_sample > _repl_rand, (BK.rand(_shape) * self.target_size).long(), mlm_repl_ids) mlm_repl_ids = BK.where(_repl_sample > _repl_origin, enc_ids, mlm_repl_ids) # final prepare mlm_input_ids = BK.where(mlm_mask > 0., mlm_repl_ids, enc_ids) # [*, elen] med.set_cache('eff_input_ids', mlm_input_ids) med.set_cache('mlm_mask', mlm_mask)
def prepare_indicators(self, flat_idxes: List, shape): bs, dlen = shape _arange_t = BK.arange_idx(bs) # [*] rets = [] for one_idxes in flat_idxes: one_indicator = BK.constants_idx(shape, 0) # [*, dlen] one_indicator[_arange_t, one_idxes] = 1 rets.append(one_indicator) return rets
def _loss_feed_split(self, mask_expr, split_scores, pred_split_decisions, cand_widxes, cand_masks, cand_expr, cand_scores, expr_seq_gaddr): conf: SoftExtractorConf = self.conf bsize, slen = BK.get_shape(mask_expr) arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [*, 1, 1] # -- # step 2.1: split loss (only on good points (excluding -1|-1 or paddings) with dynamic oracle) cand_gaddr = expr_seq_gaddr[arange2_t, cand_widxes] # [*, clen] cand_gaddr0, cand_gaddr1 = cand_gaddr[:, : -1], cand_gaddr[:, 1:] # [*, clen-1] split_oracle = (cand_gaddr0 != cand_gaddr1).float() * cand_masks[:, 1:] # [*, clen-1] split_oracle_mask = ( (cand_gaddr0 >= 0) | (cand_gaddr1 >= 0)).float() * cand_masks[:, 1:] # [*, clen-1] raw_split_loss = BK.loss_binary( split_scores, split_oracle, label_smoothing=conf.split_label_smoothing) # [*, slen] loss_split_item = LossHelper.compile_leaf_loss( f"split", (raw_split_loss * split_oracle_mask).sum(), split_oracle_mask.sum(), loss_lambda=conf.loss_split) # step 2.2: feed split # note: when teacher-forcing, only forcing good points, others still use pred force_split_decisions = split_oracle_mask * split_oracle + ( 1. - split_oracle_mask) * pred_split_decisions # [*, clen-1] _use_force_mask = (BK.rand([bsize]) <= conf.split_feed_force_rate).float().unsqueeze( -1) # [*, 1], seq-level feed_split_decisions = (_use_force_mask * force_split_decisions + (1. - _use_force_mask) * pred_split_decisions ) # [*, clen-1] # next # *[*, seglen, MW], [*, seglen] seg_ext_cidxes, seg_ext_masks, seg_masks = self._split_extend( feed_split_decisions, cand_masks) seg_ext_scores, seg_ext_cidxes, seg_ext_widxes, seg_ext_masks, seg_weighted_expr = self._split_aggregate( cand_expr, cand_scores, cand_widxes, seg_ext_cidxes, seg_ext_masks, conf.split_topk) # [*, seglen, ?] # finally get oracles for next steps # todo(+N): simply select the highest scored one as oracle if BK.is_zero_shape(seg_ext_scores): # simply make them all -1 oracle_gaddr = BK.constants_idx(seg_masks.shape, -1) # [*, seglen] else: _, _seg_max_t = seg_ext_scores.max(-1, keepdim=True) # [*, seglen, 1] oracle_widxes = seg_ext_widxes.gather(-1, _seg_max_t).squeeze( -1) # [*, seglen] oracle_gaddr = expr_seq_gaddr.gather(-1, oracle_widxes) # [*, seglen] oracle_gaddr[seg_masks <= 0] = -1 # (assign invalid ones) [*, seglen] return loss_split_item, seg_masks, seg_ext_widxes, seg_ext_masks, seg_weighted_expr, oracle_gaddr
def _common_prepare(self, input_shape: Tuple[int], _mask_f: Callable, gold_widx_expr: BK.Expr, gold_wlen_expr: BK.Expr, gold_addr_expr: BK.Expr): conf: SpanExtractorConf = self.conf min_width, max_width = conf.min_width, conf.max_width diff_width = max_width - min_width + 1 # number of width to extract # -- bsize, mlen = input_shape # -- # [bsize, mlen*(max_width-min_width)], mlen first (dim=1) # note: the spans are always sorted by (widx, wlen) _tmp_arange_t = BK.arange_idx(mlen * diff_width) # [mlen*dw] widx_t0 = (_tmp_arange_t // diff_width) # [mlen*dw] wlen_t0 = (_tmp_arange_t % diff_width) + min_width # [mlen*dw] mask_t0 = _mask_f(widx_t0, wlen_t0) # [bsize, mlen*dw] # -- # compacting (use mask2idx and gather) final_idx_t, final_mask_t = BK.mask2idx(mask_t0, padding_idx=0) # [bsize, ??] _tmp2_arange_t = BK.arange_idx(bsize).unsqueeze(1) # [bsize, 1] # no need to make valid for mask=0, since idx=0 means (0, min_width) # todo(+?): do we need to deal with empty ones here? ret_widx = widx_t0[final_idx_t] # [bsize, ??] ret_wlen = wlen_t0[final_idx_t] # [bsize, ??] # -- # prepare gold (as pointer-like addresses) if gold_addr_expr is not None: gold_t0 = BK.constants_idx((bsize, mlen * diff_width), -1) # [bsize, mlen*diff] # check valid of golds (flatten all) gold_valid_t = ((gold_addr_expr >= 0) & (gold_wlen_expr >= min_width) & (gold_wlen_expr <= max_width)) gold_valid_t = gold_valid_t.view(-1) # [bsize*_glen] _glen = BK.get_shape(gold_addr_expr, 1) flattened_bsize_t = BK.arange_idx( bsize * _glen) // _glen # [bsize*_glen] flattened_fidx_t = (gold_widx_expr * diff_width + gold_wlen_expr - min_width).view(-1) # [bsize*_glen] flattened_gaddr_t = gold_addr_expr.view(-1) # mask and assign gold_t0[flattened_bsize_t[gold_valid_t], flattened_fidx_t[gold_valid_t]] = flattened_gaddr_t[ gold_valid_t] ret_gaddr = gold_t0[_tmp2_arange_t, final_idx_t] # [bsize, ??] ret_gaddr.masked_fill_((final_mask_t == 0), -1) # make invalid ones -1 else: ret_gaddr = None # -- return ret_widx, ret_wlen, final_mask_t, ret_gaddr
def select_topk_non_overlapping(score_t: BK.Expr, topk_t: Union[int, BK.Expr], widx_t: BK.Expr, wlen_t: BK.Expr, input_mask_t: BK.Expr, mask_t: BK.Expr = None, dim=-1): score_shape = BK.get_shape(score_t) assert dim == -1 or dim == len( score_shape - 1 ), "Currently only support last-dim!!" # todo(+2): we can permute to allow any dim! # -- # prepare K if isinstance(topk_t, int): tmp_shape = score_shape.copy() tmp_shape[dim] = 1 # set it as 1 topk_t = BK.constants_idx(tmp_shape, topk_t) # -- reshape_trg = [np.prod(score_shape[:-1]).item(), -1] # [*, ?] _, sorted_idxes_t = score_t.sort(dim, descending=True) # -- # put it as CPU and use loop; todo(+N): more efficient ways? arr_sorted_idxes_t, arr_topk_t, arr_widx_t, arr_wlen_t, arr_input_mask_t, arr_mask_t = \ [BK.get_value(z.reshape(reshape_trg)) for z in [sorted_idxes_t, topk_t, widx_t, wlen_t, input_mask_t, mask_t]] _bsize, _cnum = BK.get_shape(arr_sorted_idxes_t) # [bsize, NUM] arr_topk_mask = np.full([_bsize, _cnum], 0.) # [bsize, NUM] _bidx = 0 for aslice_sorted_idxes_t, aslice_topk_t, aslice_widx_t, aslice_wlen_t, aslice_input_mask_t, aslice_mask_t \ in zip(arr_sorted_idxes_t, arr_topk_t, arr_widx_t, arr_wlen_t, arr_input_mask_t, arr_mask_t): aslice_topk_mask = arr_topk_mask[_bidx] # -- cur_ok_mask = np.copy(aslice_input_mask_t) cur_budget = aslice_topk_t.item() for _cidx in aslice_sorted_idxes_t: _cidx = _cidx.item() if cur_budget <= 0: break # no budget left if not aslice_mask_t[_cidx].item(): continue # non-valid candidate one_widx, one_wlen = aslice_widx_t[_cidx].item( ), aslice_wlen_t[_cidx].item() if np.prod(cur_ok_mask[one_widx:one_widx + one_wlen]).item() == 0.: # any hit one? continue # ok! add it! cur_budget -= 1 cur_ok_mask[one_widx:one_widx + one_wlen] = 0. aslice_topk_mask[_cidx] = 1. _bidx += 1 # note: no need to *=mask_t again since already check in the loop return BK.input_real(arr_topk_mask).reshape(score_shape)
def mask2posi_padded(mask: BK.Expr, offset: int, cmin: int): with BK.no_grad_env(): bsize, ssize = BK.get_shape(mask) ret = BK.arange_idx(ssize).repeat(bsize, 1) # [1, ssize] rmask_long_t = (mask == 0.).long() # reverse-mask [bsize, ssize] conti_zeros = BK.constants_idx([bsize], 0) # [bsize], number of continous zeros for sidx in range(ssize): slice = rmask_long_t[:, sidx] # [bsize] conti_zeros = (conti_zeros + slice) * slice # [bsize], *slice to reset ret[:, sidx] -= conti_zeros # -- ret += offset ret.clamp_(min=cmin) return ret
def loss(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr, pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr = None): conf: AnchorExtractorConf = self.conf assert not lookup_flatten bsize, slen = BK.get_shape(mask_expr) # -- # step 0: prepare arr_items, expr_seq_gaddr, expr_seq_labs, expr_group_widxes, expr_group_masks, expr_loss_weight_non = \ self.helper.prepare(insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True) arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [*, 1] arange3_t = arange2_t.unsqueeze(-1) # [*, 1, 1] # -- # step 1: label, simply scoring everything! _main_t, _pair_t = self.lab_node.transform_expr(input_expr, pair_expr) all_scores_t = self.lab_node.score_all( _main_t, _pair_t, mask_expr, None, local_normalize=False, extra_score=external_extra_score ) # unnormalized scores [*, slen, L] all_probs_t = all_scores_t.softmax(-1) # [*, slen, L] all_gprob_t = all_probs_t.gather(-1, expr_seq_labs.unsqueeze(-1)).squeeze( -1) # [*, slen] # how to weight extended_gprob_t = all_gprob_t[ arange3_t, expr_group_widxes] * expr_group_masks # [*, slen, MW] if BK.is_zero_shape(extended_gprob_t): extended_gprob_max_t = BK.zeros(mask_expr.shape) # [*, slen] else: extended_gprob_max_t, _ = extended_gprob_t.max(-1) # [*, slen] _w_alpha = conf.cand_loss_weight_alpha _weight = ( (all_gprob_t * mask_expr) / (extended_gprob_max_t.clamp(min=1e-5)))**_w_alpha # [*, slen] _label_smoothing = conf.lab_conf.labeler_conf.label_smoothing _loss1 = BK.loss_nll(all_scores_t, expr_seq_labs, label_smoothing=_label_smoothing) # [*, slen] _loss2 = BK.loss_nll(all_scores_t, BK.constants_idx([bsize, slen], 0), label_smoothing=_label_smoothing) # [*, slen] _weight1 = _weight.detach() if conf.detach_weight_lab else _weight _raw_loss = _weight1 * _loss1 + (1. - _weight1) * _loss2 # [*, slen] # final weight it cand_loss_weights = BK.where(expr_seq_labs == 0, expr_loss_weight_non.unsqueeze(-1) * conf.loss_weight_non, mask_expr) # [*, slen] final_cand_loss_weights = cand_loss_weights * mask_expr # [*, slen] loss_lab_item = LossHelper.compile_leaf_loss( f"lab", (_raw_loss * final_cand_loss_weights).sum(), final_cand_loss_weights.sum(), loss_lambda=conf.loss_lab, gold=(expr_seq_labs > 0).float().sum()) # -- # step 1.5 all_losses = [loss_lab_item] _loss_cand_entropy = conf.loss_cand_entropy if _loss_cand_entropy > 0.: _prob = extended_gprob_t # [*, slen, MW] _ent = EntropyHelper.self_entropy(_prob) # [*, slen] # [*, slen], only first one in bag _ent_mask = BK.concat([expr_seq_gaddr[:,:1]>=0, expr_seq_gaddr[:,1:]!=expr_seq_gaddr[:,:-1]],-1).float() \ * (expr_seq_labs>0).float() _loss_ent_item = LossHelper.compile_leaf_loss( f"cand_ent", (_ent * _ent_mask).sum(), _ent_mask.sum(), loss_lambda=_loss_cand_entropy) all_losses.append(_loss_ent_item) # -- # step 4: extend (select topk) if conf.loss_ext > 0.: if BK.is_zero_shape(extended_gprob_t): flt_mask = (BK.zeros(mask_expr.shape) > 0) else: _topk = min(conf.ext_loss_topk, BK.get_shape(extended_gprob_t, -1)) # number to extract _topk_grpob_t, _ = extended_gprob_t.topk( _topk, dim=-1) # [*, slen, K] flt_mask = (expr_seq_labs > 0) & (all_gprob_t >= _topk_grpob_t.min(-1)[0]) & ( _weight > conf.ext_loss_thresh) # [*, slen] flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[ flt_mask] # [?] flt_expr = input_expr[flt_mask] # [?, D] flt_full_expr = self._prepare_full_expr(flt_mask) # [?, slen, D] flt_items = arr_items.flatten()[BK.get_value( expr_seq_gaddr[flt_mask])] # [?] flt_weights = _weight.detach( )[flt_mask] if conf.detach_weight_ext else _weight[flt_mask] # [?] loss_ext_item = self.ext_node.loss(flt_items, input_expr[flt_sidx], flt_expr, flt_full_expr, mask_expr[flt_sidx], flt_extra_weights=flt_weights) all_losses.append(loss_ext_item) # -- # return loss ret_loss = LossHelper.combine_multiple_losses(all_losses) return ret_loss, None
def inference_search(scores_t: BK.Expr, mat_t: BK.Expr, mask_t: BK.Expr, beam_k: int = 0): scores_shape = BK.get_shape(scores_t) # [*, slen, L] need_topk = (beam_k > 0) and (beam_k < scores_shape[-1] ) # whether we need topk # -- score_slices = split_at_dim(scores_t, -2, True) # List[*, 1, L] mask_slices = split_at_dim(mask_t, -1, True) # List[*, 1] # the loop on slen all_sel_labs = [] # List of [*, K] all_sel_scores = [] # List of [*, K] all_tracebacks = [] # List of [*, K] start_vals_shape = scores_shape[:-2] + [1] # [*, 1] full_idxes_shape = scores_shape[:-2] + [-1] # [*, ?] last_labs_t = BK.constants_idx(start_vals_shape, 0) # [*, K], todo(note): start with 0! last_accu_scores = BK.zeros( start_vals_shape) # accumulated scores: [*, K] full_labs_t = BK.arange_idx(scores_shape[-1]).expand( full_idxes_shape) # [*, L] cur_step = 0 for one_score_slice, one_mask_slice in zip(score_slices, mask_slices): # [*,L],[*,1] one_mask_slice_neg = 1. - one_mask_slice # [*,1] # get current scores if cur_step == 0: # no transition at start! one_cur_scores = one_score_slice # [*, 1, L] else: # len(all_sel_labs) must >0 one_cur_scores = one_score_slice + mat_t[ last_labs_t] # [*, K, L] # expand scores expanded_scores = last_accu_scores.unsqueeze( -1) + one_cur_scores # [*, K, L] # max at -2, merge same current label max_scores, max_idxes = expanded_scores.max(-2) # [*, L] # need topk? if need_topk: # topk at current step, no need to sort! new_accu_scores, new_labs_t = max_scores.topk( beam_k, -1, sorted=False) # [*, K] new_traceback = max_idxes.gather(-1, new_labs_t) # [*, K] last_labs_t = last_labs_t * one_mask_slice_neg.long( ) + new_labs_t * one_mask_slice.long() # [*, K] else: new_accu_scores = max_scores # [*, L(K)] new_traceback = max_idxes # note: still need to mask this! last_labs_t = last_labs_t * one_mask_slice_neg.long( ) + full_labs_t * one_mask_slice.long() # mask and update last_accu_scores = last_accu_scores * one_mask_slice_neg + new_accu_scores * one_mask_slice # [*, K] default_traceback = BK.arange_idx(BK.get_shape(expanded_scores, -2))\ .view([1]*(len(scores_shape)-2) + [-1]) # [*, K(arange)] last_traceback_t = default_traceback * one_mask_slice_neg.long( ) + new_traceback * one_mask_slice.long() # [*, K] all_sel_labs.append(last_labs_t) all_tracebacks.append(last_traceback_t) one_new_scores = one_cur_scores[BK.arange_idx( scores_shape[0]).unsqueeze(-1), last_traceback_t, last_labs_t] # [*, K] one_new_scores *= one_mask_slice all_sel_scores.append(one_new_scores) cur_step += 1 # traceback _, last_idxes = last_accu_scores.max(-1) # [*] last_idxes = last_idxes.unsqueeze(-1) # [*, 1] all_preds, all_scores = [], [] for cur_step in range(len(all_tracebacks) - 1, -1, -1): all_preds.append(all_sel_labs[cur_step].gather( -1, last_idxes).squeeze(-1)) # [*] all_scores.append(all_sel_scores[cur_step].gather( -1, last_idxes).squeeze(-1)) # [*] last_idxes = all_tracebacks[cur_step].gather(-1, last_idxes) # [*, 1] # remember to reverse!! all_preds.reverse() all_scores.reverse() best_labs = BK.stack(all_preds, -1) # [*, slen] best_scores = BK.stack(all_scores, -1) # [*, slen] return best_labs, best_scores # [*, slen]
def beam_search(self, batch_size: int, beam_k: int, ret_best: bool = True): _NEG_INF = Constants.REAL_PRAC_MIN # -- cur_step = 0 cache: DecCache = None # init: keep the seq of scores rather than traceback! start_vals_shape = [batch_size, 1] # [bs, 1] all_preds_t = BK.constants_idx(start_vals_shape, 0).unsqueeze( -1) # [bs, K, step], todo(note): start with 0! all_scores_t = BK.zeros(start_vals_shape).unsqueeze( -1) # [bs, K, step] accu_scores_t = BK.zeros(start_vals_shape) # [bs, K] arange_t = BK.arange_idx(batch_size).unsqueeze(-1) # [bs, 1] # while loop prev_k = 1 # start with single one while not self.is_end(cur_step): # expand and score cache, scores_t, masks_t = self.step_score( cur_step, prev_k, cache) # ..., [bs*pK, L], [bs*pK] scores_t_shape = BK.get_shape(scores_t) last_dim = scores_t_shape[-1] # L # modify score to handle mask: keep previous pred for the masked items! sel_scores_t = BK.constants([batch_size, prev_k, last_dim], 1.) # [bs, pk, L] sel_scores_t.scatter_(-1, all_preds_t[:, :, -1:], -1) # [bs, pk, L] sel_scores_t = scores_t + _NEG_INF * ( sel_scores_t.view(scores_t_shape) * (1. - masks_t).unsqueeze(-1)) # [bs*pK, L] # first select topk locally, note: here no need to sort! local_k = min(last_dim, beam_k) l_topk_scores, l_topk_idxes = sel_scores_t.topk( local_k, -1, sorted=False) # [bs*pK, lK] # then topk globally on full pK*K add_score_shape = [batch_size, prev_k, local_k] to_sel_shape = [batch_size, prev_k * local_k] global_k = min(to_sel_shape[-1], beam_k) # new k to_sel_scores, to_sel_idxes = \ (l_topk_scores.view(add_score_shape) + accu_scores_t.unsqueeze(-1)).view(to_sel_shape), \ l_topk_idxes.view(to_sel_shape) # [bs, pK*lK] _, g_topk_idxes = to_sel_scores.topk(global_k, -1, sorted=True) # [bs, gK] # get to know the idxes new_preds_t = to_sel_idxes.gather(-1, g_topk_idxes) # [bs, gK] new_pk_idxes = ( g_topk_idxes // local_k ) # which previous idx (in beam) are selected? [bs, gK] # get current pred and scores (handling mask) scores_t3 = scores_t.view([batch_size, -1, last_dim]) # [bs, pK, L] masks_t2 = masks_t.view([batch_size, -1]) # [bs, pK] new_masks_t = masks_t2[arange_t, new_pk_idxes] # [bs, gK] # -- one-step score for new selections: [bs, gK], note: zero scores for masked ones new_scores_t = scores_t3[arange_t, new_pk_idxes, new_preds_t] * new_masks_t # [bs, gK] # ending new_arrange_idxes = (arange_t * prev_k + new_pk_idxes).view( -1) # [bs*gK] cache.arrange_idxes(new_arrange_idxes) self.step_end(cur_step, global_k, cache, new_preds_t.view(-1)) # modify in cache # prepare next & judge ending all_preds_t = BK.concat([ all_preds_t[arange_t, new_pk_idxes], new_preds_t.unsqueeze(-1) ], -1) # [bs, gK, step] all_scores_t = BK.concat([ all_scores_t[arange_t, new_pk_idxes], new_scores_t.unsqueeze(-1) ], -1) # [bs, gK, step] accu_scores_t = accu_scores_t[ arange_t, new_pk_idxes] + new_scores_t # [bs, gK] prev_k = global_k # for next step cur_step += 1 # -- # sort and ret at a final step _, final_idxes = accu_scores_t.topk(prev_k, -1, sorted=True) # [bs, K] ret_preds = all_preds_t[ arange_t, final_idxes][:, :, 1:] # [bs, K, steps], exclude dummy start! ret_scores = all_scores_t[arange_t, final_idxes][:, :, 1:] # [bs, K, steps] if ret_best: return ret_preds[:, 0], ret_scores[:, 0] # [bs, slen] else: return ret_preds, ret_scores # [bs, topk, slen]
def forward(self, inputs, vstate: VrecSteppingState = None, inc_cls=False): conf: BertEncoderConf = self.conf # -- no_bert_ft = (not conf.bert_ft ) # whether fine-tune bert (if not detach hiddens!) impl = self.impl # -- # prepare inputs if not isinstance(inputs, BerterInputBatch): inputs = self.create_input_batch(inputs) all_output_layers = [] # including embeddings # -- # get embeddings (for embeddings, we simply forward once!) mask_repl_rate = conf.bert_repl_mask_rate if self.is_training() else 0. input_ids, input_masks = inputs.get_basic_inputs( mask_repl_rate) # [bsize, 1+sub_len+1] other_embeds = None if self.other_embed_nodes is not None and len( self.other_embed_nodes) > 0: other_embeds = 0. for other_name, other_node in self.other_embed_nodes.items(): other_embeds += other_node( inputs.other_factors[other_name] ) # should be prepared correspondingly!! # -- # forward layers (for layers, we may need to split!) # todo(+N): we simply split things apart, thus middle parts may lack CLS/SEP, and not true global att # todo(+N): the lengths currently are hard-coded!! MAX_LEN = 512 # max len INBUF_LEN = 50 # in-between buffer for splits, for both sides! cur_sub_len = BK.get_shape(input_ids, 1) # 1+sub_len+1 needs_split = (cur_sub_len > MAX_LEN) if needs_split: # decide split and merge points split_points = self._calculate_split_points( cur_sub_len, MAX_LEN, INBUF_LEN) zwarn( f"Multi-seg for Berter: {cur_sub_len}//{len(split_points)}->{split_points}" ) # -- # todo(note): we also need split from embeddings if needs_split: all_embed_pieces = [] split_extended_attention_mask = [] for o_s, o_e, i_s, i_e in split_points: piece_embeddings, piece_extended_attention_mask = impl.forward_embedding( *[(None if z is None else z[:, o_s:o_e]) for z in [ input_ids, input_masks, inputs.batched_token_type_ids, inputs.batched_position_ids, other_embeds ]]) all_embed_pieces.append(piece_embeddings[:, i_s:i_e]) split_extended_attention_mask.append( piece_extended_attention_mask) embeddings = BK.concat(all_embed_pieces, 1) # concat back to full extended_attention_mask = None else: embeddings, extended_attention_mask = impl.forward_embedding( input_ids, input_masks, inputs.batched_token_type_ids, inputs.batched_position_ids, other_embeds) split_extended_attention_mask = None if no_bert_ft: # stop gradient embeddings = embeddings.detach() # -- cur_hidden = embeddings all_output_layers.append(embeddings) # *[bsize, 1+sub_len+1, D] # also prepare mapper idxes for sub <-> orig # todo(+N): currently only use the first sub-word! idxes_arange2 = inputs.arange2_t # [bsize, 1] batched_first_idxes_p1 = (1 + inputs.batched_first_idxes) * ( inputs.batched_first_mask.long()) # plus one for CLS offset! if inc_cls: # [bsize, 1+orig_len] idxes_sub2orig = BK.concat([ BK.constants_idx([inputs.bsize, 1], 0), batched_first_idxes_p1 ], 1) else: # [bsize, orig_len] idxes_sub2orig = batched_first_idxes_p1 _input_masks0 = None # used for vstate back, make it 0. for BOS and EOS # for ii in range(impl.num_hidden_layers): for ii in range(max(self.actual_output_layers) ): # do not need that much if does not require! # forward multiple times with splitting if needed if needs_split: all_pieces = [] for piece_idx, piece_points in enumerate(split_points): o_s, o_e, i_s, i_e = piece_points piece_res = impl.forward_hidden( ii, cur_hidden[:, o_s:o_e], split_extended_attention_mask[piece_idx])[:, i_s:i_e] all_pieces.append(piece_res) new_hidden = BK.concat(all_pieces, 1) # concat back to full else: new_hidden = impl.forward_hidden(ii, cur_hidden, extended_attention_mask) if no_bert_ft: # stop gradient new_hidden = new_hidden.detach() if vstate is not None: # from 1+sub_len+1 -> (inc_cls?)+orig_len new_hidden2orig = new_hidden[ idxes_arange2, idxes_sub2orig] # [bsize, 1?+orig_len, D] # update new_hidden2orig_ret = vstate.update( new_hidden2orig) # [bsize, 1?+orig_len, D] if new_hidden2orig_ret is not None: # calculate when needed if _input_masks0 is None: # [bsize, 1+sub_len+1, 1] with 1. only for real valid ones _input_masks0 = inputs._aug_ends( inputs.batched_input_mask, 0., 0., 0., BK.float32).unsqueeze(-1) # back to 1+sub_len+1; todo(+N): here we simply add and //2, and no CLS back from orig to sub!! tmp_orig2sub = new_hidden2orig_ret[ idxes_arange2, int(inc_cls) + inputs.batched_rev_idxes] # [bsize, sub_len, D] tmp_slice_size = BK.get_shape(tmp_orig2sub) tmp_slice_size[1] = 1 tmp_slice_zero = BK.zeros(tmp_slice_size) tmp_orig2sub_aug = BK.concat( [tmp_slice_zero, tmp_orig2sub, tmp_slice_zero], 1) # [bsize, 1+sub_len+1, D] new_hidden = new_hidden * (1. - _input_masks0) + ( (new_hidden + tmp_orig2sub_aug) / 2.) * _input_masks0 all_output_layers.append(new_hidden) cur_hidden = new_hidden # finally, prepare return final_output_layers = [ all_output_layers[z] for z in conf.bert_output_layers ] # *[bsize,1+sl+1,D] combined_output = self.combiner( final_output_layers) # [bsize, 1+sl+1, ??] final_ret = combined_output[idxes_arange2, idxes_sub2orig] # [bsize, 1?+orig_len, D] return final_ret