Example #1
0
def expand_ranged_idxes(widx_t: BK.Expr,
                        wlen_t: BK.Expr,
                        pad: int = 0,
                        max_width: int = None):
    if max_width is None:  # if not provided
        if BK.is_zero_shape(wlen_t):
            max_width = 1  # at least one
        else:
            max_width = wlen_t.max().item()  # overall max width
    # --
    input_shape = BK.get_shape(widx_t)  # [*]
    mw_range_t = BK.arange_idx(max_width).view([1] * len(input_shape) +
                                               [-1])  # [*, MW]
    expanded_idxes = widx_t.unsqueeze(-1) + mw_range_t  # [*, MW]
    expanded_masks_bool = (mw_range_t < wlen_t.unsqueeze(-1))  # [*, MW]
    expanded_idxes.masked_fill_(~expanded_masks_bool, pad)  # [*, MW]
    return expanded_idxes, expanded_masks_bool.float()
Example #2
0
 def decode_upos(self, ibatch, logprobs_t: BK.Expr):
     conf: ZDecoderUposConf = self.conf
     # get argmax label!
     pred_upos_scores, pred_upos_labels = logprobs_t.max(-1)  # [*, dlen]
     # arr_upos_scores, arr_upos_labels = BK.get_value(pred_upos_scores), BK.get_value(pred_upos_labels)
     arr_upos_labels = BK.get_value(pred_upos_labels)
     # put results
     voc = self.voc
     for bidx, item in enumerate(
             ibatch.items):  # for each item in the batch
         _dec_offsets = item.seq_info.dec_offsets
         for sidx, sent in enumerate(item.sents):
             if conf.msent_pred_center and (sidx != item.center_sidx):
                 continue  # skip non-center sent in this mode!
             _start = _dec_offsets[sidx]
             _len = len(sent)
             _upos_idxes = arr_upos_labels[bidx][_start:_start +
                                                 _len].tolist()
             _upos_labels = voc.seq_idx2word(_upos_idxes)
             sent.build_uposes(_upos_labels)
Example #3
0
def log_sum_exp(t: BK.Expr, dim: int, t_max: BK.Expr = None):
    if t_max is None:
        t_max, _ = t.max(dim)  # get maximum value; [*, *]
    ret = t_max + (t - t_max.unsqueeze(dim)).exp().sum(dim).log()  # [*, *]
    return ret