Exemple #1
0
def select_topk(score_t: BK.Expr,
                topk_t: Union[int, BK.Expr],
                mask_t: BK.Expr = None,
                dim=-1):
    # prepare K
    if isinstance(topk_t, int):
        K = topk_t
        tmp_shape = BK.get_shape(score_t)
        tmp_shape[dim] = 1  # set it as 1
        topk_t = BK.constants_idx(tmp_shape, K)
    else:
        K = topk_t.max().item()
    exact_rank_t = topk_t - 1  # [bsize, 1]
    exact_rank_t.clamp_(min=0, max=K - 1)  # make it in valid range!
    # mask values
    if mask_t is not None:
        score_t = score_t + Constants.REAL_PRAC_MIN * (1. - mask_t)
    # topk
    topk_vals, _ = score_t.topk(K, dim, largest=True, sorted=True)  # [*, K, *]
    # gather score
    sel_thresh = topk_vals.gather(dim, exact_rank_t)  # [*, 1, *]
    # get topk_mask
    topk_mask = (score_t >= sel_thresh).float()  # [*, D, *]
    if mask_t is not None:
        topk_mask *= mask_t
    return topk_mask
Exemple #2
0
 def inference_forward(scores_t: BK.Expr,
                       mat_t: BK.Expr,
                       mask_t: BK.Expr,
                       beam_k: int = 0):
     scores_shape = BK.get_shape(scores_t)  # [*, slen, L]
     need_topk = (beam_k > 0) and (beam_k < scores_shape[-1]
                                   )  # whether we need topk
     # --
     score_slices = split_at_dim(scores_t, -2, True)  # List[*, 1, L]
     mask_slices = split_at_dim(mask_t, -1, True)  # List[*, 1]
     # the loop on slen
     start_shape = scores_shape[:-2] + [1]  # [*, 1]
     last_labs_t = BK.constants_idx(start_shape,
                                    0)  # [*, K], todo(note): start with 0!
     last_accu_scores = BK.zeros(start_shape)  # accumulated scores: [*, K]
     last_potential = BK.zeros(
         start_shape)  # accumulated potentials: [*, K]
     full_labs_t = BK.arange_idx(scores_shape[-1]).view(
         [1] * (len(scores_shape) - 2) + [-1])  # [*, L]
     cur_step = 0
     for one_score_slice, one_mask_slice in zip(score_slices,
                                                mask_slices):  # [*,L],[*,1]
         one_mask_slice_neg = 1. - one_mask_slice  # [*,1]
         # get current scores
         if cur_step == 0:  # no transition at start!
             one_cur_scores = one_score_slice  # [*, 1, L]
         else:
             one_cur_scores = one_score_slice + mat_t[
                 last_labs_t]  # [*, K, L]
         # first for potentials
         expanded_potentials = last_potential.unsqueeze(
             -1) + one_cur_scores  # [*, K, L]
         merged_potentials = log_sum_exp(expanded_potentials, -2)  # [*, L]
         # optional for topk with merging; note: not really topk!!
         if need_topk:
             # todo(+W): another option is to directly select with potentials rather than accu_scores
             expanded_scores = last_accu_scores.unsqueeze(
                 -1) + one_cur_scores  # [*, K, L]
             # max at -2, merge same current label
             max_scores, max_idxes = expanded_scores.max(-2)  # [*, L]
             # topk at current step, no need to sort!
             new_accu_scores, new_labs_t = max_scores.topk(
                 beam_k, -1, sorted=False)  # [*, K]
             new_potential = merged_potentials.gather(-1,
                                                      new_labs_t)  # [*, K]
             # mask and update
             last_potential = last_potential * one_mask_slice_neg + new_potential * one_mask_slice  # [*, K]
             last_accu_scores = last_accu_scores * one_mask_slice_neg + new_accu_scores * one_mask_slice  # [*, K]
             last_labs_t = last_labs_t * one_mask_slice_neg.long(
             ) + new_labs_t * one_mask_slice.long()  # [*, K]
         else:
             # mask and update
             last_potential = last_potential * one_mask_slice_neg + merged_potentials * one_mask_slice  # [*, L(K)]
             # note: still need to mask this!
             last_labs_t = last_labs_t * one_mask_slice_neg.long(
             ) + full_labs_t * one_mask_slice.long()
         cur_step += 1
     # finally sum all
     ret_potential = log_sum_exp(last_potential, -1)  # [*]
     return ret_potential
Exemple #3
0
 def prep_enc(self, med: ZMediator, *args, **kwargs):
     conf: ZDecoderMlmConf = self.conf
     # note: we have to use enc-mask
     # todo(+W): currently do mlm for full seq regardless of center or not!
     sinfo = med.ibatch.seq_info
     enc_ids, enc_mask = sinfo.enc_input_ids, sinfo.enc_input_masks  # [*, elen]
     _shape = enc_ids.shape
     # sample mask
     mlm_mask = (BK.rand(_shape) <
                 conf.mlm_mrate).float() * enc_mask  # [*, elen]
     # sample repl
     _repl_sample = BK.rand(_shape)  # [*, elen], between [0, 1)
     mlm_repl_ids = BK.constants_idx(_shape,
                                     self.mask_token_id)  # [*, elen] [MASK]
     _repl_rand, _repl_origin = self.repl_ranges
     mlm_repl_ids = BK.where(_repl_sample > _repl_rand,
                             (BK.rand(_shape) * self.target_size).long(),
                             mlm_repl_ids)
     mlm_repl_ids = BK.where(_repl_sample > _repl_origin, enc_ids,
                             mlm_repl_ids)
     # final prepare
     mlm_input_ids = BK.where(mlm_mask > 0., mlm_repl_ids,
                              enc_ids)  # [*, elen]
     med.set_cache('eff_input_ids', mlm_input_ids)
     med.set_cache('mlm_mask', mlm_mask)
Exemple #4
0
 def prepare_indicators(self, flat_idxes: List, shape):
     bs, dlen = shape
     _arange_t = BK.arange_idx(bs)  # [*]
     rets = []
     for one_idxes in flat_idxes:
         one_indicator = BK.constants_idx(shape, 0)  # [*, dlen]
         one_indicator[_arange_t, one_idxes] = 1
         rets.append(one_indicator)
     return rets
Exemple #5
0
 def _loss_feed_split(self, mask_expr, split_scores, pred_split_decisions,
                      cand_widxes, cand_masks, cand_expr, cand_scores,
                      expr_seq_gaddr):
     conf: SoftExtractorConf = self.conf
     bsize, slen = BK.get_shape(mask_expr)
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [*, 1, 1]
     # --
     # step 2.1: split loss (only on good points (excluding -1|-1 or paddings) with dynamic oracle)
     cand_gaddr = expr_seq_gaddr[arange2_t, cand_widxes]  # [*, clen]
     cand_gaddr0, cand_gaddr1 = cand_gaddr[:, :
                                           -1], cand_gaddr[:,
                                                           1:]  # [*, clen-1]
     split_oracle = (cand_gaddr0 !=
                     cand_gaddr1).float() * cand_masks[:, 1:]  # [*, clen-1]
     split_oracle_mask = (
         (cand_gaddr0 >= 0) |
         (cand_gaddr1 >= 0)).float() * cand_masks[:, 1:]  # [*, clen-1]
     raw_split_loss = BK.loss_binary(
         split_scores,
         split_oracle,
         label_smoothing=conf.split_label_smoothing)  # [*, slen]
     loss_split_item = LossHelper.compile_leaf_loss(
         f"split", (raw_split_loss * split_oracle_mask).sum(),
         split_oracle_mask.sum(),
         loss_lambda=conf.loss_split)
     # step 2.2: feed split
     # note: when teacher-forcing, only forcing good points, others still use pred
     force_split_decisions = split_oracle_mask * split_oracle + (
         1. - split_oracle_mask) * pred_split_decisions  # [*, clen-1]
     _use_force_mask = (BK.rand([bsize])
                        <= conf.split_feed_force_rate).float().unsqueeze(
                            -1)  # [*, 1], seq-level
     feed_split_decisions = (_use_force_mask * force_split_decisions +
                             (1. - _use_force_mask) * pred_split_decisions
                             )  # [*, clen-1]
     # next
     # *[*, seglen, MW], [*, seglen]
     seg_ext_cidxes, seg_ext_masks, seg_masks = self._split_extend(
         feed_split_decisions, cand_masks)
     seg_ext_scores, seg_ext_cidxes, seg_ext_widxes, seg_ext_masks, seg_weighted_expr = self._split_aggregate(
         cand_expr, cand_scores, cand_widxes, seg_ext_cidxes, seg_ext_masks,
         conf.split_topk)  # [*, seglen, ?]
     # finally get oracles for next steps
     # todo(+N): simply select the highest scored one as oracle
     if BK.is_zero_shape(seg_ext_scores):  # simply make them all -1
         oracle_gaddr = BK.constants_idx(seg_masks.shape, -1)  # [*, seglen]
     else:
         _, _seg_max_t = seg_ext_scores.max(-1,
                                            keepdim=True)  # [*, seglen, 1]
         oracle_widxes = seg_ext_widxes.gather(-1, _seg_max_t).squeeze(
             -1)  # [*, seglen]
         oracle_gaddr = expr_seq_gaddr.gather(-1,
                                              oracle_widxes)  # [*, seglen]
     oracle_gaddr[seg_masks <= 0] = -1  # (assign invalid ones) [*, seglen]
     return loss_split_item, seg_masks, seg_ext_widxes, seg_ext_masks, seg_weighted_expr, oracle_gaddr
Exemple #6
0
 def _common_prepare(self, input_shape: Tuple[int], _mask_f: Callable,
                     gold_widx_expr: BK.Expr, gold_wlen_expr: BK.Expr,
                     gold_addr_expr: BK.Expr):
     conf: SpanExtractorConf = self.conf
     min_width, max_width = conf.min_width, conf.max_width
     diff_width = max_width - min_width + 1  # number of width to extract
     # --
     bsize, mlen = input_shape
     # --
     # [bsize, mlen*(max_width-min_width)], mlen first (dim=1)
     # note: the spans are always sorted by (widx, wlen)
     _tmp_arange_t = BK.arange_idx(mlen * diff_width)  # [mlen*dw]
     widx_t0 = (_tmp_arange_t // diff_width)  # [mlen*dw]
     wlen_t0 = (_tmp_arange_t % diff_width) + min_width  # [mlen*dw]
     mask_t0 = _mask_f(widx_t0, wlen_t0)  # [bsize, mlen*dw]
     # --
     # compacting (use mask2idx and gather)
     final_idx_t, final_mask_t = BK.mask2idx(mask_t0,
                                             padding_idx=0)  # [bsize, ??]
     _tmp2_arange_t = BK.arange_idx(bsize).unsqueeze(1)  # [bsize, 1]
     # no need to make valid for mask=0, since idx=0 means (0, min_width)
     # todo(+?): do we need to deal with empty ones here?
     ret_widx = widx_t0[final_idx_t]  # [bsize, ??]
     ret_wlen = wlen_t0[final_idx_t]  # [bsize, ??]
     # --
     # prepare gold (as pointer-like addresses)
     if gold_addr_expr is not None:
         gold_t0 = BK.constants_idx((bsize, mlen * diff_width),
                                    -1)  # [bsize, mlen*diff]
         # check valid of golds (flatten all)
         gold_valid_t = ((gold_addr_expr >= 0) &
                         (gold_wlen_expr >= min_width) &
                         (gold_wlen_expr <= max_width))
         gold_valid_t = gold_valid_t.view(-1)  # [bsize*_glen]
         _glen = BK.get_shape(gold_addr_expr, 1)
         flattened_bsize_t = BK.arange_idx(
             bsize * _glen) // _glen  # [bsize*_glen]
         flattened_fidx_t = (gold_widx_expr * diff_width + gold_wlen_expr -
                             min_width).view(-1)  # [bsize*_glen]
         flattened_gaddr_t = gold_addr_expr.view(-1)
         # mask and assign
         gold_t0[flattened_bsize_t[gold_valid_t],
                 flattened_fidx_t[gold_valid_t]] = flattened_gaddr_t[
                     gold_valid_t]
         ret_gaddr = gold_t0[_tmp2_arange_t, final_idx_t]  # [bsize, ??]
         ret_gaddr.masked_fill_((final_mask_t == 0),
                                -1)  # make invalid ones -1
     else:
         ret_gaddr = None
     # --
     return ret_widx, ret_wlen, final_mask_t, ret_gaddr
Exemple #7
0
def select_topk_non_overlapping(score_t: BK.Expr,
                                topk_t: Union[int, BK.Expr],
                                widx_t: BK.Expr,
                                wlen_t: BK.Expr,
                                input_mask_t: BK.Expr,
                                mask_t: BK.Expr = None,
                                dim=-1):
    score_shape = BK.get_shape(score_t)
    assert dim == -1 or dim == len(
        score_shape - 1
    ), "Currently only support last-dim!!"  # todo(+2): we can permute to allow any dim!
    # --
    # prepare K
    if isinstance(topk_t, int):
        tmp_shape = score_shape.copy()
        tmp_shape[dim] = 1  # set it as 1
        topk_t = BK.constants_idx(tmp_shape, topk_t)
    # --
    reshape_trg = [np.prod(score_shape[:-1]).item(), -1]  # [*, ?]
    _, sorted_idxes_t = score_t.sort(dim, descending=True)
    # --
    # put it as CPU and use loop; todo(+N): more efficient ways?
    arr_sorted_idxes_t, arr_topk_t, arr_widx_t, arr_wlen_t, arr_input_mask_t, arr_mask_t = \
        [BK.get_value(z.reshape(reshape_trg)) for z in [sorted_idxes_t, topk_t, widx_t, wlen_t, input_mask_t, mask_t]]
    _bsize, _cnum = BK.get_shape(arr_sorted_idxes_t)  # [bsize, NUM]
    arr_topk_mask = np.full([_bsize, _cnum], 0.)  # [bsize, NUM]
    _bidx = 0
    for aslice_sorted_idxes_t, aslice_topk_t, aslice_widx_t, aslice_wlen_t, aslice_input_mask_t, aslice_mask_t \
            in zip(arr_sorted_idxes_t, arr_topk_t, arr_widx_t, arr_wlen_t, arr_input_mask_t, arr_mask_t):
        aslice_topk_mask = arr_topk_mask[_bidx]
        # --
        cur_ok_mask = np.copy(aslice_input_mask_t)
        cur_budget = aslice_topk_t.item()
        for _cidx in aslice_sorted_idxes_t:
            _cidx = _cidx.item()
            if cur_budget <= 0: break  # no budget left
            if not aslice_mask_t[_cidx].item(): continue  # non-valid candidate
            one_widx, one_wlen = aslice_widx_t[_cidx].item(
            ), aslice_wlen_t[_cidx].item()
            if np.prod(cur_ok_mask[one_widx:one_widx +
                                   one_wlen]).item() == 0.:  # any hit one?
                continue
            # ok! add it!
            cur_budget -= 1
            cur_ok_mask[one_widx:one_widx + one_wlen] = 0.
            aslice_topk_mask[_cidx] = 1.
        _bidx += 1
    # note: no need to *=mask_t again since already check in the loop
    return BK.input_real(arr_topk_mask).reshape(score_shape)
Exemple #8
0
def mask2posi_padded(mask: BK.Expr, offset: int, cmin: int):
    with BK.no_grad_env():
        bsize, ssize = BK.get_shape(mask)
        ret = BK.arange_idx(ssize).repeat(bsize, 1)  # [1, ssize]
        rmask_long_t = (mask == 0.).long()  # reverse-mask [bsize, ssize]
        conti_zeros = BK.constants_idx([bsize],
                                       0)  # [bsize], number of continous zeros
        for sidx in range(ssize):
            slice = rmask_long_t[:, sidx]  # [bsize]
            conti_zeros = (conti_zeros +
                           slice) * slice  # [bsize], *slice to reset
            ret[:, sidx] -= conti_zeros
        # --
        ret += offset
        ret.clamp_(min=cmin)
        return ret
Exemple #9
0
 def loss(self,
          insts: Union[List[Sent], List[Frame]],
          input_expr: BK.Expr,
          mask_expr: BK.Expr,
          pair_expr: BK.Expr = None,
          lookup_flatten=False,
          external_extra_score: BK.Expr = None):
     conf: AnchorExtractorConf = self.conf
     assert not lookup_flatten
     bsize, slen = BK.get_shape(mask_expr)
     # --
     # step 0: prepare
     arr_items, expr_seq_gaddr, expr_seq_labs, expr_group_widxes, expr_group_masks, expr_loss_weight_non = \
         self.helper.prepare(insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True)
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [*, 1]
     arange3_t = arange2_t.unsqueeze(-1)  # [*, 1, 1]
     # --
     # step 1: label, simply scoring everything!
     _main_t, _pair_t = self.lab_node.transform_expr(input_expr, pair_expr)
     all_scores_t = self.lab_node.score_all(
         _main_t,
         _pair_t,
         mask_expr,
         None,
         local_normalize=False,
         extra_score=external_extra_score
     )  # unnormalized scores [*, slen, L]
     all_probs_t = all_scores_t.softmax(-1)  # [*, slen, L]
     all_gprob_t = all_probs_t.gather(-1,
                                      expr_seq_labs.unsqueeze(-1)).squeeze(
                                          -1)  # [*, slen]
     # how to weight
     extended_gprob_t = all_gprob_t[
         arange3_t, expr_group_widxes] * expr_group_masks  # [*, slen, MW]
     if BK.is_zero_shape(extended_gprob_t):
         extended_gprob_max_t = BK.zeros(mask_expr.shape)  # [*, slen]
     else:
         extended_gprob_max_t, _ = extended_gprob_t.max(-1)  # [*, slen]
     _w_alpha = conf.cand_loss_weight_alpha
     _weight = (
         (all_gprob_t * mask_expr) /
         (extended_gprob_max_t.clamp(min=1e-5)))**_w_alpha  # [*, slen]
     _label_smoothing = conf.lab_conf.labeler_conf.label_smoothing
     _loss1 = BK.loss_nll(all_scores_t,
                          expr_seq_labs,
                          label_smoothing=_label_smoothing)  # [*, slen]
     _loss2 = BK.loss_nll(all_scores_t,
                          BK.constants_idx([bsize, slen], 0),
                          label_smoothing=_label_smoothing)  # [*, slen]
     _weight1 = _weight.detach() if conf.detach_weight_lab else _weight
     _raw_loss = _weight1 * _loss1 + (1. - _weight1) * _loss2  # [*, slen]
     # final weight it
     cand_loss_weights = BK.where(expr_seq_labs == 0,
                                  expr_loss_weight_non.unsqueeze(-1) *
                                  conf.loss_weight_non,
                                  mask_expr)  # [*, slen]
     final_cand_loss_weights = cand_loss_weights * mask_expr  # [*, slen]
     loss_lab_item = LossHelper.compile_leaf_loss(
         f"lab", (_raw_loss * final_cand_loss_weights).sum(),
         final_cand_loss_weights.sum(),
         loss_lambda=conf.loss_lab,
         gold=(expr_seq_labs > 0).float().sum())
     # --
     # step 1.5
     all_losses = [loss_lab_item]
     _loss_cand_entropy = conf.loss_cand_entropy
     if _loss_cand_entropy > 0.:
         _prob = extended_gprob_t  # [*, slen, MW]
         _ent = EntropyHelper.self_entropy(_prob)  # [*, slen]
         # [*, slen], only first one in bag
         _ent_mask = BK.concat([expr_seq_gaddr[:,:1]>=0, expr_seq_gaddr[:,1:]!=expr_seq_gaddr[:,:-1]],-1).float() \
                     * (expr_seq_labs>0).float()
         _loss_ent_item = LossHelper.compile_leaf_loss(
             f"cand_ent", (_ent * _ent_mask).sum(),
             _ent_mask.sum(),
             loss_lambda=_loss_cand_entropy)
         all_losses.append(_loss_ent_item)
     # --
     # step 4: extend (select topk)
     if conf.loss_ext > 0.:
         if BK.is_zero_shape(extended_gprob_t):
             flt_mask = (BK.zeros(mask_expr.shape) > 0)
         else:
             _topk = min(conf.ext_loss_topk,
                         BK.get_shape(extended_gprob_t,
                                      -1))  # number to extract
             _topk_grpob_t, _ = extended_gprob_t.topk(
                 _topk, dim=-1)  # [*, slen, K]
             flt_mask = (expr_seq_labs >
                         0) & (all_gprob_t >= _topk_grpob_t.min(-1)[0]) & (
                             _weight > conf.ext_loss_thresh)  # [*, slen]
         flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[
             flt_mask]  # [?]
         flt_expr = input_expr[flt_mask]  # [?, D]
         flt_full_expr = self._prepare_full_expr(flt_mask)  # [?, slen, D]
         flt_items = arr_items.flatten()[BK.get_value(
             expr_seq_gaddr[flt_mask])]  # [?]
         flt_weights = _weight.detach(
         )[flt_mask] if conf.detach_weight_ext else _weight[flt_mask]  # [?]
         loss_ext_item = self.ext_node.loss(flt_items,
                                            input_expr[flt_sidx],
                                            flt_expr,
                                            flt_full_expr,
                                            mask_expr[flt_sidx],
                                            flt_extra_weights=flt_weights)
         all_losses.append(loss_ext_item)
     # --
     # return loss
     ret_loss = LossHelper.combine_multiple_losses(all_losses)
     return ret_loss, None
Exemple #10
0
 def inference_search(scores_t: BK.Expr,
                      mat_t: BK.Expr,
                      mask_t: BK.Expr,
                      beam_k: int = 0):
     scores_shape = BK.get_shape(scores_t)  # [*, slen, L]
     need_topk = (beam_k > 0) and (beam_k < scores_shape[-1]
                                   )  # whether we need topk
     # --
     score_slices = split_at_dim(scores_t, -2, True)  # List[*, 1, L]
     mask_slices = split_at_dim(mask_t, -1, True)  # List[*, 1]
     # the loop on slen
     all_sel_labs = []  # List of [*, K]
     all_sel_scores = []  # List of [*, K]
     all_tracebacks = []  # List of [*, K]
     start_vals_shape = scores_shape[:-2] + [1]  # [*, 1]
     full_idxes_shape = scores_shape[:-2] + [-1]  # [*, ?]
     last_labs_t = BK.constants_idx(start_vals_shape,
                                    0)  # [*, K], todo(note): start with 0!
     last_accu_scores = BK.zeros(
         start_vals_shape)  # accumulated scores: [*, K]
     full_labs_t = BK.arange_idx(scores_shape[-1]).expand(
         full_idxes_shape)  # [*, L]
     cur_step = 0
     for one_score_slice, one_mask_slice in zip(score_slices,
                                                mask_slices):  # [*,L],[*,1]
         one_mask_slice_neg = 1. - one_mask_slice  # [*,1]
         # get current scores
         if cur_step == 0:  # no transition at start!
             one_cur_scores = one_score_slice  # [*, 1, L]
         else:  # len(all_sel_labs) must >0
             one_cur_scores = one_score_slice + mat_t[
                 last_labs_t]  # [*, K, L]
         # expand scores
         expanded_scores = last_accu_scores.unsqueeze(
             -1) + one_cur_scores  # [*, K, L]
         # max at -2, merge same current label
         max_scores, max_idxes = expanded_scores.max(-2)  # [*, L]
         # need topk?
         if need_topk:
             # topk at current step, no need to sort!
             new_accu_scores, new_labs_t = max_scores.topk(
                 beam_k, -1, sorted=False)  # [*, K]
             new_traceback = max_idxes.gather(-1, new_labs_t)  # [*, K]
             last_labs_t = last_labs_t * one_mask_slice_neg.long(
             ) + new_labs_t * one_mask_slice.long()  # [*, K]
         else:
             new_accu_scores = max_scores  # [*, L(K)]
             new_traceback = max_idxes
             # note: still need to mask this!
             last_labs_t = last_labs_t * one_mask_slice_neg.long(
             ) + full_labs_t * one_mask_slice.long()
         # mask and update
         last_accu_scores = last_accu_scores * one_mask_slice_neg + new_accu_scores * one_mask_slice  # [*, K]
         default_traceback = BK.arange_idx(BK.get_shape(expanded_scores, -2))\
             .view([1]*(len(scores_shape)-2) + [-1])  # [*, K(arange)]
         last_traceback_t = default_traceback * one_mask_slice_neg.long(
         ) + new_traceback * one_mask_slice.long()  # [*, K]
         all_sel_labs.append(last_labs_t)
         all_tracebacks.append(last_traceback_t)
         one_new_scores = one_cur_scores[BK.arange_idx(
             scores_shape[0]).unsqueeze(-1), last_traceback_t,
                                         last_labs_t]  # [*, K]
         one_new_scores *= one_mask_slice
         all_sel_scores.append(one_new_scores)
         cur_step += 1
     # traceback
     _, last_idxes = last_accu_scores.max(-1)  # [*]
     last_idxes = last_idxes.unsqueeze(-1)  # [*, 1]
     all_preds, all_scores = [], []
     for cur_step in range(len(all_tracebacks) - 1, -1, -1):
         all_preds.append(all_sel_labs[cur_step].gather(
             -1, last_idxes).squeeze(-1))  # [*]
         all_scores.append(all_sel_scores[cur_step].gather(
             -1, last_idxes).squeeze(-1))  # [*]
         last_idxes = all_tracebacks[cur_step].gather(-1,
                                                      last_idxes)  # [*, 1]
     # remember to reverse!!
     all_preds.reverse()
     all_scores.reverse()
     best_labs = BK.stack(all_preds, -1)  # [*, slen]
     best_scores = BK.stack(all_scores, -1)  # [*, slen]
     return best_labs, best_scores  # [*, slen]
Exemple #11
0
 def beam_search(self, batch_size: int, beam_k: int, ret_best: bool = True):
     _NEG_INF = Constants.REAL_PRAC_MIN
     # --
     cur_step = 0
     cache: DecCache = None
     # init: keep the seq of scores rather than traceback!
     start_vals_shape = [batch_size, 1]  # [bs, 1]
     all_preds_t = BK.constants_idx(start_vals_shape, 0).unsqueeze(
         -1)  # [bs, K, step], todo(note): start with 0!
     all_scores_t = BK.zeros(start_vals_shape).unsqueeze(
         -1)  # [bs, K, step]
     accu_scores_t = BK.zeros(start_vals_shape)  # [bs, K]
     arange_t = BK.arange_idx(batch_size).unsqueeze(-1)  # [bs, 1]
     # while loop
     prev_k = 1  # start with single one
     while not self.is_end(cur_step):
         # expand and score
         cache, scores_t, masks_t = self.step_score(
             cur_step, prev_k, cache)  # ..., [bs*pK, L], [bs*pK]
         scores_t_shape = BK.get_shape(scores_t)
         last_dim = scores_t_shape[-1]  # L
         # modify score to handle mask: keep previous pred for the masked items!
         sel_scores_t = BK.constants([batch_size, prev_k, last_dim],
                                     1.)  # [bs, pk, L]
         sel_scores_t.scatter_(-1, all_preds_t[:, :, -1:],
                               -1)  # [bs, pk, L]
         sel_scores_t = scores_t + _NEG_INF * (
             sel_scores_t.view(scores_t_shape) *
             (1. - masks_t).unsqueeze(-1))  # [bs*pK, L]
         # first select topk locally, note: here no need to sort!
         local_k = min(last_dim, beam_k)
         l_topk_scores, l_topk_idxes = sel_scores_t.topk(
             local_k, -1, sorted=False)  # [bs*pK, lK]
         # then topk globally on full pK*K
         add_score_shape = [batch_size, prev_k, local_k]
         to_sel_shape = [batch_size, prev_k * local_k]
         global_k = min(to_sel_shape[-1], beam_k)  # new k
         to_sel_scores, to_sel_idxes = \
             (l_topk_scores.view(add_score_shape) + accu_scores_t.unsqueeze(-1)).view(to_sel_shape), \
             l_topk_idxes.view(to_sel_shape)  # [bs, pK*lK]
         _, g_topk_idxes = to_sel_scores.topk(global_k, -1,
                                              sorted=True)  # [bs, gK]
         # get to know the idxes
         new_preds_t = to_sel_idxes.gather(-1, g_topk_idxes)  # [bs, gK]
         new_pk_idxes = (
             g_topk_idxes // local_k
         )  # which previous idx (in beam) are selected? [bs, gK]
         # get current pred and scores (handling mask)
         scores_t3 = scores_t.view([batch_size, -1,
                                    last_dim])  # [bs, pK, L]
         masks_t2 = masks_t.view([batch_size, -1])  # [bs, pK]
         new_masks_t = masks_t2[arange_t, new_pk_idxes]  # [bs, gK]
         # -- one-step score for new selections: [bs, gK], note: zero scores for masked ones
         new_scores_t = scores_t3[arange_t, new_pk_idxes,
                                  new_preds_t] * new_masks_t  # [bs, gK]
         # ending
         new_arrange_idxes = (arange_t * prev_k + new_pk_idxes).view(
             -1)  # [bs*gK]
         cache.arrange_idxes(new_arrange_idxes)
         self.step_end(cur_step, global_k, cache,
                       new_preds_t.view(-1))  # modify in cache
         # prepare next & judge ending
         all_preds_t = BK.concat([
             all_preds_t[arange_t, new_pk_idxes],
             new_preds_t.unsqueeze(-1)
         ], -1)  # [bs, gK, step]
         all_scores_t = BK.concat([
             all_scores_t[arange_t, new_pk_idxes],
             new_scores_t.unsqueeze(-1)
         ], -1)  # [bs, gK, step]
         accu_scores_t = accu_scores_t[
             arange_t, new_pk_idxes] + new_scores_t  # [bs, gK]
         prev_k = global_k  # for next step
         cur_step += 1
     # --
     # sort and ret at a final step
     _, final_idxes = accu_scores_t.topk(prev_k, -1, sorted=True)  # [bs, K]
     ret_preds = all_preds_t[
         arange_t, final_idxes][:, :,
                                1:]  # [bs, K, steps], exclude dummy start!
     ret_scores = all_scores_t[arange_t, final_idxes][:, :,
                                                      1:]  # [bs, K, steps]
     if ret_best:
         return ret_preds[:, 0], ret_scores[:, 0]  # [bs, slen]
     else:
         return ret_preds, ret_scores  # [bs, topk, slen]
Exemple #12
0
 def forward(self, inputs, vstate: VrecSteppingState = None, inc_cls=False):
     conf: BertEncoderConf = self.conf
     # --
     no_bert_ft = (not conf.bert_ft
                   )  # whether fine-tune bert (if not detach hiddens!)
     impl = self.impl
     # --
     # prepare inputs
     if not isinstance(inputs, BerterInputBatch):
         inputs = self.create_input_batch(inputs)
     all_output_layers = []  # including embeddings
     # --
     # get embeddings (for embeddings, we simply forward once!)
     mask_repl_rate = conf.bert_repl_mask_rate if self.is_training() else 0.
     input_ids, input_masks = inputs.get_basic_inputs(
         mask_repl_rate)  # [bsize, 1+sub_len+1]
     other_embeds = None
     if self.other_embed_nodes is not None and len(
             self.other_embed_nodes) > 0:
         other_embeds = 0.
         for other_name, other_node in self.other_embed_nodes.items():
             other_embeds += other_node(
                 inputs.other_factors[other_name]
             )  # should be prepared correspondingly!!
     # --
     # forward layers (for layers, we may need to split!)
     # todo(+N): we simply split things apart, thus middle parts may lack CLS/SEP, and not true global att
     # todo(+N): the lengths currently are hard-coded!!
     MAX_LEN = 512  # max len
     INBUF_LEN = 50  # in-between buffer for splits, for both sides!
     cur_sub_len = BK.get_shape(input_ids, 1)  # 1+sub_len+1
     needs_split = (cur_sub_len > MAX_LEN)
     if needs_split:  # decide split and merge points
         split_points = self._calculate_split_points(
             cur_sub_len, MAX_LEN, INBUF_LEN)
         zwarn(
             f"Multi-seg for Berter: {cur_sub_len}//{len(split_points)}->{split_points}"
         )
     # --
     # todo(note): we also need split from embeddings
     if needs_split:
         all_embed_pieces = []
         split_extended_attention_mask = []
         for o_s, o_e, i_s, i_e in split_points:
             piece_embeddings, piece_extended_attention_mask = impl.forward_embedding(
                 *[(None if z is None else z[:, o_s:o_e]) for z in [
                     input_ids, input_masks, inputs.batched_token_type_ids,
                     inputs.batched_position_ids, other_embeds
                 ]])
             all_embed_pieces.append(piece_embeddings[:, i_s:i_e])
             split_extended_attention_mask.append(
                 piece_extended_attention_mask)
         embeddings = BK.concat(all_embed_pieces, 1)  # concat back to full
         extended_attention_mask = None
     else:
         embeddings, extended_attention_mask = impl.forward_embedding(
             input_ids, input_masks, inputs.batched_token_type_ids,
             inputs.batched_position_ids, other_embeds)
         split_extended_attention_mask = None
     if no_bert_ft:  # stop gradient
         embeddings = embeddings.detach()
     # --
     cur_hidden = embeddings
     all_output_layers.append(embeddings)  # *[bsize, 1+sub_len+1, D]
     # also prepare mapper idxes for sub <-> orig
     # todo(+N): currently only use the first sub-word!
     idxes_arange2 = inputs.arange2_t  # [bsize, 1]
     batched_first_idxes_p1 = (1 + inputs.batched_first_idxes) * (
         inputs.batched_first_mask.long())  # plus one for CLS offset!
     if inc_cls:  # [bsize, 1+orig_len]
         idxes_sub2orig = BK.concat([
             BK.constants_idx([inputs.bsize, 1], 0), batched_first_idxes_p1
         ], 1)
     else:  # [bsize, orig_len]
         idxes_sub2orig = batched_first_idxes_p1
     _input_masks0 = None  # used for vstate back, make it 0. for BOS and EOS
     # for ii in range(impl.num_hidden_layers):
     for ii in range(max(self.actual_output_layers)
                     ):  # do not need that much if does not require!
         # forward multiple times with splitting if needed
         if needs_split:
             all_pieces = []
             for piece_idx, piece_points in enumerate(split_points):
                 o_s, o_e, i_s, i_e = piece_points
                 piece_res = impl.forward_hidden(
                     ii, cur_hidden[:, o_s:o_e],
                     split_extended_attention_mask[piece_idx])[:, i_s:i_e]
                 all_pieces.append(piece_res)
             new_hidden = BK.concat(all_pieces, 1)  # concat back to full
         else:
             new_hidden = impl.forward_hidden(ii, cur_hidden,
                                              extended_attention_mask)
         if no_bert_ft:  # stop gradient
             new_hidden = new_hidden.detach()
         if vstate is not None:
             # from 1+sub_len+1 -> (inc_cls?)+orig_len
             new_hidden2orig = new_hidden[
                 idxes_arange2, idxes_sub2orig]  # [bsize, 1?+orig_len, D]
             # update
             new_hidden2orig_ret = vstate.update(
                 new_hidden2orig)  # [bsize, 1?+orig_len, D]
             if new_hidden2orig_ret is not None:
                 # calculate when needed
                 if _input_masks0 is None:  # [bsize, 1+sub_len+1, 1] with 1. only for real valid ones
                     _input_masks0 = inputs._aug_ends(
                         inputs.batched_input_mask, 0., 0., 0.,
                         BK.float32).unsqueeze(-1)
                 # back to 1+sub_len+1; todo(+N): here we simply add and //2, and no CLS back from orig to sub!!
                 tmp_orig2sub = new_hidden2orig_ret[
                     idxes_arange2,
                     int(inc_cls) +
                     inputs.batched_rev_idxes]  # [bsize, sub_len, D]
                 tmp_slice_size = BK.get_shape(tmp_orig2sub)
                 tmp_slice_size[1] = 1
                 tmp_slice_zero = BK.zeros(tmp_slice_size)
                 tmp_orig2sub_aug = BK.concat(
                     [tmp_slice_zero, tmp_orig2sub, tmp_slice_zero],
                     1)  # [bsize, 1+sub_len+1, D]
                 new_hidden = new_hidden * (1. - _input_masks0) + (
                     (new_hidden + tmp_orig2sub_aug) / 2.) * _input_masks0
         all_output_layers.append(new_hidden)
         cur_hidden = new_hidden
     # finally, prepare return
     final_output_layers = [
         all_output_layers[z] for z in conf.bert_output_layers
     ]  # *[bsize,1+sl+1,D]
     combined_output = self.combiner(
         final_output_layers)  # [bsize, 1+sl+1, ??]
     final_ret = combined_output[idxes_arange2,
                                 idxes_sub2orig]  # [bsize, 1?+orig_len, D]
     return final_ret