def forward(self, expr_t: BK.Expr, mask_t: BK.Expr, scores_t=None, **kwargs): conf: IdecConnectorAttConf = self.conf # -- # prepare input _d_bs, _dq, _dk, _d_nl, _d_nh = BK.get_shape(scores_t) in1_t = scores_t[:, :, :, self.lstart:, :self.head_end].reshape( [_d_bs, _dq, _dk, self.d_in]) # [*, lenq, lenk, din] in2_t = in1_t.transpose(-3, -2) # [*, lenk, lenq, din] final_input_t = BK.concat([in1_t, in2_t], -1) # [*, lenk, lenq, din*2] # forward node_ret_t = self.node.forward(final_input_t, mask_t, self.feed_output, self.lidx, **kwargs) # [*, lenq, lenk, head_end] if self.feed_output: # pad zeros if necessary if self.head_end < _d_nh: pad_t = BK.zeros([_d_bs, _dq, _dk, _d_nh - self.head_end]) node_ret_t = BK.concat([node_ret_t, pad_t], -3) # [*, lenq, lenk, Hin] return node_ret_t else: return None
def forward(self, inputs, add_bos=False, add_eos=False): conf: PlainInputEmbedderConf = self.conf # -- voc = self.voc input_t = BK.input_idx(inputs) # [*, len] # rare unk in training if self.is_training() and self.use_rare_unk: rare_unk_rate = conf.rare_unk_rate cur_unk_imask = ( self.rare_unk_mask[input_t] * (BK.rand(BK.get_shape(input_t)) < rare_unk_rate)).long() input_t = input_t * (1 - cur_unk_imask) + voc.unk * cur_unk_imask # bos and eos all_input_slices = [] slice_shape = BK.get_shape(input_t)[:-1] + [1] if add_bos: all_input_slices.append( BK.constants(slice_shape, voc.bos, dtype=input_t.dtype)) all_input_slices.append(input_t) # [*, len] if add_eos: all_input_slices.append( BK.constants(slice_shape, voc.eos, dtype=input_t.dtype)) final_input_t = BK.concat(all_input_slices, -1) # [*, 1?+len+1?] # finally ret = self.E(final_input_t) # [*, ??, dim] return ret
def forward(self, input_map: Dict): mask_expr, expr_map = self.eg.forward(input_map) exprs = list(expr_map.values()) # follow the order in OrderedDict # concat and final concat_expr = BK.concat(exprs, -1) # [*, len, SUM] final_expr = self.final_layer(concat_expr) return mask_expr, final_expr
def forward(self, inputs, add_bos=False, add_eos=False): conf: PosiInputEmbedderConf = self.conf # -- try: # input is a shape as prepared by "PosiHelper" batch_size, max_len = inputs if add_bos: max_len += 1 if add_eos: max_len += 1 posi_idxes = BK.arange_idx(max_len) # [?len?] ret = self.E(posi_idxes).unsqueeze(0).expand(batch_size, -1, -1) except: # input is tensor posi_idxes = BK.input_idx(inputs) # [*, len] cur_maxlen = BK.get_shape(posi_idxes, -1) # -- all_input_slices = [] slice_shape = BK.get_shape(posi_idxes)[:-1] + [1] if add_bos: # add 0 and offset all_input_slices.append( BK.constants(slice_shape, 0, dtype=posi_idxes.dtype)) cur_maxlen += 1 posi_idxes += 1 all_input_slices.append(posi_idxes) # [*, len] if add_eos: all_input_slices.append( BK.constants(slice_shape, cur_maxlen, dtype=posi_idxes.dtype)) final_input_t = BK.concat(all_input_slices, -1) # [*, 1?+len+1?] # finally ret = self.E(final_input_t) # [*, ??, dim] return ret
def ss_add_new_layer(self, layer_idx: int, expr: BK.Expr): assert layer_idx == self._cur_layer_idx self._cur_layer_idx += 1 # -- if self.states is None: self.states = [] _cur_cum_state = (layer_idx in self.cum_state_lset) added_expr = expr if len(self.states) == layer_idx: # first group of calls if _cur_cum_state: # direct select! added_expr = added_expr[self._arange2_t, self._arange_sel_t] self.states.append(added_expr) # directly add as first-time adding else: prev_state_all = self.states[layer_idx] # [bsize, old_step] if _cur_cum_state: # concat last and select added_expr = BK.concat([prev_state_all[:, -1].unsqueeze(1), added_expr], 1)[self._arange2_t, self._arange_sel_t] self.states[layer_idx] = BK.concat([prev_state_all, added_expr], 1) # [*, old+new, D] return added_expr, self.states[layer_idx] # q; kv
def forward(self, med: ZMediator): conf: Idec2Conf = self.conf cur_lidx = med.lidx assert cur_lidx == self.max_app_lidx seq_info = med.ibatch.seq_info # -- # get att values # todo(+N): modify med to make (cached) things more flexible! v_att_final = med.get_cache(self.gatt_key) if v_att_final is None: v_att = BK.concat( med.get_enc_cache( conf.gatt_name).vals[self.min_gatt_lidx:cur_lidx], 1) # [*, L*H, h, m] v_att_rm = self.dsel_rm(v_att.permute(0, 3, 2, 1), seq_info) # first reduce m: [*,m',h,L*H] v_att_rh = self.dsel_rh(v_att_rm.transpose(1, 2), seq_info) # then reduce h: [*,h',m',L*H] v_att_final = BK.concat( [v_att_rh, v_att_rh.transpose(1, 2)], -1) # final concat: [*,h',m',L*H*2] med.set_cache(self.gatt_key, v_att_final) hid_inputs = [self.gatt_drop(v_att_final)] # -- # get hid values if conf.ghid_m or conf.ghid_h: _dsel = self.dsel_hid v_hid = med.get_enc_cache_val( # [*, len', D] "hid", signature=_dsel.signature, function=(lambda x: _dsel.forward(x, seq_info))) if conf.ghid_h: hid_inputs.append(v_hid.unsqueeze(-2)) # [*, h, 1, D] if conf.ghid_m: hid_inputs.append(v_hid.unsqueeze(-3)) # [*, 1, m, D] # -- # go ret = self.aff_hid(hid_inputs) if self.aff_final is not None: ret = self.aff_final(ret) return ret, None # currently no feed!
def _aug_ends( self, t: BK.Expr, BOS, PAD, EOS, dtype ): # add BOS(CLS) and EOS(SEP) for a tensor (sub_len -> 1+sub_len+1) slice_shape = [self.bsize, 1] slices = [ BK.constants(slice_shape, BOS, dtype=dtype), t, BK.constants(slice_shape, PAD, dtype=dtype) ] aug_batched_ids = BK.concat(slices, -1) # [bsize, 1+sub_len+1] aug_batched_ids[self.arange1_t, self.batched_sublens_p1] = EOS # assign EOS return aug_batched_ids
def forward(self, input_expr: BK.Expr, widx_expr: BK.Expr, wlen_expr: BK.Expr): conf: BaseSpanConf = self.conf # -- # note: check empty, otherwise error input_item_shape = BK.get_shape(widx_expr) if np.prod(input_item_shape) == 0: return BK.zeros(input_item_shape + [self.output_dim]) # return an empty but shaped tensor # -- start_idxes, end_idxes = widx_expr, widx_expr+wlen_expr # make [start, end) # get sizes bsize, slen = BK.get_shape(input_expr)[:2] # num_span = BK.get_shape(start_idxes, 1) arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] # -- reprs = [] if conf.use_starts: # start [start, reprs.append(input_expr[arange2_t, start_idxes]) # [bsize, ?, D] if conf.use_ends: # simply ,end-1] reprs.append(input_expr[arange2_t, end_idxes-1]) if conf.use_softhead: # expand range all_span_idxes, all_span_mask = expand_ranged_idxes(widx_expr, wlen_expr, 0, None) # [bsize, ?, MW] # flatten flatten_all_span_idxes = all_span_idxes.view(bsize, -1) # [bsize, ?*MW] flatten_all_span_mask = all_span_mask.view(bsize, -1) # [bsize, ?*MW] # get softhead score (consider mask here) softhead_scores = self.softhead_scorer(input_expr).squeeze(-1) # [bsize, slen] flatten_all_span_scores = softhead_scores[arange2_t, flatten_all_span_idxes] # [bsize, ?*MW] flatten_all_span_scores += (1.-flatten_all_span_mask) * Constants.REAL_PRAC_MIN all_span_scores = flatten_all_span_scores.view(all_span_idxes.shape) # [bsize, ?, MW] # reshape and (optionally topk) and softmax softhead_topk = conf.softhead_topk if softhead_topk>0 and BK.get_shape(all_span_scores,-1)>softhead_topk: # further select topk; note: this may save mem final_span_score, _tmp_idxes = all_span_scores.topk(softhead_topk, dim=-1, sorted=False) # [bsize, ?, K] final_span_idxes = all_span_idxes.gather(-1, _tmp_idxes) # [bsize, ?, K] else: final_span_score, final_span_idxes = all_span_scores, all_span_idxes # [bsize, ?, MW] final_prob = final_span_score.softmax(-1) # [bsize, ?, ??] # [bsize, ?, ??, D] final_repr = input_expr[arange2_t, final_span_idxes.view(bsize, -1)].view(BK.get_shape(final_span_idxes)+[-1]) weighted_repr = (final_repr * final_prob.unsqueeze(-1)).sum(-2) # [bsize, ?, D] reprs.append(weighted_repr) if conf.use_width: cur_width_embed = self.width_embed(wlen_expr) # [bsize, ?, DE] reprs.append(cur_width_embed) # concat concat_repr = BK.concat(reprs, -1) # [bsize, ?, SUM] if conf.use_proj: ret = self.final_proj(concat_repr) # [bsize, ?, DR] else: ret = concat_repr return ret
def _split_extend(self, split_decisions: BK.Expr, cand_mask: BK.Expr): # first augment/pad split_decisions slice_ones = BK.constants([BK.get_shape(split_decisions, 0), 1], 1.) # [*, 1] padded_split_decisions = BK.concat([slice_ones, split_decisions], -1) # [*, clen] seg_cidxes, seg_masks = BK.mask2idx( padded_split_decisions) # [*, seglen] # -- cand_lens = cand_mask.sum(-1, keepdim=True).long() # [*, 1] seg_masks *= (cand_lens > 0).float() # for the case of no cands # -- seg_cidxes_special = seg_cidxes + (1. - seg_masks).long( ) * cand_lens # [*, seglen], fill in for paddings seg_cidxes_special2 = BK.concat([seg_cidxes_special, cand_lens], -1) # [*, seglen+1] seg_clens = seg_cidxes_special2[:, 1:] - seg_cidxes_special # [*, seglen] # extend the idxes seg_ext_cidxes, seg_ext_masks = expand_ranged_idxes( seg_cidxes, seg_clens) # [*, seglen, MW] seg_ext_masks *= seg_masks.unsqueeze(-1) return seg_ext_cidxes, seg_ext_masks, seg_masks # 2x[*, seglen, MW], [*, seglen]
def assign_boundaries(self, items: List, boundary_node, flat_mask_t: BK.Expr, flat_hid_t: BK.Expr, indicators: List): flat_indicators = boundary_node.prepare_indicators( indicators, BK.get_shape(flat_mask_t)) # -- _bsize, _dlen = BK.get_shape(flat_mask_t) # [???, dlen] _once_bsize = max(1, int(self.conf.boundary_bsize / max(1, _dlen))) # -- if _once_bsize >= _bsize: _, _left_idxes, _right_idxes = boundary_node.decode( flat_hid_t, flat_mask_t, flat_indicators) # [???] else: _all_left_idxes, _all_right_idxes = [], [] for ii in range(0, _bsize, _once_bsize): _, _one_left_idxes, _one_right_idxes = boundary_node.decode( flat_hid_t[ii:ii + _once_bsize], flat_mask_t[ii:ii + _once_bsize], [z[ii:ii + _once_bsize] for z in flat_indicators]) _all_left_idxes.append(_one_left_idxes) _all_right_idxes.append(_one_right_idxes) _left_idxes, _right_idxes = BK.concat(_all_left_idxes, 0), BK.concat( _all_right_idxes, 0) _arr_left, _arr_right = BK.get_value(_left_idxes), BK.get_value( _right_idxes) for ii, item in enumerate(items): _mention = item.mention _start = item._tmp_sstart # need to minus this!! _left_widx, _right_widx = _arr_left[ii].item( ) - _start, _arr_right[ii].item() - _start # todo(+N): sometimes we can have repeated ones, currently simply over-write! if _mention.get_span()[1] == 1: _mention.set_span(*(_mention.get_span()), shead=True) # first move to shead! _mention.set_span(_left_widx, _right_widx - _left_widx + 1)
def forward(self, med: ZMediator, **kwargs): conf: IdecConnectorAttConf = self.conf # -- # get stack att: already transposed by zmed scores_t = med.get_stack_att() # [*, len_q, len_k, NL, H] _d_bs, _dq, _dk, _d_nl, _d_nh = BK.get_shape(scores_t) in1_t = scores_t[:, :, :, self.lstart:, :self.head_end].reshape( [_d_bs, _dq, _dk, self.d_in]) # [*, lenq, lenk, din] in2_t = in1_t.transpose(-3, -2) # [*, lenk, lenq, din] cat_t = self._go_detach(BK.concat([in1_t, in2_t], -1)) # [*, lenk, lenq, din*2] # further affine cat_drop_t = self.pre_mid_drop(cat_t) # [*, lenk, lenq, din*2] ret_t = self.mid_aff(cat_drop_t) # [*, lenk, lenq, M] return ret_t
def s0_open_new_steps(self, bsize: int, ssize: int, mask: BK.Expr = None): assert ssize > 0 assert self._cur_layer_idx == -1 self._cur_layer_idx = 0 # -- new_mask = BK.constants([bsize, ssize], 1.) if mask is None else mask # [*, ssize] # -- # prepare for store_lstate selecting if len(self.cum_state_lset) > 0: # any layer need to accumulat? self._arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] # note: if no last state, simply clamp 0, otherwise, offset by 1 since we will concat later self._arange_sel_t = mask2posi_padded(new_mask, 0, 0) if mask is None else mask2posi_padded(new_mask, 1, 0) # prev_steps = self.steps # previous accumulated steps self.steps += ssize self.mask = new_mask if self.mask is None else BK.concat([self.mask, new_mask], 1) # [*, old+new] self.positions = mask2posi(self.mask, offset=-1, cmin=0) # [*, old+new], recalculate!!
def forward(self, inputs, add_bos=False, add_eos=False): conf: CharCnnInputEmbedderConf = self.conf # -- voc = self.voc char_input_t = BK.input_idx(inputs) # [*, len] # todo(note): no need for replacing to unk for char!! # bos and eos all_input_slices = [] slice_shape = BK.get_shape(char_input_t) slice_shape[-2] = 1 # [*, 1, clen] if add_bos: all_input_slices.append( BK.constants(slice_shape, voc.bos, dtype=char_input_t.dtype)) all_input_slices.append(char_input_t) # [*, len, clen] if add_eos: all_input_slices.append( BK.constants(slice_shape, voc.eos, dtype=char_input_t.dtype)) final_input_t = BK.concat(all_input_slices, -2) # [*, 1?+len+1?, clen] # char embeddings char_embed_expr = self.E(final_input_t) # [*, ??, dim] # char cnn ret = self.cnn(char_embed_expr) return ret
def forward(self, input_map: Dict): add_bos, add_eos = self.conf.add_bos, self.conf.add_eos ret = OrderedDict() # [*, len, ?] for key, embedder_pack in self.embedders.items( ): # according to REG order!! embedder, input_name = embedder_pack one_expr = embedder(input_map[input_name], add_bos=add_bos, add_eos=add_eos) ret[key] = one_expr # mask expr mask_expr = input_map.get("mask") if mask_expr is not None: all_input_slices = [] mask_slice = BK.constants(BK.get_shape(mask_expr)[:-1] + [1], 1, dtype=mask_expr.dtype) # [*, 1] if add_bos: all_input_slices.append(mask_slice) all_input_slices.append(mask_expr) if add_eos: all_input_slices.append(mask_slice) mask_expr = BK.concat(all_input_slices, -1) # [*, ?+len+?] return mask_expr, ret
def beam_search(self, batch_size: int, beam_k: int, ret_best: bool = True): _NEG_INF = Constants.REAL_PRAC_MIN # -- cur_step = 0 cache: DecCache = None # init: keep the seq of scores rather than traceback! start_vals_shape = [batch_size, 1] # [bs, 1] all_preds_t = BK.constants_idx(start_vals_shape, 0).unsqueeze( -1) # [bs, K, step], todo(note): start with 0! all_scores_t = BK.zeros(start_vals_shape).unsqueeze( -1) # [bs, K, step] accu_scores_t = BK.zeros(start_vals_shape) # [bs, K] arange_t = BK.arange_idx(batch_size).unsqueeze(-1) # [bs, 1] # while loop prev_k = 1 # start with single one while not self.is_end(cur_step): # expand and score cache, scores_t, masks_t = self.step_score( cur_step, prev_k, cache) # ..., [bs*pK, L], [bs*pK] scores_t_shape = BK.get_shape(scores_t) last_dim = scores_t_shape[-1] # L # modify score to handle mask: keep previous pred for the masked items! sel_scores_t = BK.constants([batch_size, prev_k, last_dim], 1.) # [bs, pk, L] sel_scores_t.scatter_(-1, all_preds_t[:, :, -1:], -1) # [bs, pk, L] sel_scores_t = scores_t + _NEG_INF * ( sel_scores_t.view(scores_t_shape) * (1. - masks_t).unsqueeze(-1)) # [bs*pK, L] # first select topk locally, note: here no need to sort! local_k = min(last_dim, beam_k) l_topk_scores, l_topk_idxes = sel_scores_t.topk( local_k, -1, sorted=False) # [bs*pK, lK] # then topk globally on full pK*K add_score_shape = [batch_size, prev_k, local_k] to_sel_shape = [batch_size, prev_k * local_k] global_k = min(to_sel_shape[-1], beam_k) # new k to_sel_scores, to_sel_idxes = \ (l_topk_scores.view(add_score_shape) + accu_scores_t.unsqueeze(-1)).view(to_sel_shape), \ l_topk_idxes.view(to_sel_shape) # [bs, pK*lK] _, g_topk_idxes = to_sel_scores.topk(global_k, -1, sorted=True) # [bs, gK] # get to know the idxes new_preds_t = to_sel_idxes.gather(-1, g_topk_idxes) # [bs, gK] new_pk_idxes = ( g_topk_idxes // local_k ) # which previous idx (in beam) are selected? [bs, gK] # get current pred and scores (handling mask) scores_t3 = scores_t.view([batch_size, -1, last_dim]) # [bs, pK, L] masks_t2 = masks_t.view([batch_size, -1]) # [bs, pK] new_masks_t = masks_t2[arange_t, new_pk_idxes] # [bs, gK] # -- one-step score for new selections: [bs, gK], note: zero scores for masked ones new_scores_t = scores_t3[arange_t, new_pk_idxes, new_preds_t] * new_masks_t # [bs, gK] # ending new_arrange_idxes = (arange_t * prev_k + new_pk_idxes).view( -1) # [bs*gK] cache.arrange_idxes(new_arrange_idxes) self.step_end(cur_step, global_k, cache, new_preds_t.view(-1)) # modify in cache # prepare next & judge ending all_preds_t = BK.concat([ all_preds_t[arange_t, new_pk_idxes], new_preds_t.unsqueeze(-1) ], -1) # [bs, gK, step] all_scores_t = BK.concat([ all_scores_t[arange_t, new_pk_idxes], new_scores_t.unsqueeze(-1) ], -1) # [bs, gK, step] accu_scores_t = accu_scores_t[ arange_t, new_pk_idxes] + new_scores_t # [bs, gK] prev_k = global_k # for next step cur_step += 1 # -- # sort and ret at a final step _, final_idxes = accu_scores_t.topk(prev_k, -1, sorted=True) # [bs, K] ret_preds = all_preds_t[ arange_t, final_idxes][:, :, 1:] # [bs, K, steps], exclude dummy start! ret_scores = all_scores_t[arange_t, final_idxes][:, :, 1:] # [bs, K, steps] if ret_best: return ret_preds[:, 0], ret_scores[:, 0] # [bs, slen] else: return ret_preds, ret_scores # [bs, topk, slen]
def _loss_feed_cand(self, mask_expr, cand_full_scores, pred_cand_decisions, expr_seq_gaddr, expr_group_widxes, expr_group_masks, expr_loss_weight_non): conf: SoftExtractorConf = self.conf bsize, slen = BK.get_shape(mask_expr) arange3_t = BK.arange_idx(bsize).unsqueeze(-1).unsqueeze( -1) # [*, 1, 1] # -- # step 1.1: bag loss cand_gold_mask = (expr_seq_gaddr >= 0).float() * mask_expr # [*, slen], whether is-arg raw_loss_cand = BK.loss_binary( cand_full_scores, cand_gold_mask, label_smoothing=conf.cand_label_smoothing) # [*, slen] # how to weight? extended_scores_t = cand_full_scores[arange3_t, expr_group_widxes] + ( 1. - expr_group_masks) * Constants.REAL_PRAC_MIN # [*, slen, MW] if BK.is_zero_shape(extended_scores_t): extended_scores_max_t = BK.zeros(mask_expr.shape) # [*, slen] else: extended_scores_max_t, _ = extended_scores_t.max(-1) # [*, slen] _w_alpha = conf.cand_loss_weight_alpha _weight = ((cand_full_scores - extended_scores_max_t) * _w_alpha).exp() # [*, slen] if not conf.cand_loss_div_max: # div sum-all, like doing softmax _weight = _weight / ( (extended_scores_t - extended_scores_max_t.unsqueeze(-1)) * _w_alpha).exp().sum(-1) _weight = _weight * (_weight >= conf.cand_loss_weight_thresh).float() # [*, slen] if conf.cand_detach_weight: _weight = _weight.detach() # pos poison (dis-encouragement) if conf.cand_loss_pos_poison: poison_loss = BK.loss_binary( cand_full_scores, 1. - cand_gold_mask, label_smoothing=conf.cand_label_smoothing) # [*, slen] raw_loss_cand = raw_loss_cand * _weight + poison_loss * cand_gold_mask * ( 1. - _weight) # [*, slen] else: raw_loss_cand = raw_loss_cand * _weight # final weight it cand_loss_weights = BK.where(cand_gold_mask == 0., expr_loss_weight_non.unsqueeze(-1) * conf.loss_weight_non, mask_expr) # [*, slen] final_cand_loss_weights = cand_loss_weights * mask_expr # [*, slen] loss_cand_item = LossHelper.compile_leaf_loss( f"cand", (raw_loss_cand * final_cand_loss_weights).sum(), final_cand_loss_weights.sum(), loss_lambda=conf.loss_cand) # step 1.2: feed cand # todo(+N): currently only pred/sample, whether adding certain teacher-forcing? sample_decisions = (BK.sigmoid(cand_full_scores) >= BK.rand( cand_full_scores.shape)).float() * mask_expr # [*, slen] _use_sample_mask = (BK.rand([bsize]) <= conf.cand_feed_sample_rate).float().unsqueeze( -1) # [*, 1], seq-level feed_cand_decisions = (_use_sample_mask * sample_decisions + (1. - _use_sample_mask) * pred_cand_decisions ) # [*, slen] # next cand_widxes, cand_masks = BK.mask2idx(feed_cand_decisions) # [*, clen] # -- # extra: loss_cand_entropy rets = [loss_cand_item] _loss_cand_entropy = conf.loss_cand_entropy if _loss_cand_entropy > 0.: _prob = extended_scores_t.softmax(-1) # [*, slen, MW] _ent = EntropyHelper.self_entropy(_prob) # [*, slen] # [*, slen], only first one in bag _ent_mask = BK.concat([ expr_seq_gaddr[:, :1] >= 0, expr_seq_gaddr[:, 1:] != expr_seq_gaddr[:, :-1] ], -1).float() * cand_gold_mask _loss_ent_item = LossHelper.compile_leaf_loss( f"cand_ent", (_ent * _ent_mask).sum(), _ent_mask.sum(), loss_lambda=_loss_cand_entropy) rets.append(_loss_ent_item) # -- return rets, cand_widxes, cand_masks
def loss(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr, pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr = None): conf: AnchorExtractorConf = self.conf assert not lookup_flatten bsize, slen = BK.get_shape(mask_expr) # -- # step 0: prepare arr_items, expr_seq_gaddr, expr_seq_labs, expr_group_widxes, expr_group_masks, expr_loss_weight_non = \ self.helper.prepare(insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True) arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [*, 1] arange3_t = arange2_t.unsqueeze(-1) # [*, 1, 1] # -- # step 1: label, simply scoring everything! _main_t, _pair_t = self.lab_node.transform_expr(input_expr, pair_expr) all_scores_t = self.lab_node.score_all( _main_t, _pair_t, mask_expr, None, local_normalize=False, extra_score=external_extra_score ) # unnormalized scores [*, slen, L] all_probs_t = all_scores_t.softmax(-1) # [*, slen, L] all_gprob_t = all_probs_t.gather(-1, expr_seq_labs.unsqueeze(-1)).squeeze( -1) # [*, slen] # how to weight extended_gprob_t = all_gprob_t[ arange3_t, expr_group_widxes] * expr_group_masks # [*, slen, MW] if BK.is_zero_shape(extended_gprob_t): extended_gprob_max_t = BK.zeros(mask_expr.shape) # [*, slen] else: extended_gprob_max_t, _ = extended_gprob_t.max(-1) # [*, slen] _w_alpha = conf.cand_loss_weight_alpha _weight = ( (all_gprob_t * mask_expr) / (extended_gprob_max_t.clamp(min=1e-5)))**_w_alpha # [*, slen] _label_smoothing = conf.lab_conf.labeler_conf.label_smoothing _loss1 = BK.loss_nll(all_scores_t, expr_seq_labs, label_smoothing=_label_smoothing) # [*, slen] _loss2 = BK.loss_nll(all_scores_t, BK.constants_idx([bsize, slen], 0), label_smoothing=_label_smoothing) # [*, slen] _weight1 = _weight.detach() if conf.detach_weight_lab else _weight _raw_loss = _weight1 * _loss1 + (1. - _weight1) * _loss2 # [*, slen] # final weight it cand_loss_weights = BK.where(expr_seq_labs == 0, expr_loss_weight_non.unsqueeze(-1) * conf.loss_weight_non, mask_expr) # [*, slen] final_cand_loss_weights = cand_loss_weights * mask_expr # [*, slen] loss_lab_item = LossHelper.compile_leaf_loss( f"lab", (_raw_loss * final_cand_loss_weights).sum(), final_cand_loss_weights.sum(), loss_lambda=conf.loss_lab, gold=(expr_seq_labs > 0).float().sum()) # -- # step 1.5 all_losses = [loss_lab_item] _loss_cand_entropy = conf.loss_cand_entropy if _loss_cand_entropy > 0.: _prob = extended_gprob_t # [*, slen, MW] _ent = EntropyHelper.self_entropy(_prob) # [*, slen] # [*, slen], only first one in bag _ent_mask = BK.concat([expr_seq_gaddr[:,:1]>=0, expr_seq_gaddr[:,1:]!=expr_seq_gaddr[:,:-1]],-1).float() \ * (expr_seq_labs>0).float() _loss_ent_item = LossHelper.compile_leaf_loss( f"cand_ent", (_ent * _ent_mask).sum(), _ent_mask.sum(), loss_lambda=_loss_cand_entropy) all_losses.append(_loss_ent_item) # -- # step 4: extend (select topk) if conf.loss_ext > 0.: if BK.is_zero_shape(extended_gprob_t): flt_mask = (BK.zeros(mask_expr.shape) > 0) else: _topk = min(conf.ext_loss_topk, BK.get_shape(extended_gprob_t, -1)) # number to extract _topk_grpob_t, _ = extended_gprob_t.topk( _topk, dim=-1) # [*, slen, K] flt_mask = (expr_seq_labs > 0) & (all_gprob_t >= _topk_grpob_t.min(-1)[0]) & ( _weight > conf.ext_loss_thresh) # [*, slen] flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[ flt_mask] # [?] flt_expr = input_expr[flt_mask] # [?, D] flt_full_expr = self._prepare_full_expr(flt_mask) # [?, slen, D] flt_items = arr_items.flatten()[BK.get_value( expr_seq_gaddr[flt_mask])] # [?] flt_weights = _weight.detach( )[flt_mask] if conf.detach_weight_ext else _weight[flt_mask] # [?] loss_ext_item = self.ext_node.loss(flt_items, input_expr[flt_sidx], flt_expr, flt_full_expr, mask_expr[flt_sidx], flt_extra_weights=flt_weights) all_losses.append(loss_ext_item) # -- # return loss ret_loss = LossHelper.combine_multiple_losses(all_losses) return ret_loss, None
def forward(self, inputs, vstate: VrecSteppingState = None, inc_cls=False): conf: BertEncoderConf = self.conf # -- no_bert_ft = (not conf.bert_ft ) # whether fine-tune bert (if not detach hiddens!) impl = self.impl # -- # prepare inputs if not isinstance(inputs, BerterInputBatch): inputs = self.create_input_batch(inputs) all_output_layers = [] # including embeddings # -- # get embeddings (for embeddings, we simply forward once!) mask_repl_rate = conf.bert_repl_mask_rate if self.is_training() else 0. input_ids, input_masks = inputs.get_basic_inputs( mask_repl_rate) # [bsize, 1+sub_len+1] other_embeds = None if self.other_embed_nodes is not None and len( self.other_embed_nodes) > 0: other_embeds = 0. for other_name, other_node in self.other_embed_nodes.items(): other_embeds += other_node( inputs.other_factors[other_name] ) # should be prepared correspondingly!! # -- # forward layers (for layers, we may need to split!) # todo(+N): we simply split things apart, thus middle parts may lack CLS/SEP, and not true global att # todo(+N): the lengths currently are hard-coded!! MAX_LEN = 512 # max len INBUF_LEN = 50 # in-between buffer for splits, for both sides! cur_sub_len = BK.get_shape(input_ids, 1) # 1+sub_len+1 needs_split = (cur_sub_len > MAX_LEN) if needs_split: # decide split and merge points split_points = self._calculate_split_points( cur_sub_len, MAX_LEN, INBUF_LEN) zwarn( f"Multi-seg for Berter: {cur_sub_len}//{len(split_points)}->{split_points}" ) # -- # todo(note): we also need split from embeddings if needs_split: all_embed_pieces = [] split_extended_attention_mask = [] for o_s, o_e, i_s, i_e in split_points: piece_embeddings, piece_extended_attention_mask = impl.forward_embedding( *[(None if z is None else z[:, o_s:o_e]) for z in [ input_ids, input_masks, inputs.batched_token_type_ids, inputs.batched_position_ids, other_embeds ]]) all_embed_pieces.append(piece_embeddings[:, i_s:i_e]) split_extended_attention_mask.append( piece_extended_attention_mask) embeddings = BK.concat(all_embed_pieces, 1) # concat back to full extended_attention_mask = None else: embeddings, extended_attention_mask = impl.forward_embedding( input_ids, input_masks, inputs.batched_token_type_ids, inputs.batched_position_ids, other_embeds) split_extended_attention_mask = None if no_bert_ft: # stop gradient embeddings = embeddings.detach() # -- cur_hidden = embeddings all_output_layers.append(embeddings) # *[bsize, 1+sub_len+1, D] # also prepare mapper idxes for sub <-> orig # todo(+N): currently only use the first sub-word! idxes_arange2 = inputs.arange2_t # [bsize, 1] batched_first_idxes_p1 = (1 + inputs.batched_first_idxes) * ( inputs.batched_first_mask.long()) # plus one for CLS offset! if inc_cls: # [bsize, 1+orig_len] idxes_sub2orig = BK.concat([ BK.constants_idx([inputs.bsize, 1], 0), batched_first_idxes_p1 ], 1) else: # [bsize, orig_len] idxes_sub2orig = batched_first_idxes_p1 _input_masks0 = None # used for vstate back, make it 0. for BOS and EOS # for ii in range(impl.num_hidden_layers): for ii in range(max(self.actual_output_layers) ): # do not need that much if does not require! # forward multiple times with splitting if needed if needs_split: all_pieces = [] for piece_idx, piece_points in enumerate(split_points): o_s, o_e, i_s, i_e = piece_points piece_res = impl.forward_hidden( ii, cur_hidden[:, o_s:o_e], split_extended_attention_mask[piece_idx])[:, i_s:i_e] all_pieces.append(piece_res) new_hidden = BK.concat(all_pieces, 1) # concat back to full else: new_hidden = impl.forward_hidden(ii, cur_hidden, extended_attention_mask) if no_bert_ft: # stop gradient new_hidden = new_hidden.detach() if vstate is not None: # from 1+sub_len+1 -> (inc_cls?)+orig_len new_hidden2orig = new_hidden[ idxes_arange2, idxes_sub2orig] # [bsize, 1?+orig_len, D] # update new_hidden2orig_ret = vstate.update( new_hidden2orig) # [bsize, 1?+orig_len, D] if new_hidden2orig_ret is not None: # calculate when needed if _input_masks0 is None: # [bsize, 1+sub_len+1, 1] with 1. only for real valid ones _input_masks0 = inputs._aug_ends( inputs.batched_input_mask, 0., 0., 0., BK.float32).unsqueeze(-1) # back to 1+sub_len+1; todo(+N): here we simply add and //2, and no CLS back from orig to sub!! tmp_orig2sub = new_hidden2orig_ret[ idxes_arange2, int(inc_cls) + inputs.batched_rev_idxes] # [bsize, sub_len, D] tmp_slice_size = BK.get_shape(tmp_orig2sub) tmp_slice_size[1] = 1 tmp_slice_zero = BK.zeros(tmp_slice_size) tmp_orig2sub_aug = BK.concat( [tmp_slice_zero, tmp_orig2sub, tmp_slice_zero], 1) # [bsize, 1+sub_len+1, D] new_hidden = new_hidden * (1. - _input_masks0) + ( (new_hidden + tmp_orig2sub_aug) / 2.) * _input_masks0 all_output_layers.append(new_hidden) cur_hidden = new_hidden # finally, prepare return final_output_layers = [ all_output_layers[z] for z in conf.bert_output_layers ] # *[bsize,1+sl+1,D] combined_output = self.combiner( final_output_layers) # [bsize, 1+sl+1, ??] final_ret = combined_output[idxes_arange2, idxes_sub2orig] # [bsize, 1?+orig_len, D] return final_ret
def score_all(self, expr_main: BK.Expr, expr_pair: BK.Expr, input_mask: BK.Expr, gold_idxes: BK.Expr, local_normalize: bool = None, use_bigram: bool = True, extra_score: BK.Expr = None): conf: SeqLabelerConf = self.conf # first collect basic scores if conf.use_seqdec: # first prepare init hidden sd_init_t = self.prepare_sd_init(expr_main, expr_pair) # [*, hid] # init cache: no mask at batch level sd_cache = self.seqdec.go_init( sd_init_t, init_mask=None) # and no need to cum_state here! # prepare inputs at once if conf.sd_skip_non: gold_valid_mask = (gold_idxes > 0).float( ) * input_mask # [*, slen], todo(note): fix 0 as non here! gv_idxes, gv_masks = BK.mask2idx(gold_valid_mask) # [*, ?] bsize = BK.get_shape(gold_idxes, 0) arange_t = BK.arange_idx(bsize).unsqueeze(-1) # [*, 1] # select and forward gv_embeds = self.laber.lookup( gold_idxes[arange_t, gv_idxes]) # [*, ?, E] gv_input_t = self.sd_input_aff( [expr_main[arange_t, gv_idxes], gv_embeds]) # [*, ?, hid] gv_hid_t = self.seqdec.go_feed(sd_cache, gv_input_t, gv_masks) # [*, ?, hid] # select back and output_aff aug_hid_t = BK.concat([sd_init_t.unsqueeze(-2), gv_hid_t], -2) # [*, 1+?, hid] sel_t = BK.pad(gold_valid_mask[:, :-1].cumsum(-1), (1, 0), value=0.).long() # [*, 1+(slen-1)] shifted_hid_t = aug_hid_t[arange_t, sel_t] # [*, slen, hid] else: gold_idx_embeds = self.laber.lookup(gold_idxes) # [*, slen, E] all_input_t = self.sd_input_aff( [expr_main, gold_idx_embeds]) # inputs to dec, [*, slen, hid] all_hid_t = self.seqdec.go_feed( sd_cache, all_input_t, input_mask) # output-hids, [*, slen, hid] shifted_hid_t = BK.concat( [sd_init_t.unsqueeze(-2), all_hid_t[:, :-1]], -2) # [*, slen, hid] # scorer pre_labeler_t = self.sd_output_aff([expr_main, shifted_hid_t ]) # [*, slen, hid] else: pre_labeler_t = expr_main # [*, slen, Dm'] # score with labeler (no norm here since we may need to add other scores) scores_t = self.laber.score( pre_labeler_t, None if expr_pair is None else expr_pair.unsqueeze(-2), input_mask, extra_score=extra_score, local_normalize=False) # [*, slen, L] # bigram score addition if conf.use_bigram and use_bigram: bigram_scores_t = self.bigram.get_matrix()[ gold_idxes[:, :-1]] # [*, slen-1, L] score_shape = BK.get_shape(bigram_scores_t) score_shape[1] = 1 slice_t = BK.constants( score_shape, 0.) # fix 0., no transition from BOS (and EOS) for simplicity! bigram_scores_t = BK.concat([slice_t, bigram_scores_t], 1) # [*, slen, L] scores_t += bigram_scores_t # [*, slen] # local normalization? scores_t = self.laber.output_score(scores_t, local_normalize) return scores_t