def embed_lens(self, query_len: int, key_len: int): a_q = BK.arange_idx(query_len).unsqueeze(1) # [query, 1] a_k = BK.arange_idx(key_len).unsqueeze(0) # [1, key] ret_dist = a_q - a_k # [query, key] _, dist_atts, dist_values = self.embed_rposi(ret_dist) # [lenq, lenk], [lenq, lenk, dim], [lenq, lenk, dim] return ret_dist, dist_atts, dist_values
def lookup_flatten(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr, pair_expr: BK.Expr = None): arr_items, mlp_expr, zmask_expr, extra_info = self.lookup(insts, input_expr, mask_expr, pair_expr) expr_widxes, expr_wlens = extra_info # flatten ret_items, ret_sidx, ret_expr, fl_widxes, fl_wlens = LookupNode.flatten_results( arr_items, zmask_expr, mlp_expr, expr_widxes, expr_wlens) # -- # also make full expr # full_masks = ((_arange_t>=fl_widxes.unsqueeze(-1)) & (_arange_t<(fl_widxes+fl_wlens).unsqueeze(-1))).float() # [??, slen] # ret_full_expr = full_masks.unsqueeze(-1) * ret_expr.unsqueeze(-2) # [??, slen, D] if self.conf.flatten_lookup_use_dist: # use posi again: [...,-2,-1,0,0,0,1,2,...] left_widxes = fl_widxes.unsqueeze(-1) # [??, 1] right_widxes = (fl_widxes+fl_wlens-1).unsqueeze(-1) # [??, 1] _arange_t = BK.arange_idx(BK.get_shape(mask_expr, 1)).unsqueeze(0) # [1, slen] dist0 = _arange_t - left_widxes # [??, slen] dist1 = _arange_t - right_widxes # [??, slen] full_dist = (_arange_t < left_widxes).long() * dist0 + (_arange_t > right_widxes).long() * dist1 ret_full_expr = self.indicator_norm(self.indicator_embed(full_dist)) # [??, slen, D] # # ret_full_expr = self.indicator_embed(full_dist) # [??, slen, D] else: # otherwise 0/1 _arange_t = BK.arange_idx(BK.get_shape(mask_expr, 1)).unsqueeze(0) # [1, slen] full_ind = ((_arange_t>=fl_widxes.unsqueeze(-1)) & (_arange_t<(fl_widxes+fl_wlens).unsqueeze(-1))).long() # [??, slen] ret_full_expr = self.indicator_norm(self.indicator_embed(full_ind)) # [??, slen, D] # -- return ret_items, ret_sidx, ret_expr, ret_full_expr # [??, D]
def _common_prepare(self, input_shape: Tuple[int], _mask_f: Callable, gold_widx_expr: BK.Expr, gold_wlen_expr: BK.Expr, gold_addr_expr: BK.Expr): conf: SpanExtractorConf = self.conf min_width, max_width = conf.min_width, conf.max_width diff_width = max_width - min_width + 1 # number of width to extract # -- bsize, mlen = input_shape # -- # [bsize, mlen*(max_width-min_width)], mlen first (dim=1) # note: the spans are always sorted by (widx, wlen) _tmp_arange_t = BK.arange_idx(mlen * diff_width) # [mlen*dw] widx_t0 = (_tmp_arange_t // diff_width) # [mlen*dw] wlen_t0 = (_tmp_arange_t % diff_width) + min_width # [mlen*dw] mask_t0 = _mask_f(widx_t0, wlen_t0) # [bsize, mlen*dw] # -- # compacting (use mask2idx and gather) final_idx_t, final_mask_t = BK.mask2idx(mask_t0, padding_idx=0) # [bsize, ??] _tmp2_arange_t = BK.arange_idx(bsize).unsqueeze(1) # [bsize, 1] # no need to make valid for mask=0, since idx=0 means (0, min_width) # todo(+?): do we need to deal with empty ones here? ret_widx = widx_t0[final_idx_t] # [bsize, ??] ret_wlen = wlen_t0[final_idx_t] # [bsize, ??] # -- # prepare gold (as pointer-like addresses) if gold_addr_expr is not None: gold_t0 = BK.constants_idx((bsize, mlen * diff_width), -1) # [bsize, mlen*diff] # check valid of golds (flatten all) gold_valid_t = ((gold_addr_expr >= 0) & (gold_wlen_expr >= min_width) & (gold_wlen_expr <= max_width)) gold_valid_t = gold_valid_t.view(-1) # [bsize*_glen] _glen = BK.get_shape(gold_addr_expr, 1) flattened_bsize_t = BK.arange_idx( bsize * _glen) // _glen # [bsize*_glen] flattened_fidx_t = (gold_widx_expr * diff_width + gold_wlen_expr - min_width).view(-1) # [bsize*_glen] flattened_gaddr_t = gold_addr_expr.view(-1) # mask and assign gold_t0[flattened_bsize_t[gold_valid_t], flattened_fidx_t[gold_valid_t]] = flattened_gaddr_t[ gold_valid_t] ret_gaddr = gold_t0[_tmp2_arange_t, final_idx_t] # [bsize, ??] ret_gaddr.masked_fill_((final_mask_t == 0), -1) # make invalid ones -1 else: ret_gaddr = None # -- return ret_widx, ret_wlen, final_mask_t, ret_gaddr
def predict(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr, pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr = None): conf: AnchorExtractorConf = self.conf assert not lookup_flatten bsize, slen = BK.get_shape(mask_expr) # -- for inst in insts: # first clear things self.helper._clear_f(inst) # -- # step 1: simply labeling! best_labs, best_scores = self.lab_node.predict( input_expr, pair_expr, mask_expr, extra_score=external_extra_score) flt_items = self.helper.put_results(insts, best_labs, best_scores) # [?] # -- # step 2: final extend (in a flattened way) if len(flt_items) > 0 and conf.pred_ext: flt_mask = ((best_labs > 0) & (mask_expr > 0)) # [*, slen] flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[ flt_mask] # [?] flt_expr = input_expr[flt_mask] # [?, D] flt_full_expr = self._prepare_full_expr(flt_mask) # [?, slen, D] self.ext_node.predict(flt_items, input_expr[flt_sidx], flt_expr, flt_full_expr, mask_expr[flt_sidx]) # -- # extra: self.pp_node.prune(insts) return None
def forward(self, inputs, add_bos=False, add_eos=False): conf: PosiInputEmbedderConf = self.conf # -- try: # input is a shape as prepared by "PosiHelper" batch_size, max_len = inputs if add_bos: max_len += 1 if add_eos: max_len += 1 posi_idxes = BK.arange_idx(max_len) # [?len?] ret = self.E(posi_idxes).unsqueeze(0).expand(batch_size, -1, -1) except: # input is tensor posi_idxes = BK.input_idx(inputs) # [*, len] cur_maxlen = BK.get_shape(posi_idxes, -1) # -- all_input_slices = [] slice_shape = BK.get_shape(posi_idxes)[:-1] + [1] if add_bos: # add 0 and offset all_input_slices.append( BK.constants(slice_shape, 0, dtype=posi_idxes.dtype)) cur_maxlen += 1 posi_idxes += 1 all_input_slices.append(posi_idxes) # [*, len] if add_eos: all_input_slices.append( BK.constants(slice_shape, cur_maxlen, dtype=posi_idxes.dtype)) final_input_t = BK.concat(all_input_slices, -1) # [*, 1?+len+1?] # finally ret = self.E(final_input_t) # [*, ??, dim] return ret
def forward_embedding(self, input_ids, attention_mask, token_type_ids, position_ids, other_embeds): input_shape = input_ids.size() # [bsize, len] seq_length = input_shape[1] if position_ids is None: position_ids = BK.arange_idx(seq_length) # [len] position_ids = position_ids.unsqueeze(0).expand( input_shape) # [bsize, len] if token_type_ids is None: token_type_ids = BK.zeros(input_shape).long() # [bsize, len] # BertEmbeddings.forward _embeddings = self.model.embeddings inputs_embeds = _embeddings.word_embeddings( input_ids) # [bsize, len, D] position_embeddings = _embeddings.position_embeddings( position_ids) # [bsize, len, D] token_type_embeddings = _embeddings.token_type_embeddings( token_type_ids) # [bsize, len, D] embeddings = inputs_embeds + position_embeddings + token_type_embeddings if other_embeds is not None: embeddings += other_embeds embeddings = _embeddings.LayerNorm(embeddings) embeddings = _embeddings.dropout(embeddings) # prepare attention_mask if attention_mask is None: attention_mask = BK.constants(input_shape, value=1.) assert attention_mask.dim() == 2 extended_attention_mask = attention_mask[:, None, None, :] extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return embeddings, extended_attention_mask
def inference_forward(scores_t: BK.Expr, mat_t: BK.Expr, mask_t: BK.Expr, beam_k: int = 0): scores_shape = BK.get_shape(scores_t) # [*, slen, L] need_topk = (beam_k > 0) and (beam_k < scores_shape[-1] ) # whether we need topk # -- score_slices = split_at_dim(scores_t, -2, True) # List[*, 1, L] mask_slices = split_at_dim(mask_t, -1, True) # List[*, 1] # the loop on slen start_shape = scores_shape[:-2] + [1] # [*, 1] last_labs_t = BK.constants_idx(start_shape, 0) # [*, K], todo(note): start with 0! last_accu_scores = BK.zeros(start_shape) # accumulated scores: [*, K] last_potential = BK.zeros( start_shape) # accumulated potentials: [*, K] full_labs_t = BK.arange_idx(scores_shape[-1]).view( [1] * (len(scores_shape) - 2) + [-1]) # [*, L] cur_step = 0 for one_score_slice, one_mask_slice in zip(score_slices, mask_slices): # [*,L],[*,1] one_mask_slice_neg = 1. - one_mask_slice # [*,1] # get current scores if cur_step == 0: # no transition at start! one_cur_scores = one_score_slice # [*, 1, L] else: one_cur_scores = one_score_slice + mat_t[ last_labs_t] # [*, K, L] # first for potentials expanded_potentials = last_potential.unsqueeze( -1) + one_cur_scores # [*, K, L] merged_potentials = log_sum_exp(expanded_potentials, -2) # [*, L] # optional for topk with merging; note: not really topk!! if need_topk: # todo(+W): another option is to directly select with potentials rather than accu_scores expanded_scores = last_accu_scores.unsqueeze( -1) + one_cur_scores # [*, K, L] # max at -2, merge same current label max_scores, max_idxes = expanded_scores.max(-2) # [*, L] # topk at current step, no need to sort! new_accu_scores, new_labs_t = max_scores.topk( beam_k, -1, sorted=False) # [*, K] new_potential = merged_potentials.gather(-1, new_labs_t) # [*, K] # mask and update last_potential = last_potential * one_mask_slice_neg + new_potential * one_mask_slice # [*, K] last_accu_scores = last_accu_scores * one_mask_slice_neg + new_accu_scores * one_mask_slice # [*, K] last_labs_t = last_labs_t * one_mask_slice_neg.long( ) + new_labs_t * one_mask_slice.long() # [*, K] else: # mask and update last_potential = last_potential * one_mask_slice_neg + merged_potentials * one_mask_slice # [*, L(K)] # note: still need to mask this! last_labs_t = last_labs_t * one_mask_slice_neg.long( ) + full_labs_t * one_mask_slice.long() cur_step += 1 # finally sum all ret_potential = log_sum_exp(last_potential, -1) # [*] return ret_potential
def _prepare_full_expr(self, flt_mask: BK.Expr): bsize, slen = BK.get_shape(flt_mask) arange2_t = BK.arange_idx(slen).unsqueeze(0) # [1, slen] all_widxes = arange2_t.expand_as(flt_mask)[flt_mask] # [?] tmp_idxes = BK.zeros([len(all_widxes), slen]).long() # [?, slen] tmp_idxes.scatter_(-1, all_widxes.unsqueeze(-1), 1) # [?, slen] tmp_embs = self.indicator_embed(tmp_idxes) # [?, slen, D] return tmp_embs
def flatten_results(arr_items: np.ndarray, mask_expr: BK.Expr, *other_exprs: BK.Expr): sel_mask_expr = (mask_expr > 0.) # flatten first dims ret_items = [z for z in arr_items.flatten() if z is not None] ret_other_exprs = [z[sel_mask_expr] for z in other_exprs] # [?(flat), D] ret_sidx = BK.arange_idx(BK.get_shape(mask_expr, 0)).unsqueeze(-1).expand_as(sel_mask_expr)[sel_mask_expr] assert all(len(ret_items) == len(z) for z in ret_other_exprs), "Error: dim0 not matched after flatten!" return ret_items, ret_sidx, *ret_other_exprs # [?(flat), *]
def prepare_indicators(self, flat_idxes: List, shape): bs, dlen = shape _arange_t = BK.arange_idx(bs) # [*] rets = [] for one_idxes in flat_idxes: one_indicator = BK.constants_idx(shape, 0) # [*, dlen] one_indicator[_arange_t, one_idxes] = 1 rets.append(one_indicator) return rets
def _loss_feed_split(self, mask_expr, split_scores, pred_split_decisions, cand_widxes, cand_masks, cand_expr, cand_scores, expr_seq_gaddr): conf: SoftExtractorConf = self.conf bsize, slen = BK.get_shape(mask_expr) arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [*, 1, 1] # -- # step 2.1: split loss (only on good points (excluding -1|-1 or paddings) with dynamic oracle) cand_gaddr = expr_seq_gaddr[arange2_t, cand_widxes] # [*, clen] cand_gaddr0, cand_gaddr1 = cand_gaddr[:, : -1], cand_gaddr[:, 1:] # [*, clen-1] split_oracle = (cand_gaddr0 != cand_gaddr1).float() * cand_masks[:, 1:] # [*, clen-1] split_oracle_mask = ( (cand_gaddr0 >= 0) | (cand_gaddr1 >= 0)).float() * cand_masks[:, 1:] # [*, clen-1] raw_split_loss = BK.loss_binary( split_scores, split_oracle, label_smoothing=conf.split_label_smoothing) # [*, slen] loss_split_item = LossHelper.compile_leaf_loss( f"split", (raw_split_loss * split_oracle_mask).sum(), split_oracle_mask.sum(), loss_lambda=conf.loss_split) # step 2.2: feed split # note: when teacher-forcing, only forcing good points, others still use pred force_split_decisions = split_oracle_mask * split_oracle + ( 1. - split_oracle_mask) * pred_split_decisions # [*, clen-1] _use_force_mask = (BK.rand([bsize]) <= conf.split_feed_force_rate).float().unsqueeze( -1) # [*, 1], seq-level feed_split_decisions = (_use_force_mask * force_split_decisions + (1. - _use_force_mask) * pred_split_decisions ) # [*, clen-1] # next # *[*, seglen, MW], [*, seglen] seg_ext_cidxes, seg_ext_masks, seg_masks = self._split_extend( feed_split_decisions, cand_masks) seg_ext_scores, seg_ext_cidxes, seg_ext_widxes, seg_ext_masks, seg_weighted_expr = self._split_aggregate( cand_expr, cand_scores, cand_widxes, seg_ext_cidxes, seg_ext_masks, conf.split_topk) # [*, seglen, ?] # finally get oracles for next steps # todo(+N): simply select the highest scored one as oracle if BK.is_zero_shape(seg_ext_scores): # simply make them all -1 oracle_gaddr = BK.constants_idx(seg_masks.shape, -1) # [*, seglen] else: _, _seg_max_t = seg_ext_scores.max(-1, keepdim=True) # [*, seglen, 1] oracle_widxes = seg_ext_widxes.gather(-1, _seg_max_t).squeeze( -1) # [*, seglen] oracle_gaddr = expr_seq_gaddr.gather(-1, oracle_widxes) # [*, seglen] oracle_gaddr[seg_masks <= 0] = -1 # (assign invalid ones) [*, seglen] return loss_split_item, seg_masks, seg_ext_widxes, seg_ext_masks, seg_weighted_expr, oracle_gaddr
def get_last_emb(self): k = "last_emb" ret = self.l_caches.get(k) if ret is None: ret = self.embs[-1] valid_idxes_t = self.valid_idxes_t if valid_idxes_t is not None: arange2_t = BK.arange_idx(BK.get_shape(valid_idxes_t, 0)).unsqueeze(-1) # [bsize, 1] ret = ret[arange2_t, valid_idxes_t] # select! self.l_caches[k] = ret # cache return ret
def get_stack_att(self): k = "stack_att" ret = self.l_caches.get(k) if ret is None: ret = BK.stack(self.attns, -1).permute(0,2,3,4,1) # NL*[*, H, lenq, lenk] -> [*, lenq, lenk, NL, H] valid_idxes_t = self.valid_idxes_t if valid_idxes_t is not None: arange3_t = BK.arange_idx(BK.get_shape(valid_idxes_t, 0)).unsqueeze(-1).unsqueeze(-1) # [bsize, 1, 1] ret = ret[arange3_t, valid_idxes_t.unsqueeze(-1), valid_idxes_t.unsqueeze(-2)] # select! self.l_caches[k] = ret # cache return ret
def get_stack_emb(self): k = "stack_emb" ret = self.l_caches.get(k) if ret is None: # note: excluding embeddings here to make it convenient!! ret = BK.stack(self.embs[1:], -1) # [*, slen, D, NL] valid_idxes_t = self.valid_idxes_t if valid_idxes_t is not None: arange2_t = BK.arange_idx(BK.get_shape(valid_idxes_t, 0)).unsqueeze(-1) # [bsize, 1] ret = ret[arange2_t, valid_idxes_t] # select! self.l_caches[k] = ret # cache return ret
def forward(self, input_expr: BK.Expr, widx_expr: BK.Expr, wlen_expr: BK.Expr): conf: BaseSpanConf = self.conf # -- # note: check empty, otherwise error input_item_shape = BK.get_shape(widx_expr) if np.prod(input_item_shape) == 0: return BK.zeros(input_item_shape + [self.output_dim]) # return an empty but shaped tensor # -- start_idxes, end_idxes = widx_expr, widx_expr+wlen_expr # make [start, end) # get sizes bsize, slen = BK.get_shape(input_expr)[:2] # num_span = BK.get_shape(start_idxes, 1) arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] # -- reprs = [] if conf.use_starts: # start [start, reprs.append(input_expr[arange2_t, start_idxes]) # [bsize, ?, D] if conf.use_ends: # simply ,end-1] reprs.append(input_expr[arange2_t, end_idxes-1]) if conf.use_softhead: # expand range all_span_idxes, all_span_mask = expand_ranged_idxes(widx_expr, wlen_expr, 0, None) # [bsize, ?, MW] # flatten flatten_all_span_idxes = all_span_idxes.view(bsize, -1) # [bsize, ?*MW] flatten_all_span_mask = all_span_mask.view(bsize, -1) # [bsize, ?*MW] # get softhead score (consider mask here) softhead_scores = self.softhead_scorer(input_expr).squeeze(-1) # [bsize, slen] flatten_all_span_scores = softhead_scores[arange2_t, flatten_all_span_idxes] # [bsize, ?*MW] flatten_all_span_scores += (1.-flatten_all_span_mask) * Constants.REAL_PRAC_MIN all_span_scores = flatten_all_span_scores.view(all_span_idxes.shape) # [bsize, ?, MW] # reshape and (optionally topk) and softmax softhead_topk = conf.softhead_topk if softhead_topk>0 and BK.get_shape(all_span_scores,-1)>softhead_topk: # further select topk; note: this may save mem final_span_score, _tmp_idxes = all_span_scores.topk(softhead_topk, dim=-1, sorted=False) # [bsize, ?, K] final_span_idxes = all_span_idxes.gather(-1, _tmp_idxes) # [bsize, ?, K] else: final_span_score, final_span_idxes = all_span_scores, all_span_idxes # [bsize, ?, MW] final_prob = final_span_score.softmax(-1) # [bsize, ?, ??] # [bsize, ?, ??, D] final_repr = input_expr[arange2_t, final_span_idxes.view(bsize, -1)].view(BK.get_shape(final_span_idxes)+[-1]) weighted_repr = (final_repr * final_prob.unsqueeze(-1)).sum(-2) # [bsize, ?, D] reprs.append(weighted_repr) if conf.use_width: cur_width_embed = self.width_embed(wlen_expr) # [bsize, ?, DE] reprs.append(cur_width_embed) # concat concat_repr = BK.concat(reprs, -1) # [bsize, ?, SUM] if conf.use_proj: ret = self.final_proj(concat_repr) # [bsize, ?, DR] else: ret = concat_repr return ret
def __init__(self, ibatch: InputBatch, IDX_PAD: int): # preps self.bsize = len(ibatch) self.arange1_t = BK.arange_idx(self.bsize) # [bsize] self.arange2_t = self.arange1_t.unsqueeze(-1) # [bsize, 1] self.arange3_t = self.arange2_t.unsqueeze(-1) # [bsize, 1, 1] # batched them all_seq_infos = [z.seq_info for z in ibatch.items] # enc: [*, len_enc]: ids(pad IDX_PAD), masks, segids(pad 0) self.enc_input_ids = BK.input_idx( DataPadder.go_batch_2d([z.enc_input_ids for z in all_seq_infos], int(IDX_PAD))) self.enc_input_masks = BK.input_real( DataPadder.lengths2mask( [len(z.enc_input_ids) for z in all_seq_infos])) self.enc_input_segids = BK.input_idx( DataPadder.go_batch_2d([z.enc_input_segids for z in all_seq_infos], 0)) # dec: [*, len_dec]: sel_idxes(pad 0), sel_lens(pad 1), masks, sent_idxes(pad ??) self.dec_sel_idxes = BK.input_idx( DataPadder.go_batch_2d([z.dec_sel_idxes for z in all_seq_infos], 0)) self.dec_sel_lens = BK.input_idx( DataPadder.go_batch_2d([z.dec_sel_lens for z in all_seq_infos], 1)) self.dec_sel_masks = BK.input_real( DataPadder.lengths2mask( [len(z.dec_sel_idxes) for z in all_seq_infos])) _max_dec_len = BK.get_shape(self.dec_sel_masks, 1) _dec_offsets = BK.input_idx( DataPadder.go_batch_2d([z.dec_offsets for z in all_seq_infos], _max_dec_len)) # note: CLS as -1, then 0,1,2,..., PAD gets -2! self.dec_sent_idxes = \ (BK.arange_idx(_max_dec_len).unsqueeze(0).unsqueeze(-1) >= _dec_offsets.unsqueeze(-2)).sum(-1).long() - 1 self.dec_sent_idxes[self.dec_sel_masks <= 0.] = -2 # dec -> enc: [*, len_enc] (calculated on needed!) # note: require 1-to-1 mapping (except pads)!! self._enc_back_hits = None self._enc_back_sel_idxes = None
def _aggregate_subtoks(self, repr_t: BK.Expr, dsel_seq_info): conf: DSelectorConf = self.conf _arange_t, _sel_t, _len_t = dsel_seq_info.arange2_t, dsel_seq_info.dec_sel_idxes, dsel_seq_info.dec_sel_lens _max_len = 1 if BK.is_zero_shape(_len_t) else _len_t.max().item() _max_len = max(1, min(conf.dsel_max_subtoks, _max_len)) # truncate # -- _tmp_arange_t = BK.arange_idx(_max_len) # [M] _all_valids_t = (_tmp_arange_t < _len_t.unsqueeze(-1)).float() # [*, dlen, M] _tmp_arange_t = _tmp_arange_t * _all_valids_t.long() # note: pad as 0 _all_idxes_t = _sel_t.unsqueeze(-1) + _tmp_arange_t # [*, dlen, M] _all_repr_t = repr_t[_arange_t.unsqueeze(-1), _all_idxes_t] # [*, dlen, M, D] while len(BK.get_shape(_all_valids_t)) < len(BK.get_shape(_all_repr_t)): _all_valids_t = _all_valids_t.unsqueeze(-1) _all_repr_t = _all_repr_t * _all_valids_t return _all_repr_t, _all_valids_t
def s0_open_new_steps(self, bsize: int, ssize: int, mask: BK.Expr = None): assert ssize > 0 assert self._cur_layer_idx == -1 self._cur_layer_idx = 0 # -- new_mask = BK.constants([bsize, ssize], 1.) if mask is None else mask # [*, ssize] # -- # prepare for store_lstate selecting if len(self.cum_state_lset) > 0: # any layer need to accumulat? self._arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] # note: if no last state, simply clamp 0, otherwise, offset by 1 since we will concat later self._arange_sel_t = mask2posi_padded(new_mask, 0, 0) if mask is None else mask2posi_padded(new_mask, 1, 0) # prev_steps = self.steps # previous accumulated steps self.steps += ssize self.mask = new_mask if self.mask is None else BK.concat([self.mask, new_mask], 1) # [*, old+new] self.positions = mask2posi(self.mask, offset=-1, cmin=0) # [*, old+new], recalculate!!
def mask2posi_padded(mask: BK.Expr, offset: int, cmin: int): with BK.no_grad_env(): bsize, ssize = BK.get_shape(mask) ret = BK.arange_idx(ssize).repeat(bsize, 1) # [1, ssize] rmask_long_t = (mask == 0.).long() # reverse-mask [bsize, ssize] conti_zeros = BK.constants_idx([bsize], 0) # [bsize], number of continous zeros for sidx in range(ssize): slice = rmask_long_t[:, sidx] # [bsize] conti_zeros = (conti_zeros + slice) * slice # [bsize], *slice to reset ret[:, sidx] -= conti_zeros # -- ret += offset ret.clamp_(min=cmin) return ret
def forward(self, input_expr, mask_expr=None, **kwargs): conf: TTransformerConf = self.conf # -- if conf.n_layers == 0: return input_expr # change nothing if no layers # -- if conf.use_posi: ssize = BK.get_shape(input_expr, 1) # step size posi_embed = self.PE(BK.arange_idx(ssize)).unsqueeze(0) # [1, step, D] input_x = input_expr + posi_embed else: input_x = input_expr if conf.norm_input: input_x = self.norm(input_x) output = self.enc(input_x.transpose(0,1), src_key_padding_mask=(mask_expr>0)).transpose(0,1).contiguous() return output
def expand_ranged_idxes(widx_t: BK.Expr, wlen_t: BK.Expr, pad: int = 0, max_width: int = None): if max_width is None: # if not provided if BK.is_zero_shape(wlen_t): max_width = 1 # at least one else: max_width = wlen_t.max().item() # overall max width # -- input_shape = BK.get_shape(widx_t) # [*] mw_range_t = BK.arange_idx(max_width).view([1] * len(input_shape) + [-1]) # [*, MW] expanded_idxes = widx_t.unsqueeze(-1) + mw_range_t # [*, MW] expanded_masks_bool = (mw_range_t < wlen_t.unsqueeze(-1)) # [*, MW] expanded_idxes.masked_fill_(~expanded_masks_bool, pad) # [*, MW] return expanded_idxes, expanded_masks_bool.float()
def _f(_widx, _wlen): # [mlen*dw] -> [bsize, mlen*dw] _mw = self.conf.max_width # todo(note): also rely on this _bsize = BK.get_shape(input_mask, 0) _padded_input_mask = BK.pad( input_mask, (0, _mw), value=0.) # [bsize, mlen+_mw], make idxing valid!! _er_idxes, _er_masks = expand_ranged_idxes(_widx, _wlen, 0, _mw) # [mlen*dw, MW] _arange_t = BK.arange_idx(_bsize).unsqueeze(-1).unsqueeze( -1) # [bsize, 1, 1] _idx_valid = _padded_input_mask[_arange_t, _er_idxes.unsqueeze( 0)] # [bsize, mlen*dw, MW] _idx_valid.masked_fill_((_er_masks == 0.).unsqueeze(0), 1.) # make sure paddings get 1 _ret = _idx_valid.prod( -1) # [bsize, mlen*dw], require all non-pad ones to be valid! return _ret
def forward(self, med: ZMediator, **kwargs): conf: IdecConnectorPlainConf = self.conf # -- if self.do_seq_pool: # note: for pooling, use the raw emb!! mixed_emb0 = self._go_detach(med.get_raw_last_emb()) # [*, ??, D] mixed_emb = self.pool_f(mixed_emb0) # [*, D] else: if conf.use_nlayer == 1: # simply get the last one mixed_emb = self._go_detach(med.get_last_emb()) else: # mix them stacked_embs = self._go_detach(med.get_stack_emb( ))[:, :, :, -len(self.mixed_weights):] # [*, slen, D, NL] mixed_emb = BK.matmul( stacked_embs, BK.softmax(self.mixed_weights, -1).unsqueeze(-1)).squeeze(-1) # [*, slen, D] if self.do_seq_sel: _arange_t = BK.arange_idx(BK.get_shape(mixed_emb, 0)) _idx_t = med.get_cache(conf.seq_sel_key) mixed_emb = mixed_emb[_arange_t, _idx_t] # [*, D] # further affine if self.input_mask is not None: # note: special input mask!! mixed_emb = mixed_emb * self.input_mask.detach( ) # no grad for input_mask!! drop_emb = self.pre_mid_drop(mixed_emb) if conf.mid_dim > 0: # gather inputs _r = conf.mid_extra_range _detached_drop_emb = drop_emb.detach() _inputs = [] for ii in range(-_r, _r + 1): if ii < 0: _one = BK.pad(_detached_drop_emb[:, :ii], [0, 0, -ii, 0]) elif ii == 0: _one = drop_emb # no need to change! else: _one = BK.pad(_detached_drop_emb[:, ii:], [0, 0, 0, ii]) _inputs.append(_one) # -- ret_t = self.mid_aff(_inputs) # [*, slen, M] or [*, M] else: ret_t = drop_emb return ret_t
def forward(self, input_expr, mask_expr, med: ZMediator): conf: MyTransformerConf = self.conf med.start(mask_expr) # for the L0 layer cur_expr = med.forw_emb(input_expr, norm_node=None) # norm right later if conf.use_posi: ssize = BK.get_shape(input_expr, 1) # step size posi_embed = self.PE(BK.arange_idx(ssize)).unsqueeze(0) # [1, step, D] cur_expr = self.input_norm(cur_expr + posi_embed) else: cur_expr = self.input_norm(cur_expr) # L1+ for ti, tnode in enumerate(self.tnodes): med.next() cur_expr = tnode(cur_expr, mask_expr, med) if med.is_end(): break # clean med.end() return cur_expr
def decode_with_scores(left_scores: BK.Expr, right_scores: BK.Expr, normalize: bool): if normalize: left_scores = BK.log_softmax(left_scores, -1) right_scores = BK.log_softmax(right_scores, -1) # pairwise adding score_shape = BK.get_shape(left_scores) pair_scores = left_scores.unsqueeze(-1) + right_scores.unsqueeze( -2) # [*, slen_L, slen_R] flt_pair_scores = pair_scores.view(score_shape[:-1] + [-1]) # [*, slen*slen] # LR mask slen = score_shape[-1] arange_t = BK.arange_idx(slen) lr_mask = (arange_t.unsqueeze(-1) <= arange_t.unsqueeze(-2)).float().view(-1) # [slen_L*slen_R] max_scores, max_idxes = (flt_pair_scores + (1. - lr_mask) * Constants.REAL_PRAC_MIN).max( -1) # [*] left_idxes, right_idxes = max_idxes // slen, max_idxes % slen # [*] return max_scores, left_idxes, right_idxes
def _go_common(self, res: SpanExtractorOutput, sel_mask: BK.Expr, add_gold_rate: float): gaddr_expr, span_mask = res.gaddr_expr, res.mask_expr bsize = BK.get_shape(span_mask, 0) # add gold? if add_gold_rate > 0.: # inplace gold_mask = ((gaddr_expr >= 0) & (BK.rand(sel_mask.shape) < add_gold_rate) ).float() # note: gaddr==-1 means nope sel_mask += gold_mask sel_mask.clamp_(max=1.) # OR sel_mask *= span_mask # must be valid # select masked final_idx_t, final_mask_t = BK.mask2idx(sel_mask, padding_idx=0) # [bsize, ??] _tmp_arange_t = BK.arange_idx(bsize).unsqueeze(1) # [bsize, 1] res.arrange(_tmp_arange_t, final_idx_t, final_mask_t) if res.gaddr_expr is not None: res.gaddr_expr.masked_fill_(final_mask_t == 0., -1) # make invalid ones -1 return res # [bsize, SNUM, *]
def forward(self, input_expr, mask_expr, med: ZMediator): conf: MyTransformerConf = self.conf # for the L0 layer cur_expr = input_expr if conf.use_posi: ssize = BK.get_shape(input_expr, 1) # step size posi_embed = self.PE(BK.arange_idx(ssize)).unsqueeze(0) # [1, step, D] cur_expr = cur_expr + posi_embed add_expr, _ = med.layer_end(cur_expr, None) # no checking at L0 if add_expr is not None: cur_expr += add_expr # norm right later! cur_expr = self.input_norm(cur_expr) # for later layers: L1+ for ti, tnode in enumerate(self.tnodes): cur_expr, scores_t = tnode(cur_expr, mask_expr, med) add_expr, early_exit = med.layer_end(cur_expr, scores_t) # check if add_expr is not None: cur_expr = tnode.norm1(cur_expr + add_expr) if early_exit: break return cur_expr, {}
def forward_input(self, med: ZMediator, detach_scale: float, **kwargs): # get it if self.do_dsel: _dsel = self.dsel input_t0 = med.get_enc_cache_val( "hid", signature=_dsel.signature, function=(lambda x: _dsel.forward(x, med.ibatch.seq_info))) # [*, ??, D] else: # input_t0 = med.get_enc_cache_val("hid", no_cache=True) # [*, ??, D], note: no need for caching! input_t0 = med.get_enc_cache_val("hid") # [*, ??, D] mask_t = med.get_mask(self.do_dsel) # [*, ??] # extra processing? if self.do_seq_pool: input_t = self.seq_pool_f(input_t0) # [*, D] elif self.do_seq_sel: _arange_t = BK.arange_idx(BK.get_shape(input_t0, 0)) # [*] _idx_t = med.get_cache(self.seq_sel_key) # [*] input_t = input_t0[_arange_t, _idx_t] # [*, D] else: input_t = input_t0 # detach? ret_t = BK.go_detach(input_t, detach_scale, self.is_training()) return ret_t, mask_t # [*, (??), D], [*, ??]
def __init__(self, berter: BertEncoder, seq_subs: List[InputSubwordSeqField]): self.seq_subs = seq_subs self.berter = berter self.bsize = len(seq_subs) self.arange1_t = BK.arange_idx(self.bsize) # [bsize] self.arange2_t = self.arange1_t.unsqueeze(-1) # [bsize, 1] self.arange3_t = self.arange2_t.unsqueeze(-1) # [bsize, 1, 1] # -- tokenizer = self.berter.tokenizer PAD_IDX = tokenizer.pad_token_id # MASK_IDX = tokenizer.mask_token_id # CLS_IDX_l = [tokenizer.cls_token_id] # SEP_IDX_l = [tokenizer.sep_token_id] # make batched idxes padder = DataPadder(2, pad_vals=PAD_IDX, mask_range=2) batched_sublens = [len(s.idxes) for s in seq_subs] # [bsize] batched_input_ids, batched_input_mask = padder.pad( [s.idxes for s in seq_subs]) # [bsize, sub_len] self.batched_sublens_p1 = BK.input_idx( batched_sublens ) + 1 # also the idx of EOS (if counting including BOS) self.batched_input_ids = BK.input_idx(batched_input_ids) self.batched_input_mask = BK.input_real(batched_input_mask) # make batched mappings (sub->orig) padder2 = DataPadder(2, pad_vals=0, mask_range=2) # pad as 0 to avoid out-of-range batched_first_idxes, batched_first_mask = padder2.pad( [s.align_info.orig2begin for s in seq_subs]) # [bsize, orig_len] self.batched_first_idxes = BK.input_idx(batched_first_idxes) self.batched_first_mask = BK.input_real(batched_first_mask) # reversed batched_mappings (orig->sub) (created when needed) self._batched_rev_idxes = None # [bsize, sub_len] # -- self.batched_repl_masks = None # [bsize, sub_len], to replace with MASK self.batched_token_type_ids = None # [bsize, 1+sub_len+1] self.batched_position_ids = None # [bsize, 1+sub_len+1] self.other_factors = {} # name -> aug_batched_ids
def _split_aggregate(self, cand_expr, cand_scores, cand_widxes, seg_ext_cidxes, seg_ext_masks, topk: int): arange3_t = BK.arange_idx( seg_ext_cidxes.shape[0]).unsqueeze(-1).unsqueeze(-1) # [*, 1, 1] seg_ext_scores = cand_scores[arange3_t, seg_ext_cidxes] + ( 1. - seg_ext_masks) * Constants.REAL_PRAC_MIN # [*, seglen, MW] # if need further topk? if topk > 0 and BK.get_shape(seg_ext_scores, -1) > topk: # need to further topk? seg_ext_scores, _tmp_idxes = seg_ext_scores.topk( topk, dim=-1, sorted=False) # [*, seglen, K] seg_ext_cidxes = seg_ext_cidxes.gather( -1, _tmp_idxes) # [*, seglen, K] seg_ext_masks = seg_ext_masks.gather(-1, _tmp_idxes) # [*, seglen, K] # get expr and extend to full seg_ext_prob = seg_ext_scores.softmax(-1) # [*, seglen, K] _tmp_expr = cand_expr[arange3_t, seg_ext_cidxes] # [*, seglen, K, D] seg_weighted_expr = (_tmp_expr * seg_ext_prob.unsqueeze(-1)).sum( -2) # [*, seglen, D] seg_ext_widxes = cand_widxes[arange3_t, seg_ext_cidxes] # [*, seglen, K] return seg_ext_scores, seg_ext_cidxes, seg_ext_widxes, seg_ext_masks, seg_weighted_expr # [*, seglen, ?]