def _common_prepare(self, input_shape: Tuple[int], _mask_f: Callable, gold_widx_expr: BK.Expr, gold_wlen_expr: BK.Expr, gold_addr_expr: BK.Expr): conf: SpanExtractorConf = self.conf min_width, max_width = conf.min_width, conf.max_width diff_width = max_width - min_width + 1 # number of width to extract # -- bsize, mlen = input_shape # -- # [bsize, mlen*(max_width-min_width)], mlen first (dim=1) # note: the spans are always sorted by (widx, wlen) _tmp_arange_t = BK.arange_idx(mlen * diff_width) # [mlen*dw] widx_t0 = (_tmp_arange_t // diff_width) # [mlen*dw] wlen_t0 = (_tmp_arange_t % diff_width) + min_width # [mlen*dw] mask_t0 = _mask_f(widx_t0, wlen_t0) # [bsize, mlen*dw] # -- # compacting (use mask2idx and gather) final_idx_t, final_mask_t = BK.mask2idx(mask_t0, padding_idx=0) # [bsize, ??] _tmp2_arange_t = BK.arange_idx(bsize).unsqueeze(1) # [bsize, 1] # no need to make valid for mask=0, since idx=0 means (0, min_width) # todo(+?): do we need to deal with empty ones here? ret_widx = widx_t0[final_idx_t] # [bsize, ??] ret_wlen = wlen_t0[final_idx_t] # [bsize, ??] # -- # prepare gold (as pointer-like addresses) if gold_addr_expr is not None: gold_t0 = BK.constants_idx((bsize, mlen * diff_width), -1) # [bsize, mlen*diff] # check valid of golds (flatten all) gold_valid_t = ((gold_addr_expr >= 0) & (gold_wlen_expr >= min_width) & (gold_wlen_expr <= max_width)) gold_valid_t = gold_valid_t.view(-1) # [bsize*_glen] _glen = BK.get_shape(gold_addr_expr, 1) flattened_bsize_t = BK.arange_idx( bsize * _glen) // _glen # [bsize*_glen] flattened_fidx_t = (gold_widx_expr * diff_width + gold_wlen_expr - min_width).view(-1) # [bsize*_glen] flattened_gaddr_t = gold_addr_expr.view(-1) # mask and assign gold_t0[flattened_bsize_t[gold_valid_t], flattened_fidx_t[gold_valid_t]] = flattened_gaddr_t[ gold_valid_t] ret_gaddr = gold_t0[_tmp2_arange_t, final_idx_t] # [bsize, ??] ret_gaddr.masked_fill_((final_mask_t == 0), -1) # make invalid ones -1 else: ret_gaddr = None # -- return ret_widx, ret_wlen, final_mask_t, ret_gaddr
def apply_piece_pooling(t: BK.Expr, piece: int, f: Union[Callable, str] = ActivationHelper.get_pool('max'), dim: int = -1): # first do things like chunk by piece if piece == 1: return t # nothing to do # reshape orig_shape = BK.get_shape(t) if dim < 0: # should do this! dim = len(orig_shape) + dim orig_shape[dim] = piece # replace it with piece new_shape = orig_shape[:dim] + [-1] + orig_shape[dim:] # put before it reshaped_t = t.view(new_shape) # [..., -1, piece, ...] if isinstance(f, str): f = ActivationHelper.get_pool(f) return f(reshaped_t, dim + 1) # +1 since we make a new dim