Ejemplo n.º 1
0
 def embed_lens(self, query_len: int, key_len: int):
     a_q = BK.arange_idx(query_len).unsqueeze(1)  # [query, 1]
     a_k = BK.arange_idx(key_len).unsqueeze(0)  # [1, key]
     ret_dist = a_q - a_k  # [query, key]
     _, dist_atts, dist_values = self.embed_rposi(ret_dist)
     # [lenq, lenk], [lenq, lenk, dim], [lenq, lenk, dim]
     return ret_dist, dist_atts, dist_values
Ejemplo n.º 2
0
 def lookup_flatten(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr,
                    pair_expr: BK.Expr = None):
     arr_items, mlp_expr, zmask_expr, extra_info = self.lookup(insts, input_expr, mask_expr, pair_expr)
     expr_widxes, expr_wlens = extra_info
     # flatten
     ret_items, ret_sidx, ret_expr, fl_widxes, fl_wlens = LookupNode.flatten_results(
         arr_items, zmask_expr, mlp_expr, expr_widxes, expr_wlens)
     # --
     # also make full expr
     # full_masks = ((_arange_t>=fl_widxes.unsqueeze(-1)) & (_arange_t<(fl_widxes+fl_wlens).unsqueeze(-1))).float()  # [??, slen]
     # ret_full_expr = full_masks.unsqueeze(-1) * ret_expr.unsqueeze(-2)  # [??, slen, D]
     if self.conf.flatten_lookup_use_dist:  # use posi again: [...,-2,-1,0,0,0,1,2,...]
         left_widxes = fl_widxes.unsqueeze(-1)  # [??, 1]
         right_widxes = (fl_widxes+fl_wlens-1).unsqueeze(-1)  # [??, 1]
         _arange_t = BK.arange_idx(BK.get_shape(mask_expr, 1)).unsqueeze(0)  # [1, slen]
         dist0 = _arange_t - left_widxes  # [??, slen]
         dist1 = _arange_t - right_widxes  # [??, slen]
         full_dist = (_arange_t < left_widxes).long() * dist0 + (_arange_t > right_widxes).long() * dist1
         ret_full_expr = self.indicator_norm(self.indicator_embed(full_dist))  # [??, slen, D]
         # # ret_full_expr = self.indicator_embed(full_dist)  # [??, slen, D]
     else:  # otherwise 0/1
         _arange_t = BK.arange_idx(BK.get_shape(mask_expr, 1)).unsqueeze(0)  # [1, slen]
         full_ind = ((_arange_t>=fl_widxes.unsqueeze(-1)) & (_arange_t<(fl_widxes+fl_wlens).unsqueeze(-1))).long()  # [??, slen]
         ret_full_expr = self.indicator_norm(self.indicator_embed(full_ind))  # [??, slen, D]
     # --
     return ret_items, ret_sidx, ret_expr, ret_full_expr  # [??, D]
Ejemplo n.º 3
0
 def _common_prepare(self, input_shape: Tuple[int], _mask_f: Callable,
                     gold_widx_expr: BK.Expr, gold_wlen_expr: BK.Expr,
                     gold_addr_expr: BK.Expr):
     conf: SpanExtractorConf = self.conf
     min_width, max_width = conf.min_width, conf.max_width
     diff_width = max_width - min_width + 1  # number of width to extract
     # --
     bsize, mlen = input_shape
     # --
     # [bsize, mlen*(max_width-min_width)], mlen first (dim=1)
     # note: the spans are always sorted by (widx, wlen)
     _tmp_arange_t = BK.arange_idx(mlen * diff_width)  # [mlen*dw]
     widx_t0 = (_tmp_arange_t // diff_width)  # [mlen*dw]
     wlen_t0 = (_tmp_arange_t % diff_width) + min_width  # [mlen*dw]
     mask_t0 = _mask_f(widx_t0, wlen_t0)  # [bsize, mlen*dw]
     # --
     # compacting (use mask2idx and gather)
     final_idx_t, final_mask_t = BK.mask2idx(mask_t0,
                                             padding_idx=0)  # [bsize, ??]
     _tmp2_arange_t = BK.arange_idx(bsize).unsqueeze(1)  # [bsize, 1]
     # no need to make valid for mask=0, since idx=0 means (0, min_width)
     # todo(+?): do we need to deal with empty ones here?
     ret_widx = widx_t0[final_idx_t]  # [bsize, ??]
     ret_wlen = wlen_t0[final_idx_t]  # [bsize, ??]
     # --
     # prepare gold (as pointer-like addresses)
     if gold_addr_expr is not None:
         gold_t0 = BK.constants_idx((bsize, mlen * diff_width),
                                    -1)  # [bsize, mlen*diff]
         # check valid of golds (flatten all)
         gold_valid_t = ((gold_addr_expr >= 0) &
                         (gold_wlen_expr >= min_width) &
                         (gold_wlen_expr <= max_width))
         gold_valid_t = gold_valid_t.view(-1)  # [bsize*_glen]
         _glen = BK.get_shape(gold_addr_expr, 1)
         flattened_bsize_t = BK.arange_idx(
             bsize * _glen) // _glen  # [bsize*_glen]
         flattened_fidx_t = (gold_widx_expr * diff_width + gold_wlen_expr -
                             min_width).view(-1)  # [bsize*_glen]
         flattened_gaddr_t = gold_addr_expr.view(-1)
         # mask and assign
         gold_t0[flattened_bsize_t[gold_valid_t],
                 flattened_fidx_t[gold_valid_t]] = flattened_gaddr_t[
                     gold_valid_t]
         ret_gaddr = gold_t0[_tmp2_arange_t, final_idx_t]  # [bsize, ??]
         ret_gaddr.masked_fill_((final_mask_t == 0),
                                -1)  # make invalid ones -1
     else:
         ret_gaddr = None
     # --
     return ret_widx, ret_wlen, final_mask_t, ret_gaddr
Ejemplo n.º 4
0
 def predict(self,
             insts: Union[List[Sent], List[Frame]],
             input_expr: BK.Expr,
             mask_expr: BK.Expr,
             pair_expr: BK.Expr = None,
             lookup_flatten=False,
             external_extra_score: BK.Expr = None):
     conf: AnchorExtractorConf = self.conf
     assert not lookup_flatten
     bsize, slen = BK.get_shape(mask_expr)
     # --
     for inst in insts:  # first clear things
         self.helper._clear_f(inst)
     # --
     # step 1: simply labeling!
     best_labs, best_scores = self.lab_node.predict(
         input_expr, pair_expr, mask_expr, extra_score=external_extra_score)
     flt_items = self.helper.put_results(insts, best_labs,
                                         best_scores)  # [?]
     # --
     # step 2: final extend (in a flattened way)
     if len(flt_items) > 0 and conf.pred_ext:
         flt_mask = ((best_labs > 0) & (mask_expr > 0))  # [*, slen]
         flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[
             flt_mask]  # [?]
         flt_expr = input_expr[flt_mask]  # [?, D]
         flt_full_expr = self._prepare_full_expr(flt_mask)  # [?, slen, D]
         self.ext_node.predict(flt_items, input_expr[flt_sidx], flt_expr,
                               flt_full_expr, mask_expr[flt_sidx])
     # --
     # extra:
     self.pp_node.prune(insts)
     return None
Ejemplo n.º 5
0
 def forward(self, inputs, add_bos=False, add_eos=False):
     conf: PosiInputEmbedderConf = self.conf
     # --
     try:
         # input is a shape as prepared by "PosiHelper"
         batch_size, max_len = inputs
         if add_bos:
             max_len += 1
         if add_eos:
             max_len += 1
         posi_idxes = BK.arange_idx(max_len)  # [?len?]
         ret = self.E(posi_idxes).unsqueeze(0).expand(batch_size, -1, -1)
     except:
         # input is tensor
         posi_idxes = BK.input_idx(inputs)  # [*, len]
         cur_maxlen = BK.get_shape(posi_idxes, -1)
         # --
         all_input_slices = []
         slice_shape = BK.get_shape(posi_idxes)[:-1] + [1]
         if add_bos:  # add 0 and offset
             all_input_slices.append(
                 BK.constants(slice_shape, 0, dtype=posi_idxes.dtype))
             cur_maxlen += 1
             posi_idxes += 1
         all_input_slices.append(posi_idxes)  # [*, len]
         if add_eos:
             all_input_slices.append(
                 BK.constants(slice_shape,
                              cur_maxlen,
                              dtype=posi_idxes.dtype))
         final_input_t = BK.concat(all_input_slices, -1)  # [*, 1?+len+1?]
         # finally
         ret = self.E(final_input_t)  # [*, ??, dim]
     return ret
Ejemplo n.º 6
0
 def forward_embedding(self, input_ids, attention_mask, token_type_ids,
                       position_ids, other_embeds):
     input_shape = input_ids.size()  # [bsize, len]
     seq_length = input_shape[1]
     if position_ids is None:
         position_ids = BK.arange_idx(seq_length)  # [len]
         position_ids = position_ids.unsqueeze(0).expand(
             input_shape)  # [bsize, len]
     if token_type_ids is None:
         token_type_ids = BK.zeros(input_shape).long()  # [bsize, len]
     # BertEmbeddings.forward
     _embeddings = self.model.embeddings
     inputs_embeds = _embeddings.word_embeddings(
         input_ids)  # [bsize, len, D]
     position_embeddings = _embeddings.position_embeddings(
         position_ids)  # [bsize, len, D]
     token_type_embeddings = _embeddings.token_type_embeddings(
         token_type_ids)  # [bsize, len, D]
     embeddings = inputs_embeds + position_embeddings + token_type_embeddings
     if other_embeds is not None:
         embeddings += other_embeds
     embeddings = _embeddings.LayerNorm(embeddings)
     embeddings = _embeddings.dropout(embeddings)
     # prepare attention_mask
     if attention_mask is None:
         attention_mask = BK.constants(input_shape, value=1.)
     assert attention_mask.dim() == 2
     extended_attention_mask = attention_mask[:, None, None, :]
     extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
     return embeddings, extended_attention_mask
Ejemplo n.º 7
0
 def inference_forward(scores_t: BK.Expr,
                       mat_t: BK.Expr,
                       mask_t: BK.Expr,
                       beam_k: int = 0):
     scores_shape = BK.get_shape(scores_t)  # [*, slen, L]
     need_topk = (beam_k > 0) and (beam_k < scores_shape[-1]
                                   )  # whether we need topk
     # --
     score_slices = split_at_dim(scores_t, -2, True)  # List[*, 1, L]
     mask_slices = split_at_dim(mask_t, -1, True)  # List[*, 1]
     # the loop on slen
     start_shape = scores_shape[:-2] + [1]  # [*, 1]
     last_labs_t = BK.constants_idx(start_shape,
                                    0)  # [*, K], todo(note): start with 0!
     last_accu_scores = BK.zeros(start_shape)  # accumulated scores: [*, K]
     last_potential = BK.zeros(
         start_shape)  # accumulated potentials: [*, K]
     full_labs_t = BK.arange_idx(scores_shape[-1]).view(
         [1] * (len(scores_shape) - 2) + [-1])  # [*, L]
     cur_step = 0
     for one_score_slice, one_mask_slice in zip(score_slices,
                                                mask_slices):  # [*,L],[*,1]
         one_mask_slice_neg = 1. - one_mask_slice  # [*,1]
         # get current scores
         if cur_step == 0:  # no transition at start!
             one_cur_scores = one_score_slice  # [*, 1, L]
         else:
             one_cur_scores = one_score_slice + mat_t[
                 last_labs_t]  # [*, K, L]
         # first for potentials
         expanded_potentials = last_potential.unsqueeze(
             -1) + one_cur_scores  # [*, K, L]
         merged_potentials = log_sum_exp(expanded_potentials, -2)  # [*, L]
         # optional for topk with merging; note: not really topk!!
         if need_topk:
             # todo(+W): another option is to directly select with potentials rather than accu_scores
             expanded_scores = last_accu_scores.unsqueeze(
                 -1) + one_cur_scores  # [*, K, L]
             # max at -2, merge same current label
             max_scores, max_idxes = expanded_scores.max(-2)  # [*, L]
             # topk at current step, no need to sort!
             new_accu_scores, new_labs_t = max_scores.topk(
                 beam_k, -1, sorted=False)  # [*, K]
             new_potential = merged_potentials.gather(-1,
                                                      new_labs_t)  # [*, K]
             # mask and update
             last_potential = last_potential * one_mask_slice_neg + new_potential * one_mask_slice  # [*, K]
             last_accu_scores = last_accu_scores * one_mask_slice_neg + new_accu_scores * one_mask_slice  # [*, K]
             last_labs_t = last_labs_t * one_mask_slice_neg.long(
             ) + new_labs_t * one_mask_slice.long()  # [*, K]
         else:
             # mask and update
             last_potential = last_potential * one_mask_slice_neg + merged_potentials * one_mask_slice  # [*, L(K)]
             # note: still need to mask this!
             last_labs_t = last_labs_t * one_mask_slice_neg.long(
             ) + full_labs_t * one_mask_slice.long()
         cur_step += 1
     # finally sum all
     ret_potential = log_sum_exp(last_potential, -1)  # [*]
     return ret_potential
Ejemplo n.º 8
0
 def _prepare_full_expr(self, flt_mask: BK.Expr):
     bsize, slen = BK.get_shape(flt_mask)
     arange2_t = BK.arange_idx(slen).unsqueeze(0)  # [1, slen]
     all_widxes = arange2_t.expand_as(flt_mask)[flt_mask]  # [?]
     tmp_idxes = BK.zeros([len(all_widxes), slen]).long()  # [?, slen]
     tmp_idxes.scatter_(-1, all_widxes.unsqueeze(-1), 1)  # [?, slen]
     tmp_embs = self.indicator_embed(tmp_idxes)  # [?, slen, D]
     return tmp_embs
Ejemplo n.º 9
0
 def flatten_results(arr_items: np.ndarray, mask_expr: BK.Expr, *other_exprs: BK.Expr):
     sel_mask_expr = (mask_expr > 0.)
     # flatten first dims
     ret_items = [z for z in arr_items.flatten() if z is not None]
     ret_other_exprs = [z[sel_mask_expr] for z in other_exprs]  # [?(flat), D]
     ret_sidx = BK.arange_idx(BK.get_shape(mask_expr, 0)).unsqueeze(-1).expand_as(sel_mask_expr)[sel_mask_expr]
     assert all(len(ret_items) == len(z) for z in ret_other_exprs), "Error: dim0 not matched after flatten!"
     return ret_items, ret_sidx, *ret_other_exprs  # [?(flat), *]
Ejemplo n.º 10
0
 def prepare_indicators(self, flat_idxes: List, shape):
     bs, dlen = shape
     _arange_t = BK.arange_idx(bs)  # [*]
     rets = []
     for one_idxes in flat_idxes:
         one_indicator = BK.constants_idx(shape, 0)  # [*, dlen]
         one_indicator[_arange_t, one_idxes] = 1
         rets.append(one_indicator)
     return rets
Ejemplo n.º 11
0
 def _loss_feed_split(self, mask_expr, split_scores, pred_split_decisions,
                      cand_widxes, cand_masks, cand_expr, cand_scores,
                      expr_seq_gaddr):
     conf: SoftExtractorConf = self.conf
     bsize, slen = BK.get_shape(mask_expr)
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [*, 1, 1]
     # --
     # step 2.1: split loss (only on good points (excluding -1|-1 or paddings) with dynamic oracle)
     cand_gaddr = expr_seq_gaddr[arange2_t, cand_widxes]  # [*, clen]
     cand_gaddr0, cand_gaddr1 = cand_gaddr[:, :
                                           -1], cand_gaddr[:,
                                                           1:]  # [*, clen-1]
     split_oracle = (cand_gaddr0 !=
                     cand_gaddr1).float() * cand_masks[:, 1:]  # [*, clen-1]
     split_oracle_mask = (
         (cand_gaddr0 >= 0) |
         (cand_gaddr1 >= 0)).float() * cand_masks[:, 1:]  # [*, clen-1]
     raw_split_loss = BK.loss_binary(
         split_scores,
         split_oracle,
         label_smoothing=conf.split_label_smoothing)  # [*, slen]
     loss_split_item = LossHelper.compile_leaf_loss(
         f"split", (raw_split_loss * split_oracle_mask).sum(),
         split_oracle_mask.sum(),
         loss_lambda=conf.loss_split)
     # step 2.2: feed split
     # note: when teacher-forcing, only forcing good points, others still use pred
     force_split_decisions = split_oracle_mask * split_oracle + (
         1. - split_oracle_mask) * pred_split_decisions  # [*, clen-1]
     _use_force_mask = (BK.rand([bsize])
                        <= conf.split_feed_force_rate).float().unsqueeze(
                            -1)  # [*, 1], seq-level
     feed_split_decisions = (_use_force_mask * force_split_decisions +
                             (1. - _use_force_mask) * pred_split_decisions
                             )  # [*, clen-1]
     # next
     # *[*, seglen, MW], [*, seglen]
     seg_ext_cidxes, seg_ext_masks, seg_masks = self._split_extend(
         feed_split_decisions, cand_masks)
     seg_ext_scores, seg_ext_cidxes, seg_ext_widxes, seg_ext_masks, seg_weighted_expr = self._split_aggregate(
         cand_expr, cand_scores, cand_widxes, seg_ext_cidxes, seg_ext_masks,
         conf.split_topk)  # [*, seglen, ?]
     # finally get oracles for next steps
     # todo(+N): simply select the highest scored one as oracle
     if BK.is_zero_shape(seg_ext_scores):  # simply make them all -1
         oracle_gaddr = BK.constants_idx(seg_masks.shape, -1)  # [*, seglen]
     else:
         _, _seg_max_t = seg_ext_scores.max(-1,
                                            keepdim=True)  # [*, seglen, 1]
         oracle_widxes = seg_ext_widxes.gather(-1, _seg_max_t).squeeze(
             -1)  # [*, seglen]
         oracle_gaddr = expr_seq_gaddr.gather(-1,
                                              oracle_widxes)  # [*, seglen]
     oracle_gaddr[seg_masks <= 0] = -1  # (assign invalid ones) [*, seglen]
     return loss_split_item, seg_masks, seg_ext_widxes, seg_ext_masks, seg_weighted_expr, oracle_gaddr
Ejemplo n.º 12
0
 def get_last_emb(self):
     k = "last_emb"
     ret = self.l_caches.get(k)
     if ret is None:
         ret = self.embs[-1]
         valid_idxes_t = self.valid_idxes_t
         if valid_idxes_t is not None:
             arange2_t = BK.arange_idx(BK.get_shape(valid_idxes_t, 0)).unsqueeze(-1)  # [bsize, 1]
             ret = ret[arange2_t, valid_idxes_t]  # select!
         self.l_caches[k] = ret  # cache
     return ret
Ejemplo n.º 13
0
 def get_stack_att(self):
     k = "stack_att"
     ret = self.l_caches.get(k)
     if ret is None:
         ret = BK.stack(self.attns, -1).permute(0,2,3,4,1)  # NL*[*, H, lenq, lenk] -> [*, lenq, lenk, NL, H]
         valid_idxes_t = self.valid_idxes_t
         if valid_idxes_t is not None:
             arange3_t = BK.arange_idx(BK.get_shape(valid_idxes_t, 0)).unsqueeze(-1).unsqueeze(-1)  # [bsize, 1, 1]
             ret = ret[arange3_t, valid_idxes_t.unsqueeze(-1), valid_idxes_t.unsqueeze(-2)]  # select!
         self.l_caches[k] = ret  # cache
     return ret
Ejemplo n.º 14
0
 def get_stack_emb(self):
     k = "stack_emb"
     ret = self.l_caches.get(k)
     if ret is None:
         # note: excluding embeddings here to make it convenient!!
         ret = BK.stack(self.embs[1:], -1)  # [*, slen, D, NL]
         valid_idxes_t = self.valid_idxes_t
         if valid_idxes_t is not None:
             arange2_t = BK.arange_idx(BK.get_shape(valid_idxes_t, 0)).unsqueeze(-1)  # [bsize, 1]
             ret = ret[arange2_t, valid_idxes_t]  # select!
         self.l_caches[k] = ret  # cache
     return ret
Ejemplo n.º 15
0
 def forward(self, input_expr: BK.Expr, widx_expr: BK.Expr, wlen_expr: BK.Expr):
     conf: BaseSpanConf = self.conf
     # --
     # note: check empty, otherwise error
     input_item_shape = BK.get_shape(widx_expr)
     if np.prod(input_item_shape) == 0:
         return BK.zeros(input_item_shape + [self.output_dim])  # return an empty but shaped tensor
     # --
     start_idxes, end_idxes = widx_expr, widx_expr+wlen_expr  # make [start, end)
     # get sizes
     bsize, slen = BK.get_shape(input_expr)[:2]
     # num_span = BK.get_shape(start_idxes, 1)
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [bsize, 1]
     # --
     reprs = []
     if conf.use_starts:  # start [start,
         reprs.append(input_expr[arange2_t, start_idxes])  # [bsize, ?, D]
     if conf.use_ends:  # simply ,end-1]
         reprs.append(input_expr[arange2_t, end_idxes-1])
     if conf.use_softhead:
         # expand range
         all_span_idxes, all_span_mask = expand_ranged_idxes(widx_expr, wlen_expr, 0, None)  # [bsize, ?, MW]
         # flatten
         flatten_all_span_idxes = all_span_idxes.view(bsize, -1)  # [bsize, ?*MW]
         flatten_all_span_mask = all_span_mask.view(bsize, -1)  # [bsize, ?*MW]
         # get softhead score (consider mask here)
         softhead_scores = self.softhead_scorer(input_expr).squeeze(-1)  # [bsize, slen]
         flatten_all_span_scores = softhead_scores[arange2_t, flatten_all_span_idxes]  # [bsize, ?*MW]
         flatten_all_span_scores += (1.-flatten_all_span_mask) * Constants.REAL_PRAC_MIN
         all_span_scores = flatten_all_span_scores.view(all_span_idxes.shape)  # [bsize, ?, MW]
         # reshape and (optionally topk) and softmax
         softhead_topk = conf.softhead_topk
         if softhead_topk>0 and BK.get_shape(all_span_scores,-1)>softhead_topk:  # further select topk; note: this may save mem
             final_span_score, _tmp_idxes = all_span_scores.topk(softhead_topk, dim=-1, sorted=False)  # [bsize, ?, K]
             final_span_idxes = all_span_idxes.gather(-1, _tmp_idxes)  # [bsize, ?, K]
         else:
             final_span_score, final_span_idxes = all_span_scores, all_span_idxes  # [bsize, ?, MW]
         final_prob = final_span_score.softmax(-1)  # [bsize, ?, ??]
         # [bsize, ?, ??, D]
         final_repr = input_expr[arange2_t, final_span_idxes.view(bsize, -1)].view(BK.get_shape(final_span_idxes)+[-1])
         weighted_repr = (final_repr * final_prob.unsqueeze(-1)).sum(-2)  # [bsize, ?, D]
         reprs.append(weighted_repr)
     if conf.use_width:
         cur_width_embed = self.width_embed(wlen_expr)  # [bsize, ?, DE]
         reprs.append(cur_width_embed)
     # concat
     concat_repr = BK.concat(reprs, -1)  # [bsize, ?, SUM]
     if conf.use_proj:
         ret = self.final_proj(concat_repr)  # [bsize, ?, DR]
     else:
         ret = concat_repr
     return ret
Ejemplo n.º 16
0
 def __init__(self, ibatch: InputBatch, IDX_PAD: int):
     # preps
     self.bsize = len(ibatch)
     self.arange1_t = BK.arange_idx(self.bsize)  # [bsize]
     self.arange2_t = self.arange1_t.unsqueeze(-1)  # [bsize, 1]
     self.arange3_t = self.arange2_t.unsqueeze(-1)  # [bsize, 1, 1]
     # batched them
     all_seq_infos = [z.seq_info for z in ibatch.items]
     # enc: [*, len_enc]: ids(pad IDX_PAD), masks, segids(pad 0)
     self.enc_input_ids = BK.input_idx(
         DataPadder.go_batch_2d([z.enc_input_ids for z in all_seq_infos],
                                int(IDX_PAD)))
     self.enc_input_masks = BK.input_real(
         DataPadder.lengths2mask(
             [len(z.enc_input_ids) for z in all_seq_infos]))
     self.enc_input_segids = BK.input_idx(
         DataPadder.go_batch_2d([z.enc_input_segids for z in all_seq_infos],
                                0))
     # dec: [*, len_dec]: sel_idxes(pad 0), sel_lens(pad 1), masks, sent_idxes(pad ??)
     self.dec_sel_idxes = BK.input_idx(
         DataPadder.go_batch_2d([z.dec_sel_idxes for z in all_seq_infos],
                                0))
     self.dec_sel_lens = BK.input_idx(
         DataPadder.go_batch_2d([z.dec_sel_lens for z in all_seq_infos], 1))
     self.dec_sel_masks = BK.input_real(
         DataPadder.lengths2mask(
             [len(z.dec_sel_idxes) for z in all_seq_infos]))
     _max_dec_len = BK.get_shape(self.dec_sel_masks, 1)
     _dec_offsets = BK.input_idx(
         DataPadder.go_batch_2d([z.dec_offsets for z in all_seq_infos],
                                _max_dec_len))
     # note: CLS as -1, then 0,1,2,..., PAD gets -2!
     self.dec_sent_idxes = \
         (BK.arange_idx(_max_dec_len).unsqueeze(0).unsqueeze(-1) >= _dec_offsets.unsqueeze(-2)).sum(-1).long() - 1
     self.dec_sent_idxes[self.dec_sel_masks <= 0.] = -2
     # dec -> enc: [*, len_enc] (calculated on needed!)
     # note: require 1-to-1 mapping (except pads)!!
     self._enc_back_hits = None
     self._enc_back_sel_idxes = None
Ejemplo n.º 17
0
 def _aggregate_subtoks(self, repr_t: BK.Expr, dsel_seq_info):
     conf: DSelectorConf = self.conf
     _arange_t, _sel_t, _len_t = dsel_seq_info.arange2_t, dsel_seq_info.dec_sel_idxes, dsel_seq_info.dec_sel_lens
     _max_len = 1 if BK.is_zero_shape(_len_t) else _len_t.max().item()
     _max_len = max(1, min(conf.dsel_max_subtoks, _max_len))  # truncate
     # --
     _tmp_arange_t = BK.arange_idx(_max_len)  # [M]
     _all_valids_t = (_tmp_arange_t < _len_t.unsqueeze(-1)).float()  # [*, dlen, M]
     _tmp_arange_t = _tmp_arange_t * _all_valids_t.long()  # note: pad as 0
     _all_idxes_t = _sel_t.unsqueeze(-1) + _tmp_arange_t  # [*, dlen, M]
     _all_repr_t = repr_t[_arange_t.unsqueeze(-1), _all_idxes_t]  # [*, dlen, M, D]
     while len(BK.get_shape(_all_valids_t)) < len(BK.get_shape(_all_repr_t)):
         _all_valids_t = _all_valids_t.unsqueeze(-1)
     _all_repr_t = _all_repr_t * _all_valids_t
     return _all_repr_t, _all_valids_t
Ejemplo n.º 18
0
 def s0_open_new_steps(self, bsize: int, ssize: int, mask: BK.Expr = None):
     assert ssize > 0
     assert self._cur_layer_idx == -1
     self._cur_layer_idx = 0
     # --
     new_mask = BK.constants([bsize, ssize], 1.) if mask is None else mask  # [*, ssize]
     # --
     # prepare for store_lstate selecting
     if len(self.cum_state_lset) > 0:  # any layer need to accumulat?
         self._arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [bsize, 1]
         # note: if no last state, simply clamp 0, otherwise, offset by 1 since we will concat later
         self._arange_sel_t = mask2posi_padded(new_mask, 0, 0) if mask is None else mask2posi_padded(new_mask, 1, 0)
     # prev_steps = self.steps  # previous accumulated steps
     self.steps += ssize
     self.mask = new_mask if self.mask is None else BK.concat([self.mask, new_mask], 1)  # [*, old+new]
     self.positions = mask2posi(self.mask, offset=-1, cmin=0)  # [*, old+new], recalculate!!
Ejemplo n.º 19
0
def mask2posi_padded(mask: BK.Expr, offset: int, cmin: int):
    with BK.no_grad_env():
        bsize, ssize = BK.get_shape(mask)
        ret = BK.arange_idx(ssize).repeat(bsize, 1)  # [1, ssize]
        rmask_long_t = (mask == 0.).long()  # reverse-mask [bsize, ssize]
        conti_zeros = BK.constants_idx([bsize],
                                       0)  # [bsize], number of continous zeros
        for sidx in range(ssize):
            slice = rmask_long_t[:, sidx]  # [bsize]
            conti_zeros = (conti_zeros +
                           slice) * slice  # [bsize], *slice to reset
            ret[:, sidx] -= conti_zeros
        # --
        ret += offset
        ret.clamp_(min=cmin)
        return ret
Ejemplo n.º 20
0
 def forward(self, input_expr, mask_expr=None, **kwargs):
     conf: TTransformerConf = self.conf
     # --
     if conf.n_layers == 0:
         return input_expr  # change nothing if no layers
     # --
     if conf.use_posi:
         ssize = BK.get_shape(input_expr, 1)  # step size
         posi_embed = self.PE(BK.arange_idx(ssize)).unsqueeze(0)  # [1, step, D]
         input_x = input_expr + posi_embed
     else:
         input_x = input_expr
     if conf.norm_input:
         input_x = self.norm(input_x)
     output = self.enc(input_x.transpose(0,1), src_key_padding_mask=(mask_expr>0)).transpose(0,1).contiguous()
     return output
Ejemplo n.º 21
0
def expand_ranged_idxes(widx_t: BK.Expr,
                        wlen_t: BK.Expr,
                        pad: int = 0,
                        max_width: int = None):
    if max_width is None:  # if not provided
        if BK.is_zero_shape(wlen_t):
            max_width = 1  # at least one
        else:
            max_width = wlen_t.max().item()  # overall max width
    # --
    input_shape = BK.get_shape(widx_t)  # [*]
    mw_range_t = BK.arange_idx(max_width).view([1] * len(input_shape) +
                                               [-1])  # [*, MW]
    expanded_idxes = widx_t.unsqueeze(-1) + mw_range_t  # [*, MW]
    expanded_masks_bool = (mw_range_t < wlen_t.unsqueeze(-1))  # [*, MW]
    expanded_idxes.masked_fill_(~expanded_masks_bool, pad)  # [*, MW]
    return expanded_idxes, expanded_masks_bool.float()
Ejemplo n.º 22
0
 def _f(_widx, _wlen):  # [mlen*dw] -> [bsize, mlen*dw]
     _mw = self.conf.max_width  # todo(note): also rely on this
     _bsize = BK.get_shape(input_mask, 0)
     _padded_input_mask = BK.pad(
         input_mask, (0, _mw),
         value=0.)  # [bsize, mlen+_mw], make idxing valid!!
     _er_idxes, _er_masks = expand_ranged_idxes(_widx, _wlen, 0,
                                                _mw)  # [mlen*dw, MW]
     _arange_t = BK.arange_idx(_bsize).unsqueeze(-1).unsqueeze(
         -1)  # [bsize, 1, 1]
     _idx_valid = _padded_input_mask[_arange_t,
                                     _er_idxes.unsqueeze(
                                         0)]  # [bsize, mlen*dw, MW]
     _idx_valid.masked_fill_((_er_masks == 0.).unsqueeze(0),
                             1.)  # make sure paddings get 1
     _ret = _idx_valid.prod(
         -1)  # [bsize, mlen*dw], require all non-pad ones to be valid!
     return _ret
Ejemplo n.º 23
0
 def forward(self, med: ZMediator, **kwargs):
     conf: IdecConnectorPlainConf = self.conf
     # --
     if self.do_seq_pool:
         # note: for pooling, use the raw emb!!
         mixed_emb0 = self._go_detach(med.get_raw_last_emb())  # [*, ??, D]
         mixed_emb = self.pool_f(mixed_emb0)  # [*, D]
     else:
         if conf.use_nlayer == 1:  # simply get the last one
             mixed_emb = self._go_detach(med.get_last_emb())
         else:  # mix them
             stacked_embs = self._go_detach(med.get_stack_emb(
             ))[:, :, :, -len(self.mixed_weights):]  # [*, slen, D, NL]
             mixed_emb = BK.matmul(
                 stacked_embs,
                 BK.softmax(self.mixed_weights,
                            -1).unsqueeze(-1)).squeeze(-1)  # [*, slen, D]
         if self.do_seq_sel:
             _arange_t = BK.arange_idx(BK.get_shape(mixed_emb, 0))
             _idx_t = med.get_cache(conf.seq_sel_key)
             mixed_emb = mixed_emb[_arange_t, _idx_t]  # [*, D]
     # further affine
     if self.input_mask is not None:  # note: special input mask!!
         mixed_emb = mixed_emb * self.input_mask.detach(
         )  # no grad for input_mask!!
     drop_emb = self.pre_mid_drop(mixed_emb)
     if conf.mid_dim > 0:
         # gather inputs
         _r = conf.mid_extra_range
         _detached_drop_emb = drop_emb.detach()
         _inputs = []
         for ii in range(-_r, _r + 1):
             if ii < 0:
                 _one = BK.pad(_detached_drop_emb[:, :ii], [0, 0, -ii, 0])
             elif ii == 0:
                 _one = drop_emb  # no need to change!
             else:
                 _one = BK.pad(_detached_drop_emb[:, ii:], [0, 0, 0, ii])
             _inputs.append(_one)
         # --
         ret_t = self.mid_aff(_inputs)  # [*, slen, M] or [*, M]
     else:
         ret_t = drop_emb
     return ret_t
Ejemplo n.º 24
0
 def forward(self, input_expr, mask_expr, med: ZMediator):
     conf: MyTransformerConf = self.conf
     med.start(mask_expr)
     # for the L0 layer
     cur_expr = med.forw_emb(input_expr, norm_node=None)  # norm right later
     if conf.use_posi:
         ssize = BK.get_shape(input_expr, 1)  # step size
         posi_embed = self.PE(BK.arange_idx(ssize)).unsqueeze(0)  # [1, step, D]
         cur_expr = self.input_norm(cur_expr + posi_embed)
     else:
         cur_expr = self.input_norm(cur_expr)
     # L1+
     for ti, tnode in enumerate(self.tnodes):
         med.next()
         cur_expr = tnode(cur_expr, mask_expr, med)
         if med.is_end():
             break
     # clean
     med.end()
     return cur_expr
Ejemplo n.º 25
0
 def decode_with_scores(left_scores: BK.Expr, right_scores: BK.Expr,
                        normalize: bool):
     if normalize:
         left_scores = BK.log_softmax(left_scores, -1)
         right_scores = BK.log_softmax(right_scores, -1)
     # pairwise adding
     score_shape = BK.get_shape(left_scores)
     pair_scores = left_scores.unsqueeze(-1) + right_scores.unsqueeze(
         -2)  # [*, slen_L, slen_R]
     flt_pair_scores = pair_scores.view(score_shape[:-1] +
                                        [-1])  # [*, slen*slen]
     # LR mask
     slen = score_shape[-1]
     arange_t = BK.arange_idx(slen)
     lr_mask = (arange_t.unsqueeze(-1) <=
                arange_t.unsqueeze(-2)).float().view(-1)  # [slen_L*slen_R]
     max_scores, max_idxes = (flt_pair_scores +
                              (1. - lr_mask) * Constants.REAL_PRAC_MIN).max(
                                  -1)  # [*]
     left_idxes, right_idxes = max_idxes // slen, max_idxes % slen  # [*]
     return max_scores, left_idxes, right_idxes
Ejemplo n.º 26
0
 def _go_common(self, res: SpanExtractorOutput, sel_mask: BK.Expr,
                add_gold_rate: float):
     gaddr_expr, span_mask = res.gaddr_expr, res.mask_expr
     bsize = BK.get_shape(span_mask, 0)
     # add gold?
     if add_gold_rate > 0.:  # inplace
         gold_mask = ((gaddr_expr >= 0) &
                      (BK.rand(sel_mask.shape) < add_gold_rate)
                      ).float()  # note: gaddr==-1 means nope
         sel_mask += gold_mask
         sel_mask.clamp_(max=1.)  # OR
     sel_mask *= span_mask  # must be valid
     # select masked
     final_idx_t, final_mask_t = BK.mask2idx(sel_mask,
                                             padding_idx=0)  # [bsize, ??]
     _tmp_arange_t = BK.arange_idx(bsize).unsqueeze(1)  # [bsize, 1]
     res.arrange(_tmp_arange_t, final_idx_t, final_mask_t)
     if res.gaddr_expr is not None:
         res.gaddr_expr.masked_fill_(final_mask_t == 0.,
                                     -1)  # make invalid ones -1
     return res  # [bsize, SNUM, *]
Ejemplo n.º 27
0
 def forward(self, input_expr, mask_expr, med: ZMediator):
     conf: MyTransformerConf = self.conf
     # for the L0 layer
     cur_expr = input_expr
     if conf.use_posi:
         ssize = BK.get_shape(input_expr, 1)  # step size
         posi_embed = self.PE(BK.arange_idx(ssize)).unsqueeze(0)  # [1, step, D]
         cur_expr = cur_expr + posi_embed
     add_expr, _ = med.layer_end(cur_expr, None)  # no checking at L0
     if add_expr is not None:
         cur_expr += add_expr  # norm right later!
     cur_expr = self.input_norm(cur_expr)
     # for later layers: L1+
     for ti, tnode in enumerate(self.tnodes):
         cur_expr, scores_t = tnode(cur_expr, mask_expr, med)
         add_expr, early_exit = med.layer_end(cur_expr, scores_t)  # check
         if add_expr is not None:
             cur_expr = tnode.norm1(cur_expr + add_expr)
         if early_exit:
             break
     return cur_expr, {}
Ejemplo n.º 28
0
 def forward_input(self, med: ZMediator, detach_scale: float, **kwargs):
     # get it
     if self.do_dsel:
         _dsel = self.dsel
         input_t0 = med.get_enc_cache_val(
             "hid", signature=_dsel.signature, function=(lambda x: _dsel.forward(x, med.ibatch.seq_info)))  # [*, ??, D]
     else:
         # input_t0 = med.get_enc_cache_val("hid", no_cache=True)  # [*, ??, D], note: no need for caching!
         input_t0 = med.get_enc_cache_val("hid")  # [*, ??, D]
     mask_t = med.get_mask(self.do_dsel)  # [*, ??]
     # extra processing?
     if self.do_seq_pool:
         input_t = self.seq_pool_f(input_t0)  # [*, D]
     elif self.do_seq_sel:
         _arange_t = BK.arange_idx(BK.get_shape(input_t0, 0))  # [*]
         _idx_t = med.get_cache(self.seq_sel_key)  # [*]
         input_t = input_t0[_arange_t, _idx_t]  # [*, D]
     else:
         input_t = input_t0
     # detach?
     ret_t = BK.go_detach(input_t, detach_scale, self.is_training())
     return ret_t, mask_t  # [*, (??), D], [*, ??]
Ejemplo n.º 29
0
 def __init__(self, berter: BertEncoder,
              seq_subs: List[InputSubwordSeqField]):
     self.seq_subs = seq_subs
     self.berter = berter
     self.bsize = len(seq_subs)
     self.arange1_t = BK.arange_idx(self.bsize)  # [bsize]
     self.arange2_t = self.arange1_t.unsqueeze(-1)  # [bsize, 1]
     self.arange3_t = self.arange2_t.unsqueeze(-1)  # [bsize, 1, 1]
     # --
     tokenizer = self.berter.tokenizer
     PAD_IDX = tokenizer.pad_token_id
     # MASK_IDX = tokenizer.mask_token_id
     # CLS_IDX_l = [tokenizer.cls_token_id]
     # SEP_IDX_l = [tokenizer.sep_token_id]
     # make batched idxes
     padder = DataPadder(2, pad_vals=PAD_IDX, mask_range=2)
     batched_sublens = [len(s.idxes) for s in seq_subs]  # [bsize]
     batched_input_ids, batched_input_mask = padder.pad(
         [s.idxes for s in seq_subs])  # [bsize, sub_len]
     self.batched_sublens_p1 = BK.input_idx(
         batched_sublens
     ) + 1  # also the idx of EOS (if counting including BOS)
     self.batched_input_ids = BK.input_idx(batched_input_ids)
     self.batched_input_mask = BK.input_real(batched_input_mask)
     # make batched mappings (sub->orig)
     padder2 = DataPadder(2, pad_vals=0,
                          mask_range=2)  # pad as 0 to avoid out-of-range
     batched_first_idxes, batched_first_mask = padder2.pad(
         [s.align_info.orig2begin for s in seq_subs])  # [bsize, orig_len]
     self.batched_first_idxes = BK.input_idx(batched_first_idxes)
     self.batched_first_mask = BK.input_real(batched_first_mask)
     # reversed batched_mappings (orig->sub) (created when needed)
     self._batched_rev_idxes = None  # [bsize, sub_len]
     # --
     self.batched_repl_masks = None  # [bsize, sub_len], to replace with MASK
     self.batched_token_type_ids = None  # [bsize, 1+sub_len+1]
     self.batched_position_ids = None  # [bsize, 1+sub_len+1]
     self.other_factors = {}  # name -> aug_batched_ids
Ejemplo n.º 30
0
 def _split_aggregate(self, cand_expr, cand_scores, cand_widxes,
                      seg_ext_cidxes, seg_ext_masks, topk: int):
     arange3_t = BK.arange_idx(
         seg_ext_cidxes.shape[0]).unsqueeze(-1).unsqueeze(-1)  # [*, 1, 1]
     seg_ext_scores = cand_scores[arange3_t, seg_ext_cidxes] + (
         1. - seg_ext_masks) * Constants.REAL_PRAC_MIN  # [*, seglen, MW]
     # if need further topk?
     if topk > 0 and BK.get_shape(seg_ext_scores,
                                  -1) > topk:  # need to further topk?
         seg_ext_scores, _tmp_idxes = seg_ext_scores.topk(
             topk, dim=-1, sorted=False)  # [*, seglen, K]
         seg_ext_cidxes = seg_ext_cidxes.gather(
             -1, _tmp_idxes)  # [*, seglen, K]
         seg_ext_masks = seg_ext_masks.gather(-1,
                                              _tmp_idxes)  # [*, seglen, K]
     # get expr and extend to full
     seg_ext_prob = seg_ext_scores.softmax(-1)  # [*, seglen, K]
     _tmp_expr = cand_expr[arange3_t, seg_ext_cidxes]  # [*, seglen, K, D]
     seg_weighted_expr = (_tmp_expr * seg_ext_prob.unsqueeze(-1)).sum(
         -2)  # [*, seglen, D]
     seg_ext_widxes = cand_widxes[arange3_t,
                                  seg_ext_cidxes]  # [*, seglen, K]
     return seg_ext_scores, seg_ext_cidxes, seg_ext_widxes, seg_ext_masks, seg_weighted_expr  # [*, seglen, ?]