def expand_ranged_idxes(widx_t: BK.Expr, wlen_t: BK.Expr, pad: int = 0, max_width: int = None): if max_width is None: # if not provided if BK.is_zero_shape(wlen_t): max_width = 1 # at least one else: max_width = wlen_t.max().item() # overall max width # -- input_shape = BK.get_shape(widx_t) # [*] mw_range_t = BK.arange_idx(max_width).view([1] * len(input_shape) + [-1]) # [*, MW] expanded_idxes = widx_t.unsqueeze(-1) + mw_range_t # [*, MW] expanded_masks_bool = (mw_range_t < wlen_t.unsqueeze(-1)) # [*, MW] expanded_idxes.masked_fill_(~expanded_masks_bool, pad) # [*, MW] return expanded_idxes, expanded_masks_bool.float()
def decode_upos(self, ibatch, logprobs_t: BK.Expr): conf: ZDecoderUposConf = self.conf # get argmax label! pred_upos_scores, pred_upos_labels = logprobs_t.max(-1) # [*, dlen] # arr_upos_scores, arr_upos_labels = BK.get_value(pred_upos_scores), BK.get_value(pred_upos_labels) arr_upos_labels = BK.get_value(pred_upos_labels) # put results voc = self.voc for bidx, item in enumerate( ibatch.items): # for each item in the batch _dec_offsets = item.seq_info.dec_offsets for sidx, sent in enumerate(item.sents): if conf.msent_pred_center and (sidx != item.center_sidx): continue # skip non-center sent in this mode! _start = _dec_offsets[sidx] _len = len(sent) _upos_idxes = arr_upos_labels[bidx][_start:_start + _len].tolist() _upos_labels = voc.seq_idx2word(_upos_idxes) sent.build_uposes(_upos_labels)
def log_sum_exp(t: BK.Expr, dim: int, t_max: BK.Expr = None): if t_max is None: t_max, _ = t.max(dim) # get maximum value; [*, *] ret = t_max + (t - t_max.unsqueeze(dim)).exp().sum(dim).log() # [*, *] return ret