Example #1
0
 def _common_prepare(self, input_shape: Tuple[int], _mask_f: Callable,
                     gold_widx_expr: BK.Expr, gold_wlen_expr: BK.Expr,
                     gold_addr_expr: BK.Expr):
     conf: SpanExtractorConf = self.conf
     min_width, max_width = conf.min_width, conf.max_width
     diff_width = max_width - min_width + 1  # number of width to extract
     # --
     bsize, mlen = input_shape
     # --
     # [bsize, mlen*(max_width-min_width)], mlen first (dim=1)
     # note: the spans are always sorted by (widx, wlen)
     _tmp_arange_t = BK.arange_idx(mlen * diff_width)  # [mlen*dw]
     widx_t0 = (_tmp_arange_t // diff_width)  # [mlen*dw]
     wlen_t0 = (_tmp_arange_t % diff_width) + min_width  # [mlen*dw]
     mask_t0 = _mask_f(widx_t0, wlen_t0)  # [bsize, mlen*dw]
     # --
     # compacting (use mask2idx and gather)
     final_idx_t, final_mask_t = BK.mask2idx(mask_t0,
                                             padding_idx=0)  # [bsize, ??]
     _tmp2_arange_t = BK.arange_idx(bsize).unsqueeze(1)  # [bsize, 1]
     # no need to make valid for mask=0, since idx=0 means (0, min_width)
     # todo(+?): do we need to deal with empty ones here?
     ret_widx = widx_t0[final_idx_t]  # [bsize, ??]
     ret_wlen = wlen_t0[final_idx_t]  # [bsize, ??]
     # --
     # prepare gold (as pointer-like addresses)
     if gold_addr_expr is not None:
         gold_t0 = BK.constants_idx((bsize, mlen * diff_width),
                                    -1)  # [bsize, mlen*diff]
         # check valid of golds (flatten all)
         gold_valid_t = ((gold_addr_expr >= 0) &
                         (gold_wlen_expr >= min_width) &
                         (gold_wlen_expr <= max_width))
         gold_valid_t = gold_valid_t.view(-1)  # [bsize*_glen]
         _glen = BK.get_shape(gold_addr_expr, 1)
         flattened_bsize_t = BK.arange_idx(
             bsize * _glen) // _glen  # [bsize*_glen]
         flattened_fidx_t = (gold_widx_expr * diff_width + gold_wlen_expr -
                             min_width).view(-1)  # [bsize*_glen]
         flattened_gaddr_t = gold_addr_expr.view(-1)
         # mask and assign
         gold_t0[flattened_bsize_t[gold_valid_t],
                 flattened_fidx_t[gold_valid_t]] = flattened_gaddr_t[
                     gold_valid_t]
         ret_gaddr = gold_t0[_tmp2_arange_t, final_idx_t]  # [bsize, ??]
         ret_gaddr.masked_fill_((final_mask_t == 0),
                                -1)  # make invalid ones -1
     else:
         ret_gaddr = None
     # --
     return ret_widx, ret_wlen, final_mask_t, ret_gaddr
Example #2
0
def apply_piece_pooling(t: BK.Expr,
                        piece: int,
                        f: Union[Callable,
                                 str] = ActivationHelper.get_pool('max'),
                        dim: int = -1):
    # first do things like chunk by piece
    if piece == 1:
        return t  # nothing to do
    # reshape
    orig_shape = BK.get_shape(t)
    if dim < 0:  # should do this!
        dim = len(orig_shape) + dim
    orig_shape[dim] = piece  # replace it with piece
    new_shape = orig_shape[:dim] + [-1] + orig_shape[dim:]  # put before it
    reshaped_t = t.view(new_shape)  # [..., -1, piece, ...]
    if isinstance(f, str):
        f = ActivationHelper.get_pool(f)
    return f(reshaped_t, dim + 1)  # +1 since we make a new dim