def predict(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr, pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr = None): # -- # note: check empty if BK.is_zero_shape(mask_expr): for inst in insts: # still need to clear things!! self.helper._clear_f(inst) else: # simply labeling! best_labs, best_scores = self.lab_node.predict( input_expr, pair_expr, mask_expr, extra_score=external_extra_score) # put results self.helper.put_results(insts, best_labs, best_scores) # -- # finally return self._finish_pred(insts, input_expr, mask_expr, pair_expr, lookup_flatten)
def _forward_max(self, repr_t: BK.Expr, dsel_seq_info): RDIM = 2 # reduce dim # -- _all_repr_t, _ = self._aggregate_subtoks(repr_t, dsel_seq_info) ret = _all_repr_t.sum(RDIM) if BK.is_zero_shape(_all_repr_t) else _all_repr_t.max(RDIM)[0] # [*, dlen, D] ret = BK.relu(ret) # note: for simplicity, just make things>=0. return ret
def predict(self, insts: List[Sent], input_expr: BK.Expr, mask_expr: BK.Expr): conf: MySRLConf = self.conf slen = BK.get_shape(mask_expr, -1) # -- # ===== # evt _, all_evt_cfs, all_evt_raw_scores = self.evt_node.get_all_values() # [*, slen, Le] all_evt_scores = [z.log_softmax(-1) for z in all_evt_raw_scores] final_evt_scores = self.evt_node.helper.pred(all_logprobs=all_evt_scores, all_cfs=all_evt_cfs) # [*, slen, Le] if conf.evt_pred_use_all or conf.evt_pred_use_posi: # todo(+W): not an elegant way... final_evt_scores[:,:,0] += Constants.REAL_PRAC_MIN # all pred sth!! pred_evt_scores, pred_evt_labels = final_evt_scores.max(-1) # [*, slen] # ===== # arg _, all_arg_cfs, all_arg_raw_score = self.arg_node.get_all_values() # [*, slen, slen, La] all_arg_scores = [z.log_softmax(-1) for z in all_arg_raw_score] final_arg_scores = self.arg_node.helper.pred(all_logprobs=all_arg_scores, all_cfs=all_arg_cfs) # [*, slen, slen, La] # slightly more efficient by masking valid evts?? full_pred_shape = BK.get_shape(final_arg_scores)[:-1] # [*, slen, slen] pred_arg_scores, pred_arg_labels = BK.zeros(full_pred_shape), BK.zeros(full_pred_shape).long() arg_flat_mask = (pred_evt_labels > 0) # [*, slen] flat_arg_scores = final_arg_scores[arg_flat_mask] # [??, slen, La] if not BK.is_zero_shape(flat_arg_scores): # at least one predicate! if self.pred_cons_mat is not None: flat_mask_expr = mask_expr.unsqueeze(-2).expand(-1, slen, slen)[arg_flat_mask] # [*, 1->slen, slen] => [??, slen] flat_pred_arg_labels, flat_pred_arg_scores = BigramInferenceHelper.inference_search( flat_arg_scores, self.pred_cons_mat, flat_mask_expr, conf.arg_beam_k) # [??, slen] else: flat_pred_arg_scores, flat_pred_arg_labels = flat_arg_scores.max(-1) # [??, slen] pred_arg_scores[arg_flat_mask] = flat_pred_arg_scores pred_arg_labels[arg_flat_mask] = flat_pred_arg_labels # ===== # assign self.helper.put_results(insts, pred_evt_labels, pred_evt_scores, pred_arg_labels, pred_arg_scores)
def _loss_feed_split(self, mask_expr, split_scores, pred_split_decisions, cand_widxes, cand_masks, cand_expr, cand_scores, expr_seq_gaddr): conf: SoftExtractorConf = self.conf bsize, slen = BK.get_shape(mask_expr) arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [*, 1, 1] # -- # step 2.1: split loss (only on good points (excluding -1|-1 or paddings) with dynamic oracle) cand_gaddr = expr_seq_gaddr[arange2_t, cand_widxes] # [*, clen] cand_gaddr0, cand_gaddr1 = cand_gaddr[:, : -1], cand_gaddr[:, 1:] # [*, clen-1] split_oracle = (cand_gaddr0 != cand_gaddr1).float() * cand_masks[:, 1:] # [*, clen-1] split_oracle_mask = ( (cand_gaddr0 >= 0) | (cand_gaddr1 >= 0)).float() * cand_masks[:, 1:] # [*, clen-1] raw_split_loss = BK.loss_binary( split_scores, split_oracle, label_smoothing=conf.split_label_smoothing) # [*, slen] loss_split_item = LossHelper.compile_leaf_loss( f"split", (raw_split_loss * split_oracle_mask).sum(), split_oracle_mask.sum(), loss_lambda=conf.loss_split) # step 2.2: feed split # note: when teacher-forcing, only forcing good points, others still use pred force_split_decisions = split_oracle_mask * split_oracle + ( 1. - split_oracle_mask) * pred_split_decisions # [*, clen-1] _use_force_mask = (BK.rand([bsize]) <= conf.split_feed_force_rate).float().unsqueeze( -1) # [*, 1], seq-level feed_split_decisions = (_use_force_mask * force_split_decisions + (1. - _use_force_mask) * pred_split_decisions ) # [*, clen-1] # next # *[*, seglen, MW], [*, seglen] seg_ext_cidxes, seg_ext_masks, seg_masks = self._split_extend( feed_split_decisions, cand_masks) seg_ext_scores, seg_ext_cidxes, seg_ext_widxes, seg_ext_masks, seg_weighted_expr = self._split_aggregate( cand_expr, cand_scores, cand_widxes, seg_ext_cidxes, seg_ext_masks, conf.split_topk) # [*, seglen, ?] # finally get oracles for next steps # todo(+N): simply select the highest scored one as oracle if BK.is_zero_shape(seg_ext_scores): # simply make them all -1 oracle_gaddr = BK.constants_idx(seg_masks.shape, -1) # [*, seglen] else: _, _seg_max_t = seg_ext_scores.max(-1, keepdim=True) # [*, seglen, 1] oracle_widxes = seg_ext_widxes.gather(-1, _seg_max_t).squeeze( -1) # [*, seglen] oracle_gaddr = expr_seq_gaddr.gather(-1, oracle_widxes) # [*, seglen] oracle_gaddr[seg_masks <= 0] = -1 # (assign invalid ones) [*, seglen] return loss_split_item, seg_masks, seg_ext_widxes, seg_ext_masks, seg_weighted_expr, oracle_gaddr
def prepare_sd_init(self, expr_main: BK.Expr, expr_pair: BK.Expr): if self.is_pairwise: sd_init_t = self.sd_init_aff(expr_pair) # [*, hid] else: if BK.is_zero_shape(expr_main): sd_init_t0 = expr_main.sum(-2) # simply make the shape! else: sd_init_t0 = self.sd_init_pool_f(expr_main, -2) # pooling at -2: [*, Dm'] sd_init_t = self.sd_init_aff(sd_init_t0) # [*, hid] return sd_init_t
def loss(self, unary_scores: BK.Expr, input_mask: BK.Expr, gold_idxes: BK.Expr): mat_t = self.bigram.get_matrix() # [L, L] if BK.is_zero_shape(unary_scores): # note: avoid empty potential_t = BK.zeros(BK.get_shape(unary_scores)[:-2]) # [*] else: potential_t = BigramInferenceHelper.inference_forward(unary_scores, mat_t, input_mask, self.conf.crf_beam) # [*] gold_single_scores_t = unary_scores.gather(-1, gold_idxes.unsqueeze(-1)).squeeze(-1) * input_mask # [*, slen] gold_bigram_scores_t = mat_t[gold_idxes[:, :-1], gold_idxes[:, 1:]] * input_mask[:, 1:] # [*, slen-1] all_losses_t = (potential_t - (gold_single_scores_t.sum(-1) + gold_bigram_scores_t.sum(-1))) # [*] if self.conf.loss_by_tok: ret_count = input_mask.sum() # [] else: ret_count = (input_mask.sum(-1)>0).float() # [*] return all_losses_t, ret_count
def _aggregate_subtoks(self, repr_t: BK.Expr, dsel_seq_info): conf: DSelectorConf = self.conf _arange_t, _sel_t, _len_t = dsel_seq_info.arange2_t, dsel_seq_info.dec_sel_idxes, dsel_seq_info.dec_sel_lens _max_len = 1 if BK.is_zero_shape(_len_t) else _len_t.max().item() _max_len = max(1, min(conf.dsel_max_subtoks, _max_len)) # truncate # -- _tmp_arange_t = BK.arange_idx(_max_len) # [M] _all_valids_t = (_tmp_arange_t < _len_t.unsqueeze(-1)).float() # [*, dlen, M] _tmp_arange_t = _tmp_arange_t * _all_valids_t.long() # note: pad as 0 _all_idxes_t = _sel_t.unsqueeze(-1) + _tmp_arange_t # [*, dlen, M] _all_repr_t = repr_t[_arange_t.unsqueeze(-1), _all_idxes_t] # [*, dlen, M, D] while len(BK.get_shape(_all_valids_t)) < len(BK.get_shape(_all_repr_t)): _all_valids_t = _all_valids_t.unsqueeze(-1) _all_repr_t = _all_repr_t * _all_valids_t return _all_repr_t, _all_valids_t
def expand_ranged_idxes(widx_t: BK.Expr, wlen_t: BK.Expr, pad: int = 0, max_width: int = None): if max_width is None: # if not provided if BK.is_zero_shape(wlen_t): max_width = 1 # at least one else: max_width = wlen_t.max().item() # overall max width # -- input_shape = BK.get_shape(widx_t) # [*] mw_range_t = BK.arange_idx(max_width).view([1] * len(input_shape) + [-1]) # [*, MW] expanded_idxes = widx_t.unsqueeze(-1) + mw_range_t # [*, MW] expanded_masks_bool = (mw_range_t < wlen_t.unsqueeze(-1)) # [*, MW] expanded_idxes.masked_fill_(~expanded_masks_bool, pad) # [*, MW] return expanded_idxes, expanded_masks_bool.float()
def go_topk( self, input_expr: BK.Expr, input_mask: BK.Expr, # input widx_expr: BK.Expr, wlen_expr: BK.Expr, span_mask: BK.Expr, rate: float = None, count: float = None, # span gaddr_expr: BK.Expr = None, add_gold_rate: float = 0., # gold non_overlapping=False, score_prune: float = None): # non-overlapping! lookup_res = self.go_lookup(input_expr, widx_expr, wlen_expr, span_mask, gaddr_expr) # [bsize, NUM, *] # -- with BK.no_grad_env(): # no need grad here! all_score_expr = lookup_res.score_expr # get topk score: again rate is to the original input length if BK.is_zero_shape(lookup_res.mask_expr): topk_mask = lookup_res.mask_expr.clone( ) # no need to go topk since no elements else: topk_expr = self._determine_size( input_mask.sum(-1, keepdim=True), rate, count).long() # [bsize, 1] if non_overlapping: topk_mask = select_topk_non_overlapping(all_score_expr, topk_expr, widx_expr, wlen_expr, input_mask, mask_t=span_mask, dim=-1) else: topk_mask = select_topk(all_score_expr, topk_expr, mask_t=span_mask, dim=-1) # further score_prune? if score_prune is not None: topk_mask *= (all_score_expr >= score_prune).float() # select and add_gold return self._go_common(lookup_res, topk_mask, add_gold_rate)
def loss(self, all_losses: List[BK.Expr], all_cfs: List[BK.Expr], **kwargs): conf: IdecHelperCW2Conf = self.conf _temp = self.temperature.value # -- stack_t = BK.stack(all_losses, -1) # [*, NL] w_t = (-stack_t / _temp) # [*, NL], smaller loss is better! w_t_detach = w_t.detach() # main loss apply_w_t = w_t_detach if conf.detach_weights else w_t ret_t = (stack_t * apply_w_t.softmax(-1)).sum(-1) # [*] # cf loss cf_t = BK.stack(all_cfs, -1).sigmoid() # [*, NL] if conf.cf_trg_rel: # relative prob proportion? _max_t = w_t_detach.sum(-1, keepdim=True) if BK.is_zero_shape( w_t_detach) else w_t_detach.max(-1, keepdim=True)[0] # [*, 1] _trg_t = (w_t_detach - _max_t).exp() * conf.max_cf # [*, NL] else: _trg_t = w_t_detach.exp() * conf.max_cf loss_cf_t = BK.loss_binary(cf_t, _trg_t).mean(-1) # [*] return [(ret_t, 1., ""), (loss_cf_t, conf.loss_cf, "_cf")]
def predict(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr, pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr=None): conf: DirectExtractorConf = self.conf # step 1: prepare targets if conf.pred_use_posi: # step 1a: directly use provided positions arr_gold_items, expr_gold_gaddr, expr_gold_widxes, expr_gold_wlens, _ = self.helper.prepare(insts, use_cache=False) cand_res = self.extract_node.go_lookup(input_expr, expr_gold_widxes, expr_gold_wlens, (expr_gold_gaddr>=0).float(), gaddr_expr=expr_gold_gaddr) else: arr_gold_items = None # step 1b: extract cands (topk); todo(note): assume no in-middle mask!! cand_widx, cand_wlen, cand_mask, _ = self.extract_node.prepare_with_lengths( BK.get_shape(mask_expr), mask_expr.sum(-1).long(), None, None, None) cand_res = self.extract_node.go_topk( input_expr, mask_expr, cand_widx, cand_wlen, cand_mask, rate=conf.span_topk_rate, count=conf.span_topk_count, non_overlapping=conf.pred_non_overlapping, score_prune=conf.pred_score_prune) # -- # note: check empty if BK.is_zero_shape(cand_res.mask_expr): if not conf.pred_use_posi: for inst in insts: # still need to clear things!! self.helper._clear_f(inst) else: # step 2: labeling # extra score cand_extra_score = self._get_extra_score( cand_res.score_expr, insts, cand_res, arr_gold_items, conf.pred_use_cons, conf.pred_use_lu) final_extra_score = self._sum_scores(external_extra_score, cand_extra_score) best_labs, best_scores = self.lab_node.predict( cand_res.span_expr, pair_expr, cand_res.mask_expr, extra_score=final_extra_score) # step 3: put results if conf.pred_use_posi: # reuse the old ones, but replace label self.helper.put_labels(arr_gold_items, best_labs, best_scores) else: # make new frames self.helper.put_results(insts, best_labs, best_scores, cand_res.widx_expr, cand_res.wlen_expr, cand_res.mask_expr) # -- # finally return self._finish_pred(insts, input_expr, mask_expr, pair_expr, lookup_flatten)
def _cand_score_and_select(self, input_expr: BK.Expr, mask_expr: BK.Expr): conf: SoftExtractorConf = self.conf # -- cand_full_scores = self.cand_scorer(input_expr).squeeze( -1) + (1. - mask_expr) * Constants.REAL_PRAC_MIN # [*, slen] # decide topk count len_t = mask_expr.sum(-1) # [*] topk_t = (len_t * conf.cand_topk_rate).clamp( max=conf.cand_topk_count).ceil().long().unsqueeze(-1) # [*, 1] # get topk mask if BK.is_zero_shape(mask_expr): topk_mask = mask_expr.clone( ) # no need to go topk since no elements else: topk_mask = select_topk(cand_full_scores, topk_t, mask_t=mask_expr, dim=-1) # [*, slen] # thresh cand_decisions = topk_mask * (cand_full_scores >= conf.cand_thresh).float() # [*, slen] return cand_full_scores, cand_decisions # [*, slen]
def _pred_arg(self, mask_expr, pred_evt_labels): conf: ZDecoderSRLConf = self.conf slen = BK.get_shape(mask_expr, -1) # -- all_arg_raw_score = self.arg_node.buffer_scores.values() # [*, slen, slen, La] all_arg_logprobs = [z.log_softmax(-1) for z in all_arg_raw_score] final_arg_logprobs = self.arg_node.helper.pred(all_logprobs=all_arg_logprobs) # [*, slen, slen, La] # slightly more efficient by masking valid evts?? full_pred_shape = BK.get_shape(final_arg_logprobs)[:-1] # [*, slen, slen] pred_arg_scores, pred_arg_labels = BK.zeros(full_pred_shape), BK.zeros(full_pred_shape).long() # mask arg_flat_mask = (pred_evt_labels > 0) # [*, slen] flat_arg_logprobs = final_arg_logprobs[arg_flat_mask] # [??, slen, La] if not BK.is_zero_shape(flat_arg_logprobs): # at least one predicate if self.pred_cons_mat is not None: flat_mask_expr = mask_expr.unsqueeze(-2).expand(-1, slen, slen)[arg_flat_mask] # [*, 1->slen, slen] => [??, slen] flat_pred_arg_labels, flat_pred_arg_scores = BigramInferenceHelper.inference_search( flat_arg_logprobs, self.pred_cons_mat, flat_mask_expr, conf.arg_beam_k) # [??, slen] else: flat_pred_arg_scores, flat_pred_arg_labels = flat_arg_logprobs.max(-1) # [??, slen] pred_arg_scores[arg_flat_mask] = flat_pred_arg_scores pred_arg_labels[arg_flat_mask] = flat_pred_arg_labels return pred_arg_labels, pred_arg_scores # [*, slen, slen, La]
def _loss_feed_cand(self, mask_expr, cand_full_scores, pred_cand_decisions, expr_seq_gaddr, expr_group_widxes, expr_group_masks, expr_loss_weight_non): conf: SoftExtractorConf = self.conf bsize, slen = BK.get_shape(mask_expr) arange3_t = BK.arange_idx(bsize).unsqueeze(-1).unsqueeze( -1) # [*, 1, 1] # -- # step 1.1: bag loss cand_gold_mask = (expr_seq_gaddr >= 0).float() * mask_expr # [*, slen], whether is-arg raw_loss_cand = BK.loss_binary( cand_full_scores, cand_gold_mask, label_smoothing=conf.cand_label_smoothing) # [*, slen] # how to weight? extended_scores_t = cand_full_scores[arange3_t, expr_group_widxes] + ( 1. - expr_group_masks) * Constants.REAL_PRAC_MIN # [*, slen, MW] if BK.is_zero_shape(extended_scores_t): extended_scores_max_t = BK.zeros(mask_expr.shape) # [*, slen] else: extended_scores_max_t, _ = extended_scores_t.max(-1) # [*, slen] _w_alpha = conf.cand_loss_weight_alpha _weight = ((cand_full_scores - extended_scores_max_t) * _w_alpha).exp() # [*, slen] if not conf.cand_loss_div_max: # div sum-all, like doing softmax _weight = _weight / ( (extended_scores_t - extended_scores_max_t.unsqueeze(-1)) * _w_alpha).exp().sum(-1) _weight = _weight * (_weight >= conf.cand_loss_weight_thresh).float() # [*, slen] if conf.cand_detach_weight: _weight = _weight.detach() # pos poison (dis-encouragement) if conf.cand_loss_pos_poison: poison_loss = BK.loss_binary( cand_full_scores, 1. - cand_gold_mask, label_smoothing=conf.cand_label_smoothing) # [*, slen] raw_loss_cand = raw_loss_cand * _weight + poison_loss * cand_gold_mask * ( 1. - _weight) # [*, slen] else: raw_loss_cand = raw_loss_cand * _weight # final weight it cand_loss_weights = BK.where(cand_gold_mask == 0., expr_loss_weight_non.unsqueeze(-1) * conf.loss_weight_non, mask_expr) # [*, slen] final_cand_loss_weights = cand_loss_weights * mask_expr # [*, slen] loss_cand_item = LossHelper.compile_leaf_loss( f"cand", (raw_loss_cand * final_cand_loss_weights).sum(), final_cand_loss_weights.sum(), loss_lambda=conf.loss_cand) # step 1.2: feed cand # todo(+N): currently only pred/sample, whether adding certain teacher-forcing? sample_decisions = (BK.sigmoid(cand_full_scores) >= BK.rand( cand_full_scores.shape)).float() * mask_expr # [*, slen] _use_sample_mask = (BK.rand([bsize]) <= conf.cand_feed_sample_rate).float().unsqueeze( -1) # [*, 1], seq-level feed_cand_decisions = (_use_sample_mask * sample_decisions + (1. - _use_sample_mask) * pred_cand_decisions ) # [*, slen] # next cand_widxes, cand_masks = BK.mask2idx(feed_cand_decisions) # [*, clen] # -- # extra: loss_cand_entropy rets = [loss_cand_item] _loss_cand_entropy = conf.loss_cand_entropy if _loss_cand_entropy > 0.: _prob = extended_scores_t.softmax(-1) # [*, slen, MW] _ent = EntropyHelper.self_entropy(_prob) # [*, slen] # [*, slen], only first one in bag _ent_mask = BK.concat([ expr_seq_gaddr[:, :1] >= 0, expr_seq_gaddr[:, 1:] != expr_seq_gaddr[:, :-1] ], -1).float() * cand_gold_mask _loss_ent_item = LossHelper.compile_leaf_loss( f"cand_ent", (_ent * _ent_mask).sum(), _ent_mask.sum(), loss_lambda=_loss_cand_entropy) rets.append(_loss_ent_item) # -- return rets, cand_widxes, cand_masks
def loss(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr, pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr = None): conf: AnchorExtractorConf = self.conf assert not lookup_flatten bsize, slen = BK.get_shape(mask_expr) # -- # step 0: prepare arr_items, expr_seq_gaddr, expr_seq_labs, expr_group_widxes, expr_group_masks, expr_loss_weight_non = \ self.helper.prepare(insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True) arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [*, 1] arange3_t = arange2_t.unsqueeze(-1) # [*, 1, 1] # -- # step 1: label, simply scoring everything! _main_t, _pair_t = self.lab_node.transform_expr(input_expr, pair_expr) all_scores_t = self.lab_node.score_all( _main_t, _pair_t, mask_expr, None, local_normalize=False, extra_score=external_extra_score ) # unnormalized scores [*, slen, L] all_probs_t = all_scores_t.softmax(-1) # [*, slen, L] all_gprob_t = all_probs_t.gather(-1, expr_seq_labs.unsqueeze(-1)).squeeze( -1) # [*, slen] # how to weight extended_gprob_t = all_gprob_t[ arange3_t, expr_group_widxes] * expr_group_masks # [*, slen, MW] if BK.is_zero_shape(extended_gprob_t): extended_gprob_max_t = BK.zeros(mask_expr.shape) # [*, slen] else: extended_gprob_max_t, _ = extended_gprob_t.max(-1) # [*, slen] _w_alpha = conf.cand_loss_weight_alpha _weight = ( (all_gprob_t * mask_expr) / (extended_gprob_max_t.clamp(min=1e-5)))**_w_alpha # [*, slen] _label_smoothing = conf.lab_conf.labeler_conf.label_smoothing _loss1 = BK.loss_nll(all_scores_t, expr_seq_labs, label_smoothing=_label_smoothing) # [*, slen] _loss2 = BK.loss_nll(all_scores_t, BK.constants_idx([bsize, slen], 0), label_smoothing=_label_smoothing) # [*, slen] _weight1 = _weight.detach() if conf.detach_weight_lab else _weight _raw_loss = _weight1 * _loss1 + (1. - _weight1) * _loss2 # [*, slen] # final weight it cand_loss_weights = BK.where(expr_seq_labs == 0, expr_loss_weight_non.unsqueeze(-1) * conf.loss_weight_non, mask_expr) # [*, slen] final_cand_loss_weights = cand_loss_weights * mask_expr # [*, slen] loss_lab_item = LossHelper.compile_leaf_loss( f"lab", (_raw_loss * final_cand_loss_weights).sum(), final_cand_loss_weights.sum(), loss_lambda=conf.loss_lab, gold=(expr_seq_labs > 0).float().sum()) # -- # step 1.5 all_losses = [loss_lab_item] _loss_cand_entropy = conf.loss_cand_entropy if _loss_cand_entropy > 0.: _prob = extended_gprob_t # [*, slen, MW] _ent = EntropyHelper.self_entropy(_prob) # [*, slen] # [*, slen], only first one in bag _ent_mask = BK.concat([expr_seq_gaddr[:,:1]>=0, expr_seq_gaddr[:,1:]!=expr_seq_gaddr[:,:-1]],-1).float() \ * (expr_seq_labs>0).float() _loss_ent_item = LossHelper.compile_leaf_loss( f"cand_ent", (_ent * _ent_mask).sum(), _ent_mask.sum(), loss_lambda=_loss_cand_entropy) all_losses.append(_loss_ent_item) # -- # step 4: extend (select topk) if conf.loss_ext > 0.: if BK.is_zero_shape(extended_gprob_t): flt_mask = (BK.zeros(mask_expr.shape) > 0) else: _topk = min(conf.ext_loss_topk, BK.get_shape(extended_gprob_t, -1)) # number to extract _topk_grpob_t, _ = extended_gprob_t.topk( _topk, dim=-1) # [*, slen, K] flt_mask = (expr_seq_labs > 0) & (all_gprob_t >= _topk_grpob_t.min(-1)[0]) & ( _weight > conf.ext_loss_thresh) # [*, slen] flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[ flt_mask] # [?] flt_expr = input_expr[flt_mask] # [?, D] flt_full_expr = self._prepare_full_expr(flt_mask) # [?, slen, D] flt_items = arr_items.flatten()[BK.get_value( expr_seq_gaddr[flt_mask])] # [?] flt_weights = _weight.detach( )[flt_mask] if conf.detach_weight_ext else _weight[flt_mask] # [?] loss_ext_item = self.ext_node.loss(flt_items, input_expr[flt_sidx], flt_expr, flt_full_expr, mask_expr[flt_sidx], flt_extra_weights=flt_weights) all_losses.append(loss_ext_item) # -- # return loss ret_loss = LossHelper.combine_multiple_losses(all_losses) return ret_loss, None
def predict(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr, pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr = None): conf: SoftExtractorConf = self.conf assert not lookup_flatten bsize, slen = BK.get_shape(mask_expr) # -- for inst in insts: # first clear things self.helper._clear_f(inst) # -- # step 1: cand score and select cand_full_scores, cand_decisions = self._cand_score_and_select( input_expr, mask_expr) # [*, slen] cand_widxes, cand_masks = BK.mask2idx(cand_decisions) # [*, clen] # step 2: split and seg arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [*, 1] arange3_t = BK.arange_idx(bsize).unsqueeze(-1).unsqueeze( -1) # [*, 1, 1] cand_expr, cand_scores = input_expr[ arange2_t, cand_widxes], cand_full_scores[arange2_t, cand_widxes] # [*, clen] split_scores, split_decisions = self._split_score( cand_expr, cand_masks) # [*, clen-1] # *[*, seglen, MW], [*, seglen] seg_ext_cidxes, seg_ext_masks, seg_masks = self._split_extend( split_decisions, cand_masks) seg_ext_widxes0, seg_ext_masks0 = cand_widxes[ arange3_t, seg_ext_cidxes], seg_ext_masks # [*, seglen, ORIG-MW] seg_ext_scores, seg_ext_cidxes, seg_ext_widxes, seg_ext_masks, seg_weighted_expr = self._split_aggregate( cand_expr, cand_scores, cand_widxes, seg_ext_cidxes, seg_ext_masks, conf.split_topk) # [*, seglen, ?] # -- # step 3: lab flt_items = [] if not BK.is_zero_shape(seg_masks): best_labs, best_scores = self.lab_node.predict( seg_weighted_expr, pair_expr, seg_masks, extra_score=external_extra_score) # *[*, seglen] flt_items = self.helper.put_results( insts, best_labs, best_scores, seg_masks, seg_ext_widxes0, seg_ext_widxes, seg_ext_masks0, seg_ext_masks, cand_full_scores, cand_decisions, split_decisions) # -- # step 4: final extend (in a flattened way) if len(flt_items) > 0 and conf.pred_ext: flt_mask = ((best_labs > 0) & (seg_masks > 0.)) # [*, seglen] flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[ flt_mask] # [?] flt_expr = seg_weighted_expr[flt_mask] # [?, D] flt_full_expr = self._prepare_full_expr(seg_ext_widxes[flt_mask], seg_ext_masks[flt_mask], slen) # [?, slen, D] self.ext_node.predict(flt_items, input_expr[flt_sidx], flt_expr, flt_full_expr, mask_expr[flt_sidx]) # -- # extra: self.pp_node.prune(insts) return None
def loss(self, input_main: BK.Expr, input_pair: BK.Expr, input_mask: BK.Expr, gold_idxes: BK.Expr, loss_weight_expr: BK.Expr = None, extra_score: BK.Expr = None): conf: SeqLabelerConf = self.conf # -- expr_main, expr_pair = self.transform_expr(input_main, input_pair) if self.loss_mle: # simply collect them all (not normalize here!) all_scores_t = self.score_all( expr_main, expr_pair, input_mask, gold_idxes, local_normalize=False, extra_score=extra_score) # [*, slen, L] # negative log likelihood; todo(+1): repeat log-softmax here # all_losses_t = - all_scores_t.gather(-1, gold_idxes.unsqueeze(-1)).squeeze(-1) * input_mask # [*, slen] all_losses_t = BK.loss_nll( all_scores_t, gold_idxes, label_smoothing=self.conf.labeler_conf.label_smoothing) # [*] all_losses_t *= input_mask if loss_weight_expr is not None: all_losses_t *= loss_weight_expr ret_loss = all_losses_t.sum() # [] elif self.loss_crf: # no normalization & no bigram single_scores_t = self.score_all( expr_main, expr_pair, input_mask, None, use_bigram=False, extra_score=extra_score) # [*, slen, L] mat_t = self.bigram.get_matrix() # [L, L] if BK.is_zero_shape(single_scores_t): # note: avoid empty potential_t = BK.zeros( BK.get_shape(single_scores_t)[:-2]) # [*] else: potential_t = BigramInferenceHelper.inference_forward( single_scores_t, mat_t, input_mask, conf.beam_k) # [*] gold_single_scores_t = single_scores_t.gather( -1, gold_idxes.unsqueeze(-1)).squeeze(-1) * input_mask # [*, slen] gold_bigram_scores_t = mat_t[ gold_idxes[:, :-1], gold_idxes[:, 1:]] * input_mask[:, 1:] # [*, slen-1] all_losses_t = ( potential_t - (gold_single_scores_t.sum(-1) + gold_bigram_scores_t.sum(-1)) ) # [*] # todo(+N): also no label_smoothing for crf # todo(+N): for now, ignore loss_weight for crf mode!! # if loss_weight_expr is not None: # assert BK.get_shape(loss_weight_expr, -1) == 1, "Currently CRF loss requires seq level loss_weight!!" # all_losses_t *= loss_weight_expr ret_loss = all_losses_t.sum() # [] else: raise NotImplementedError() # ret_count if conf.loss_by_tok: # sum all valid toks if conf.loss_by_tok_weighted and loss_weight_expr is not None: ret_count = (input_mask * loss_weight_expr).sum() else: ret_count = input_mask.sum() else: # sum all valid batch items ret_count = input_mask.prod(-1).sum() return (ret_loss, ret_count)