def forward(self, inputs, add_bos=False, add_eos=False): conf: PosiInputEmbedderConf = self.conf # -- try: # input is a shape as prepared by "PosiHelper" batch_size, max_len = inputs if add_bos: max_len += 1 if add_eos: max_len += 1 posi_idxes = BK.arange_idx(max_len) # [?len?] ret = self.E(posi_idxes).unsqueeze(0).expand(batch_size, -1, -1) except: # input is tensor posi_idxes = BK.input_idx(inputs) # [*, len] cur_maxlen = BK.get_shape(posi_idxes, -1) # -- all_input_slices = [] slice_shape = BK.get_shape(posi_idxes)[:-1] + [1] if add_bos: # add 0 and offset all_input_slices.append( BK.constants(slice_shape, 0, dtype=posi_idxes.dtype)) cur_maxlen += 1 posi_idxes += 1 all_input_slices.append(posi_idxes) # [*, len] if add_eos: all_input_slices.append( BK.constants(slice_shape, cur_maxlen, dtype=posi_idxes.dtype)) final_input_t = BK.concat(all_input_slices, -1) # [*, 1?+len+1?] # finally ret = self.E(final_input_t) # [*, ??, dim] return ret
def forward(self, inputs, add_bos=False, add_eos=False): conf: PlainInputEmbedderConf = self.conf # -- voc = self.voc input_t = BK.input_idx(inputs) # [*, len] # rare unk in training if self.is_training() and self.use_rare_unk: rare_unk_rate = conf.rare_unk_rate cur_unk_imask = ( self.rare_unk_mask[input_t] * (BK.rand(BK.get_shape(input_t)) < rare_unk_rate)).long() input_t = input_t * (1 - cur_unk_imask) + voc.unk * cur_unk_imask # bos and eos all_input_slices = [] slice_shape = BK.get_shape(input_t)[:-1] + [1] if add_bos: all_input_slices.append( BK.constants(slice_shape, voc.bos, dtype=input_t.dtype)) all_input_slices.append(input_t) # [*, len] if add_eos: all_input_slices.append( BK.constants(slice_shape, voc.eos, dtype=input_t.dtype)) final_input_t = BK.concat(all_input_slices, -1) # [*, 1?+len+1?] # finally ret = self.E(final_input_t) # [*, ??, dim] return ret
def _aug_ends( self, t: BK.Expr, BOS, PAD, EOS, dtype ): # add BOS(CLS) and EOS(SEP) for a tensor (sub_len -> 1+sub_len+1) slice_shape = [self.bsize, 1] slices = [ BK.constants(slice_shape, BOS, dtype=dtype), t, BK.constants(slice_shape, PAD, dtype=dtype) ] aug_batched_ids = BK.concat(slices, -1) # [bsize, 1+sub_len+1] aug_batched_ids[self.arange1_t, self.batched_sublens_p1] = EOS # assign EOS return aug_batched_ids
def _get_extra_score(self, cand_score, insts, cand_res, arr_gold_items, use_cons: bool, use_lu: bool): # conf: DirectExtractorConf = self.conf # -- # first cand score cand_score = self._extend_cand_score(cand_score) # then cons_lex score cons_lex_node = self.cons_lex_node if use_cons and cons_lex_node is not None: cons_lex = cons_lex_node.cons flt_arr_gold_items = arr_gold_items.flatten() _shape = BK.get_shape(cand_res.mask_expr) if cand_res.gaddr_expr is None: gaddr_expr = BK.constants(_shape, -1, dtype=BK.long) else: gaddr_expr = cand_res.gaddr_expr all_arrs = [BK.get_value(z) for z in [cand_res.widx_expr, cand_res.wlen_expr, cand_res.mask_expr, gaddr_expr]] arr_feats = np.full(_shape, None, dtype=object) for bidx, inst in enumerate(insts): one_arr_feats = arr_feats[bidx] _ii = -1 for one_widx, one_wlen, one_mask, one_gaddr in zip(*[z[bidx] for z in all_arrs]): _ii += 1 if one_mask == 0.: continue # skip invlaid ones if use_lu and one_gaddr>=0: one_feat = cons_lex.lu2feat(flt_arr_gold_items[one_gaddr].info["luName"]) else: one_feat = cons_lex.span2feat(inst, one_widx, one_wlen) one_arr_feats[_ii] = one_feat cons_valids = cons_lex_node.lookup_with_feats(arr_feats) cons_score = (1.-cons_valids) * Constants.REAL_PRAC_MIN else: cons_score = None # sum return self._sum_scores(cand_score, cons_score)
def forward_embedding(self, input_ids, attention_mask, token_type_ids, position_ids, other_embeds): input_shape = input_ids.size() # [bsize, len] seq_length = input_shape[1] if position_ids is None: position_ids = BK.arange_idx(seq_length) # [len] position_ids = position_ids.unsqueeze(0).expand( input_shape) # [bsize, len] if token_type_ids is None: token_type_ids = BK.zeros(input_shape).long() # [bsize, len] # BertEmbeddings.forward _embeddings = self.model.embeddings inputs_embeds = _embeddings.word_embeddings( input_ids) # [bsize, len, D] position_embeddings = _embeddings.position_embeddings( position_ids) # [bsize, len, D] token_type_embeddings = _embeddings.token_type_embeddings( token_type_ids) # [bsize, len, D] embeddings = inputs_embeds + position_embeddings + token_type_embeddings if other_embeds is not None: embeddings += other_embeds embeddings = _embeddings.LayerNorm(embeddings) embeddings = _embeddings.dropout(embeddings) # prepare attention_mask if attention_mask is None: attention_mask = BK.constants(input_shape, value=1.) assert attention_mask.dim() == 2 extended_attention_mask = attention_mask[:, None, None, :] extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return embeddings, extended_attention_mask
def _determine_size(self, length: BK.Expr, rate: float, count: float): if rate is None: ret_size = BK.constants(length.shape, count) # make it constant else: ret_size = (length * rate) if count is not None: ret_size.clamp_(max=count) ret_size.ceil_() return ret_size
def forward(self, inputs, add_bos=False, add_eos=False): conf: CharCnnInputEmbedderConf = self.conf # -- voc = self.voc char_input_t = BK.input_idx(inputs) # [*, len] # todo(note): no need for replacing to unk for char!! # bos and eos all_input_slices = [] slice_shape = BK.get_shape(char_input_t) slice_shape[-2] = 1 # [*, 1, clen] if add_bos: all_input_slices.append( BK.constants(slice_shape, voc.bos, dtype=char_input_t.dtype)) all_input_slices.append(char_input_t) # [*, len, clen] if add_eos: all_input_slices.append( BK.constants(slice_shape, voc.eos, dtype=char_input_t.dtype)) final_input_t = BK.concat(all_input_slices, -2) # [*, 1?+len+1?, clen] # char embeddings char_embed_expr = self.E(final_input_t) # [*, ??, dim] # char cnn ret = self.cnn(char_embed_expr) return ret
def s0_open_new_steps(self, bsize: int, ssize: int, mask: BK.Expr = None): assert ssize > 0 assert self._cur_layer_idx == -1 self._cur_layer_idx = 0 # -- new_mask = BK.constants([bsize, ssize], 1.) if mask is None else mask # [*, ssize] # -- # prepare for store_lstate selecting if len(self.cum_state_lset) > 0: # any layer need to accumulat? self._arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] # note: if no last state, simply clamp 0, otherwise, offset by 1 since we will concat later self._arange_sel_t = mask2posi_padded(new_mask, 0, 0) if mask is None else mask2posi_padded(new_mask, 1, 0) # prev_steps = self.steps # previous accumulated steps self.steps += ssize self.mask = new_mask if self.mask is None else BK.concat([self.mask, new_mask], 1) # [*, old+new] self.positions = mask2posi(self.mask, offset=-1, cmin=0) # [*, old+new], recalculate!!
def nmst_greedy(scores_expr, mask_expr, lengths_arr, labeled=True, ret_arr=False): assert labeled with BK.no_grad_env(): scores_shape = BK.get_shape(scores_expr) maxlen = scores_shape[1] # mask out diag scores_expr += BK.diagflat(BK.constants([maxlen], Constants.REAL_PRAC_MIN)).unsqueeze(-1) # combined last two dimension and Max over them combined_scores_expr = scores_expr.view(scores_shape[:-2] + [-1]) combine_max_scores, combined_max_idxes = BK.max(combined_scores_expr, dim=-1) # back to real idxes last_size = scores_shape[-1] greedy_heads = combined_max_idxes // last_size greedy_labels = combined_max_idxes % last_size if ret_arr: mst_heads_arr, mst_labels_arr, mst_scores_arr = [BK.get_value(z) for z in (greedy_heads, greedy_labels, combine_max_scores)] return mst_heads_arr, mst_labels_arr, mst_scores_arr else: return greedy_heads, greedy_labels, combine_max_scores
def loss(self, flt_items, flt_input_expr, flt_pair_expr, flt_full_expr, flt_mask_expr, flt_extra_weights=None): conf: ExtenderConf = self.conf _loss_lambda = conf._loss_lambda # -- enc_t = self._forward_eenc(flt_input_expr, flt_full_expr, flt_mask_expr) # [*, slen, D] s_left, s_right = self.enode.score(enc_t, flt_pair_expr if conf.ext_use_finput else None, flt_mask_expr) # [*, slen] # -- gold_posi = [self.ext_span_getter(z.mention) for z in flt_items] # List[(widx, wlen)] widx_t = BK.input_idx([z[0] for z in gold_posi]) # [*] wlen_t = BK.input_idx([z[1] for z in gold_posi]) loss_left_t, loss_right_t = BK.loss_nll(s_left, widx_t), BK.loss_nll(s_right, widx_t+wlen_t-1) # [*] if flt_extra_weights is not None: loss_left_t *= flt_extra_weights loss_right_t *= flt_extra_weights loss_div = flt_extra_weights.sum() # note: also use this! else: loss_div = BK.constants([len(flt_items)], value=1.).sum() loss_left_item = LossHelper.compile_leaf_loss("left", loss_left_t.sum(), loss_div, loss_lambda=_loss_lambda) loss_right_item = LossHelper.compile_leaf_loss("right", loss_right_t.sum(), loss_div, loss_lambda=_loss_lambda) ret_loss = LossHelper.combine_multiple_losses([loss_left_item, loss_right_item]) return ret_loss
def _split_extend(self, split_decisions: BK.Expr, cand_mask: BK.Expr): # first augment/pad split_decisions slice_ones = BK.constants([BK.get_shape(split_decisions, 0), 1], 1.) # [*, 1] padded_split_decisions = BK.concat([slice_ones, split_decisions], -1) # [*, clen] seg_cidxes, seg_masks = BK.mask2idx( padded_split_decisions) # [*, seglen] # -- cand_lens = cand_mask.sum(-1, keepdim=True).long() # [*, 1] seg_masks *= (cand_lens > 0).float() # for the case of no cands # -- seg_cidxes_special = seg_cidxes + (1. - seg_masks).long( ) * cand_lens # [*, seglen], fill in for paddings seg_cidxes_special2 = BK.concat([seg_cidxes_special, cand_lens], -1) # [*, seglen+1] seg_clens = seg_cidxes_special2[:, 1:] - seg_cidxes_special # [*, seglen] # extend the idxes seg_ext_cidxes, seg_ext_masks = expand_ranged_idxes( seg_cidxes, seg_clens) # [*, seglen, MW] seg_ext_masks *= seg_masks.unsqueeze(-1) return seg_ext_cidxes, seg_ext_masks, seg_masks # 2x[*, seglen, MW], [*, seglen]
def forward(self, input_map: Dict): add_bos, add_eos = self.conf.add_bos, self.conf.add_eos ret = OrderedDict() # [*, len, ?] for key, embedder_pack in self.embedders.items( ): # according to REG order!! embedder, input_name = embedder_pack one_expr = embedder(input_map[input_name], add_bos=add_bos, add_eos=add_eos) ret[key] = one_expr # mask expr mask_expr = input_map.get("mask") if mask_expr is not None: all_input_slices = [] mask_slice = BK.constants(BK.get_shape(mask_expr)[:-1] + [1], 1, dtype=mask_expr.dtype) # [*, 1] if add_bos: all_input_slices.append(mask_slice) all_input_slices.append(mask_expr) if add_eos: all_input_slices.append(mask_slice) mask_expr = BK.concat(all_input_slices, -1) # [*, ?+len+?] return mask_expr, ret
def beam_search(self, batch_size: int, beam_k: int, ret_best: bool = True): _NEG_INF = Constants.REAL_PRAC_MIN # -- cur_step = 0 cache: DecCache = None # init: keep the seq of scores rather than traceback! start_vals_shape = [batch_size, 1] # [bs, 1] all_preds_t = BK.constants_idx(start_vals_shape, 0).unsqueeze( -1) # [bs, K, step], todo(note): start with 0! all_scores_t = BK.zeros(start_vals_shape).unsqueeze( -1) # [bs, K, step] accu_scores_t = BK.zeros(start_vals_shape) # [bs, K] arange_t = BK.arange_idx(batch_size).unsqueeze(-1) # [bs, 1] # while loop prev_k = 1 # start with single one while not self.is_end(cur_step): # expand and score cache, scores_t, masks_t = self.step_score( cur_step, prev_k, cache) # ..., [bs*pK, L], [bs*pK] scores_t_shape = BK.get_shape(scores_t) last_dim = scores_t_shape[-1] # L # modify score to handle mask: keep previous pred for the masked items! sel_scores_t = BK.constants([batch_size, prev_k, last_dim], 1.) # [bs, pk, L] sel_scores_t.scatter_(-1, all_preds_t[:, :, -1:], -1) # [bs, pk, L] sel_scores_t = scores_t + _NEG_INF * ( sel_scores_t.view(scores_t_shape) * (1. - masks_t).unsqueeze(-1)) # [bs*pK, L] # first select topk locally, note: here no need to sort! local_k = min(last_dim, beam_k) l_topk_scores, l_topk_idxes = sel_scores_t.topk( local_k, -1, sorted=False) # [bs*pK, lK] # then topk globally on full pK*K add_score_shape = [batch_size, prev_k, local_k] to_sel_shape = [batch_size, prev_k * local_k] global_k = min(to_sel_shape[-1], beam_k) # new k to_sel_scores, to_sel_idxes = \ (l_topk_scores.view(add_score_shape) + accu_scores_t.unsqueeze(-1)).view(to_sel_shape), \ l_topk_idxes.view(to_sel_shape) # [bs, pK*lK] _, g_topk_idxes = to_sel_scores.topk(global_k, -1, sorted=True) # [bs, gK] # get to know the idxes new_preds_t = to_sel_idxes.gather(-1, g_topk_idxes) # [bs, gK] new_pk_idxes = ( g_topk_idxes // local_k ) # which previous idx (in beam) are selected? [bs, gK] # get current pred and scores (handling mask) scores_t3 = scores_t.view([batch_size, -1, last_dim]) # [bs, pK, L] masks_t2 = masks_t.view([batch_size, -1]) # [bs, pK] new_masks_t = masks_t2[arange_t, new_pk_idxes] # [bs, gK] # -- one-step score for new selections: [bs, gK], note: zero scores for masked ones new_scores_t = scores_t3[arange_t, new_pk_idxes, new_preds_t] * new_masks_t # [bs, gK] # ending new_arrange_idxes = (arange_t * prev_k + new_pk_idxes).view( -1) # [bs*gK] cache.arrange_idxes(new_arrange_idxes) self.step_end(cur_step, global_k, cache, new_preds_t.view(-1)) # modify in cache # prepare next & judge ending all_preds_t = BK.concat([ all_preds_t[arange_t, new_pk_idxes], new_preds_t.unsqueeze(-1) ], -1) # [bs, gK, step] all_scores_t = BK.concat([ all_scores_t[arange_t, new_pk_idxes], new_scores_t.unsqueeze(-1) ], -1) # [bs, gK, step] accu_scores_t = accu_scores_t[ arange_t, new_pk_idxes] + new_scores_t # [bs, gK] prev_k = global_k # for next step cur_step += 1 # -- # sort and ret at a final step _, final_idxes = accu_scores_t.topk(prev_k, -1, sorted=True) # [bs, K] ret_preds = all_preds_t[ arange_t, final_idxes][:, :, 1:] # [bs, K, steps], exclude dummy start! ret_scores = all_scores_t[arange_t, final_idxes][:, :, 1:] # [bs, K, steps] if ret_best: return ret_preds[:, 0], ret_scores[:, 0] # [bs, slen] else: return ret_preds, ret_scores # [bs, topk, slen]
def gather_losses(self, scores: List[BK.Expr], label_t: BK.Expr, valid_t: BK.Expr, loss_neg_sample: float = None): conf: ZLabelConf = self.conf _loss_do_sel = conf.loss_do_sel _alpha_binary, _alpha_full = conf.loss_binary_alpha, conf.loss_full_alpha _alpha_all_binary = conf.loss_allbinary_alpha # -- if self.crf is not None: # CRF mode! assert _alpha_binary <= 0. and _alpha_all_binary <= 0. # reshape them into 3d valid_premask = (valid_t.sum(-1) > 0.) # [bs, ...] # note: simply collect them all rets = [] _pm_mask, _pm_label = valid_t[valid_premask], label_t[ valid_premask] # [??, slen] for score_t in scores: _one_pm_score = score_t[valid_premask] # [??, slen, D] _one_fscore_t, _ = self._get_score( _one_pm_score) # [??, slen, L] # -- # todo(+N): hacky fix, make it a leading NIL _pm_mask2 = _pm_mask.clone() _pm_mask2[:, 0] = 1. # -- _one_loss, _one_count = self.crf.loss(_one_fscore_t, _pm_mask2, _pm_label) # ?? rets.append((_one_loss * _alpha_full, _one_count)) else: pos_t = (label_t > 0).float() # 0 as NIL!! loss_mask_t = self._get_loss_mask( pos_t, valid_t, loss_neg_sample=loss_neg_sample) # [bs, ...] if _loss_do_sel: _sel_mask = (loss_mask_t > 0.) # [bs, ...] _sel_label = label_t[_sel_mask] # [??] _sel_mask2 = BK.constants([len(_sel_label)], 1.) # [??] # note: simply collect them all rets = [] for score_t in scores: if _loss_do_sel: # [??, ] one_score_t, one_mask_t, one_label_t = score_t[ _sel_mask], _sel_mask2, _sel_label else: # [bs, ..., D] one_score_t, one_mask_t, one_label_t = score_t, loss_mask_t, label_t one_fscore_t, one_nilscore_t = self._get_score(one_score_t) # full loss one_loss_t = BK.loss_nll(one_fscore_t, one_label_t) * _alpha_full # [????] # binary loss if _alpha_binary > 0.: # plus ... _binary_loss = BK.loss_binary( one_nilscore_t.squeeze(-1), (one_label_t > 0).float()) * _alpha_binary # [???] one_loss_t = one_loss_t + _binary_loss # all binary if _alpha_all_binary > 0.: # plus ... _tmp_label_t = BK.zeros( BK.get_shape(one_fscore_t)) # [???, L] _tmp_label_t.scatter_(-1, one_label_t.unsqueeze(-1), 1.) _ab_loss = BK.loss_binary( one_fscore_t, _tmp_label_t) * _alpha_all_binary # [???, L] one_loss_t = one_loss_t + _ab_loss[..., 1:].sum(-1) # -- one_loss_t = one_loss_t * one_mask_t rets.append((one_loss_t, one_mask_t)) # tuple(loss, mask) return rets
def score_all(self, expr_main: BK.Expr, expr_pair: BK.Expr, input_mask: BK.Expr, gold_idxes: BK.Expr, local_normalize: bool = None, use_bigram: bool = True, extra_score: BK.Expr = None): conf: SeqLabelerConf = self.conf # first collect basic scores if conf.use_seqdec: # first prepare init hidden sd_init_t = self.prepare_sd_init(expr_main, expr_pair) # [*, hid] # init cache: no mask at batch level sd_cache = self.seqdec.go_init( sd_init_t, init_mask=None) # and no need to cum_state here! # prepare inputs at once if conf.sd_skip_non: gold_valid_mask = (gold_idxes > 0).float( ) * input_mask # [*, slen], todo(note): fix 0 as non here! gv_idxes, gv_masks = BK.mask2idx(gold_valid_mask) # [*, ?] bsize = BK.get_shape(gold_idxes, 0) arange_t = BK.arange_idx(bsize).unsqueeze(-1) # [*, 1] # select and forward gv_embeds = self.laber.lookup( gold_idxes[arange_t, gv_idxes]) # [*, ?, E] gv_input_t = self.sd_input_aff( [expr_main[arange_t, gv_idxes], gv_embeds]) # [*, ?, hid] gv_hid_t = self.seqdec.go_feed(sd_cache, gv_input_t, gv_masks) # [*, ?, hid] # select back and output_aff aug_hid_t = BK.concat([sd_init_t.unsqueeze(-2), gv_hid_t], -2) # [*, 1+?, hid] sel_t = BK.pad(gold_valid_mask[:, :-1].cumsum(-1), (1, 0), value=0.).long() # [*, 1+(slen-1)] shifted_hid_t = aug_hid_t[arange_t, sel_t] # [*, slen, hid] else: gold_idx_embeds = self.laber.lookup(gold_idxes) # [*, slen, E] all_input_t = self.sd_input_aff( [expr_main, gold_idx_embeds]) # inputs to dec, [*, slen, hid] all_hid_t = self.seqdec.go_feed( sd_cache, all_input_t, input_mask) # output-hids, [*, slen, hid] shifted_hid_t = BK.concat( [sd_init_t.unsqueeze(-2), all_hid_t[:, :-1]], -2) # [*, slen, hid] # scorer pre_labeler_t = self.sd_output_aff([expr_main, shifted_hid_t ]) # [*, slen, hid] else: pre_labeler_t = expr_main # [*, slen, Dm'] # score with labeler (no norm here since we may need to add other scores) scores_t = self.laber.score( pre_labeler_t, None if expr_pair is None else expr_pair.unsqueeze(-2), input_mask, extra_score=extra_score, local_normalize=False) # [*, slen, L] # bigram score addition if conf.use_bigram and use_bigram: bigram_scores_t = self.bigram.get_matrix()[ gold_idxes[:, :-1]] # [*, slen-1, L] score_shape = BK.get_shape(bigram_scores_t) score_shape[1] = 1 slice_t = BK.constants( score_shape, 0.) # fix 0., no transition from BOS (and EOS) for simplicity! bigram_scores_t = BK.concat([slice_t, bigram_scores_t], 1) # [*, slen, L] scores_t += bigram_scores_t # [*, slen] # local normalization? scores_t = self.laber.output_score(scores_t, local_normalize) return scores_t