def loss(self, insts: List[Sent], input_expr: BK.Expr, mask_expr: BK.Expr): conf: MySRLConf = self.conf # -- slen = BK.get_shape(mask_expr, -1) arr_items, expr_evt_labels, expr_arg_labels, expr_loss_weight_non = self.helper.prepare(insts, True) if conf.binary_evt: expr_evt_labels = (expr_evt_labels>0).long() # either 0 or 1 loss_items = [] # ===== # evt # -- prepare weights and masks evt_not_nil = (expr_evt_labels>0) # [*, slen] evt_extra_weights = BK.where(evt_not_nil, mask_expr, expr_loss_weight_non.unsqueeze(-1)*conf.evt_loss_weight_non) evt_weights = self._prepare_loss_weights(mask_expr, evt_not_nil, conf.evt_loss_sample_neg, evt_extra_weights) # -- get losses _, all_evt_cfs, all_evt_scores = self.evt_node.get_all_values() # [*, slen] all_evt_losses = [] for one_evt_scores in all_evt_scores: one_losses = BK.loss_nll(one_evt_scores, expr_evt_labels, label_smoothing=conf.evt_label_smoothing) all_evt_losses.append(one_losses) evt_loss_results = self.evt_node.helper.loss(all_losses=all_evt_losses, all_cfs=all_evt_cfs) for loss_t, loss_alpha, loss_name in evt_loss_results: one_evt_item = LossHelper.compile_leaf_loss("evt"+loss_name, (loss_t*evt_weights).sum(), evt_weights.sum(), loss_lambda=conf.loss_evt*loss_alpha, gold=evt_not_nil.float().sum()) loss_items.append(one_evt_item) # ===== # arg _arg_loss_evt_sample_neg = conf.arg_loss_evt_sample_neg if _arg_loss_evt_sample_neg > 0: arg_evt_masks = ((BK.rand(mask_expr.shape)<_arg_loss_evt_sample_neg) | evt_not_nil).float() * mask_expr else: arg_evt_masks = evt_not_nil.float() # [*, slen] # expand/flat the dims arg_flat_mask = (arg_evt_masks > 0) # [*, slen] flat_mask_expr = mask_expr.unsqueeze(-2).expand(-1, slen, slen)[arg_flat_mask] # [*, 1->slen, slen] => [??, slen] flat_arg_labels = expr_arg_labels[arg_flat_mask] # [??, slen] flat_arg_not_nil = (flat_arg_labels > 0) # [??, slen] flat_arg_weights = self._prepare_loss_weights(flat_mask_expr, flat_arg_not_nil, conf.arg_loss_sample_neg) # -- get losses _, all_arg_cfs, all_arg_scores = self.arg_node.get_all_values() # [*, slen, slen] all_arg_losses = [] for one_arg_scores in all_arg_scores: one_flat_arg_scores = one_arg_scores[arg_flat_mask] # [??, slen] one_losses = BK.loss_nll(one_flat_arg_scores, flat_arg_labels, label_smoothing=conf.evt_label_smoothing) all_arg_losses.append(one_losses) all_arg_cfs = [z[arg_flat_mask] for z in all_arg_cfs] # [??, slen] arg_loss_results = self.arg_node.helper.loss(all_losses=all_arg_losses, all_cfs=all_arg_cfs) for loss_t, loss_alpha, loss_name in arg_loss_results: one_arg_item = LossHelper.compile_leaf_loss("arg"+loss_name, (loss_t*flat_arg_weights).sum(), flat_arg_weights.sum(), loss_lambda=conf.loss_arg*loss_alpha, gold=flat_arg_not_nil.float().sum()) loss_items.append(one_arg_item) # ===== # return loss ret_loss = LossHelper.combine_multiple_losses(loss_items) return ret_loss
def loss(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr, pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr=None): conf: DirectExtractorConf = self.conf # step 0: prepare golds arr_gold_items, expr_gold_gaddr, expr_gold_widxes, expr_gold_wlens, expr_loss_weight_non = \ self.helper.prepare(insts, use_cache=True) # step 1: extract cands if conf.loss_use_posi: cand_res = self.extract_node.go_lookup( input_expr, expr_gold_widxes, expr_gold_wlens, (expr_gold_gaddr>=0).float(), gaddr_expr=expr_gold_gaddr) else: # todo(note): assume no in-middle mask!! cand_widx, cand_wlen, cand_mask, cand_gaddr = self.extract_node.prepare_with_lengths( BK.get_shape(mask_expr), mask_expr.sum(-1).long(), expr_gold_widxes, expr_gold_wlens, expr_gold_gaddr) if conf.span_train_sample: # simply do sampling cand_res = self.extract_node.go_sample( input_expr, mask_expr, cand_widx, cand_wlen, cand_mask, rate=conf.span_train_sample_rate, count=conf.span_train_sample_count, gaddr_expr=cand_gaddr, add_gold_rate=1.0) # note: always fully add gold for sampling!! else: # beam pruner using topk cand_res = self.extract_node.go_topk( input_expr, mask_expr, cand_widx, cand_wlen, cand_mask, rate=conf.span_topk_rate, count=conf.span_topk_count, gaddr_expr=cand_gaddr, add_gold_rate=conf.span_train_topk_add_gold_rate) # step 1+: prepare for labeling cand_gold_mask = (cand_res.gaddr_expr>=0).float() * cand_res.mask_expr # [*, cand_len] # todo(note): add a 0 as idx=-1 to make NEG ones as 0!! flatten_gold_label_idxes = BK.input_idx([(0 if z is None else z.label_idx) for z in arr_gold_items.flatten()] + [0]) gold_label_idxes = flatten_gold_label_idxes[cand_res.gaddr_expr] cand_loss_weights = BK.where(gold_label_idxes==0, expr_loss_weight_non.unsqueeze(-1)*conf.loss_weight_non, cand_res.mask_expr) final_loss_weights = cand_loss_weights * cand_res.mask_expr # cand loss if conf.loss_cand > 0. and not conf.loss_use_posi: loss_cand0 = BK.loss_binary(cand_res.score_expr, cand_gold_mask, label_smoothing=conf.cand_label_smoothing) loss_cand = (loss_cand0 * final_loss_weights).sum() loss_cand_item = LossHelper.compile_leaf_loss(f"cand", loss_cand, final_loss_weights.sum(), loss_lambda=conf.loss_cand) else: loss_cand_item = None # extra score cand_extra_score = self._get_extra_score( cand_res.score_expr, insts, cand_res, arr_gold_items, conf.loss_use_cons, conf.loss_use_lu) final_extra_score = self._sum_scores(external_extra_score, cand_extra_score) # step 2: label; with special weights loss_lab, loss_count = self.lab_node.loss( cand_res.span_expr, pair_expr, cand_res.mask_expr, gold_label_idxes, loss_weight_expr=final_loss_weights, extra_score=final_extra_score) loss_lab_item = LossHelper.compile_leaf_loss(f"lab", loss_lab, loss_count, loss_lambda=conf.loss_lab, gold=cand_gold_mask.sum()) # == # return loss ret_loss = LossHelper.combine_multiple_losses([loss_cand_item, loss_lab_item]) return self._finish_loss(ret_loss, insts, input_expr, mask_expr, pair_expr, lookup_flatten)
def loss(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr, pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr = None): conf: SeqExtractorConf = self.conf # step 0: prepare golds expr_gold_slabs, expr_loss_weight_non = self.helper.prepare( insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True) final_loss_weights = BK.where( expr_gold_slabs == 0, expr_loss_weight_non.unsqueeze(-1) * conf.loss_weight_non, mask_expr) # step 1: label; with special weights loss_lab, loss_count = self.lab_node.loss( input_expr, pair_expr, mask_expr, expr_gold_slabs, loss_weight_expr=final_loss_weights, extra_score=external_extra_score) loss_lab_item = LossHelper.compile_leaf_loss( f"lab", loss_lab, loss_count, loss_lambda=conf.loss_lab, gold=(expr_gold_slabs > 0).float().sum()) # == # return loss ret_loss = LossHelper.combine_multiple_losses([loss_lab_item]) return self._finish_loss(ret_loss, insts, input_expr, mask_expr, pair_expr, lookup_flatten)
def loss_cf(self, cf_scores: List[BK.Expr], insts, loss_cf: float): conf: SeqExitHelperConf = self.conf # -- assert self.is_cf # get oracle oracles = [self.cf_oracle_f(ff) for ff in insts] # bs*[NL, slen] or bs*[NL] rets = [] mask_t = BK.input_real( DataPadder.lengths2mask([len(z.sent) for z in insts])) # [bs, slen] for one_li, one_scores in enumerate(cf_scores): if conf.cf_use_seq: one_oracle_t = BK.input_real([z[one_li] for z in oracles]) # [bs] one_oracle_t *= conf.cf_scale one_mask_t = BK.zeros([len(one_oracle_t)]) + 1 else: one_oracle_t = BK.input_real( DataPadder.go_batch_2d([z[one_li] for z in oracles], 1.)) # [bs, slen] one_mask_t = (BK.rand(one_oracle_t.shape) >= ((one_oracle_t**conf.cf_loss_discard_curve) * conf.cf_loss_discard)) * mask_t one_oracle_t *= conf.cf_scale # simple L2 loss one_loss_t = (one_scores.squeeze(-1) - one_oracle_t)**2 one_loss_item = LossHelper.compile_leaf_loss(f"cf{one_li}", (one_loss_t * one_mask_t).sum(), one_mask_t.sum(), loss_lambda=loss_cf) rets.append(one_loss_item) return rets
def loss_regs(regs: List['ParamRegHelper']): loss_items = [] for ii, reg in enumerate(regs): if reg.reg_method_loss: _loss, _loss_lambda = reg.compute_loss() _loss_item = LossHelper.compile_leaf_loss(f'reg_{ii}', _loss, BK.input_real(1.), loss_lambda=_loss_lambda) loss_items.append(_loss_item) ret_loss = LossHelper.combine_multiple_losses(loss_items) return ret_loss
def _loss_feed_split(self, mask_expr, split_scores, pred_split_decisions, cand_widxes, cand_masks, cand_expr, cand_scores, expr_seq_gaddr): conf: SoftExtractorConf = self.conf bsize, slen = BK.get_shape(mask_expr) arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [*, 1, 1] # -- # step 2.1: split loss (only on good points (excluding -1|-1 or paddings) with dynamic oracle) cand_gaddr = expr_seq_gaddr[arange2_t, cand_widxes] # [*, clen] cand_gaddr0, cand_gaddr1 = cand_gaddr[:, : -1], cand_gaddr[:, 1:] # [*, clen-1] split_oracle = (cand_gaddr0 != cand_gaddr1).float() * cand_masks[:, 1:] # [*, clen-1] split_oracle_mask = ( (cand_gaddr0 >= 0) | (cand_gaddr1 >= 0)).float() * cand_masks[:, 1:] # [*, clen-1] raw_split_loss = BK.loss_binary( split_scores, split_oracle, label_smoothing=conf.split_label_smoothing) # [*, slen] loss_split_item = LossHelper.compile_leaf_loss( f"split", (raw_split_loss * split_oracle_mask).sum(), split_oracle_mask.sum(), loss_lambda=conf.loss_split) # step 2.2: feed split # note: when teacher-forcing, only forcing good points, others still use pred force_split_decisions = split_oracle_mask * split_oracle + ( 1. - split_oracle_mask) * pred_split_decisions # [*, clen-1] _use_force_mask = (BK.rand([bsize]) <= conf.split_feed_force_rate).float().unsqueeze( -1) # [*, 1], seq-level feed_split_decisions = (_use_force_mask * force_split_decisions + (1. - _use_force_mask) * pred_split_decisions ) # [*, clen-1] # next # *[*, seglen, MW], [*, seglen] seg_ext_cidxes, seg_ext_masks, seg_masks = self._split_extend( feed_split_decisions, cand_masks) seg_ext_scores, seg_ext_cidxes, seg_ext_widxes, seg_ext_masks, seg_weighted_expr = self._split_aggregate( cand_expr, cand_scores, cand_widxes, seg_ext_cidxes, seg_ext_masks, conf.split_topk) # [*, seglen, ?] # finally get oracles for next steps # todo(+N): simply select the highest scored one as oracle if BK.is_zero_shape(seg_ext_scores): # simply make them all -1 oracle_gaddr = BK.constants_idx(seg_masks.shape, -1) # [*, seglen] else: _, _seg_max_t = seg_ext_scores.max(-1, keepdim=True) # [*, seglen, 1] oracle_widxes = seg_ext_widxes.gather(-1, _seg_max_t).squeeze( -1) # [*, seglen] oracle_gaddr = expr_seq_gaddr.gather(-1, oracle_widxes) # [*, seglen] oracle_gaddr[seg_masks <= 0] = -1 # (assign invalid ones) [*, seglen] return loss_split_item, seg_masks, seg_ext_widxes, seg_ext_masks, seg_weighted_expr, oracle_gaddr
def loss(self, flt_items, flt_input_expr, flt_pair_expr, flt_full_expr, flt_mask_expr, flt_extra_weights=None): conf: ExtenderConf = self.conf _loss_lambda = conf._loss_lambda # -- enc_t = self._forward_eenc(flt_input_expr, flt_full_expr, flt_mask_expr) # [*, slen, D] s_left, s_right = self.enode.score(enc_t, flt_pair_expr if conf.ext_use_finput else None, flt_mask_expr) # [*, slen] # -- gold_posi = [self.ext_span_getter(z.mention) for z in flt_items] # List[(widx, wlen)] widx_t = BK.input_idx([z[0] for z in gold_posi]) # [*] wlen_t = BK.input_idx([z[1] for z in gold_posi]) loss_left_t, loss_right_t = BK.loss_nll(s_left, widx_t), BK.loss_nll(s_right, widx_t+wlen_t-1) # [*] if flt_extra_weights is not None: loss_left_t *= flt_extra_weights loss_right_t *= flt_extra_weights loss_div = flt_extra_weights.sum() # note: also use this! else: loss_div = BK.constants([len(flt_items)], value=1.).sum() loss_left_item = LossHelper.compile_leaf_loss("left", loss_left_t.sum(), loss_div, loss_lambda=_loss_lambda) loss_right_item = LossHelper.compile_leaf_loss("right", loss_right_t.sum(), loss_div, loss_lambda=_loss_lambda) ret_loss = LossHelper.combine_multiple_losses([loss_left_item, loss_right_item]) return ret_loss
def loss_from_lab(self, lab_node, score_name: str, med: ZMediator, label_t, mask_t, loss_lambda: float, loss_neg_sample: float = None): score_cache = med.get_cache((self.name, score_name)) loss_items = [] # note: simply collect them all all_losses = lab_node.gather_losses(score_cache.vals, label_t, mask_t, loss_neg_sample=loss_neg_sample) for ii, vv in enumerate(all_losses): nn = score_cache.infos[ii] _loss_t, _mask_t = vv _loss_item = LossHelper.compile_leaf_loss( f'{score_name}_{nn}', _loss_t.sum(), _mask_t.sum(), loss_lambda=loss_lambda) loss_items.append(_loss_item) return loss_items
def _loss_depth(self, med: ZMediator, mask_expr, expr_depth): conf: ZDecoderUDEPConf = self.conf # -- all_depth_scores = med.main_scores.get((self.name, "depth")) # [*, slen] all_depth_losses = [] for one_depth_scores in all_depth_scores: one_losses = BK.loss_binary(one_depth_scores.squeeze(-1), expr_depth, label_smoothing=conf.depth_label_smoothing) all_depth_losses.append(one_losses) depth_loss_results = self.depth_node.helper.loss(all_losses=all_depth_losses) loss_items = [] for loss_t, loss_alpha, loss_name in depth_loss_results: one_depth_item = LossHelper.compile_leaf_loss( "depth"+loss_name, (loss_t*mask_expr).sum(), mask_expr.sum(), loss_lambda=(loss_alpha*conf.loss_depth)) loss_items.append(one_depth_item) return loss_items
def _loss_upos(self, mask_expr, expr_upos_labels): conf: ZDecoderUPOSConf = self.conf all_upos_scores = self.upos_node.buffer_scores.values() # [*, slen, L] all_upos_losses = [] for one_upos_scores in all_upos_scores: one_losses = BK.loss_nll(one_upos_scores, expr_upos_labels, label_smoothing=conf.upos_label_smoothing) all_upos_losses.append(one_losses) upos_loss_results = self.upos_node.helper.loss( all_losses=all_upos_losses) loss_items = [] for loss_t, loss_alpha, loss_name in upos_loss_results: one_upos_item = LossHelper.compile_leaf_loss( "upos" + loss_name, (loss_t * mask_expr).sum(), mask_expr.sum(), loss_lambda=(loss_alpha * conf.loss_upos)) loss_items.append(one_upos_item) return loss_items
def _loss_udep(self, med: ZMediator, mask_expr, expr_udep): conf: ZDecoderUDEPConf = self.conf # -- all_udep_scores = med.main_scores.get((self.name, "udep")) # [*, slen, slen, L] all_udep_losses = [] for one_udep_scores in all_udep_scores: one_losses = BK.loss_nll(one_udep_scores, expr_udep, label_smoothing=conf.udep_label_smoothing) all_udep_losses.append(one_losses) udep_loss_results = self.udep_node.helper.loss(all_losses=all_udep_losses) # -- _loss_weights = ((BK.rand(expr_udep.shape) < conf.udep_loss_sample_neg) | (expr_udep>0)).float() \ * mask_expr.unsqueeze(-1) * mask_expr.unsqueeze(-2) # [*, slen, slen] # -- loss_items = [] for loss_t, loss_alpha, loss_name in udep_loss_results: one_udep_item = LossHelper.compile_leaf_loss( "udep"+loss_name, (loss_t*_loss_weights).sum(), _loss_weights.sum(), loss_lambda=(loss_alpha*conf.loss_udep)) loss_items.append(one_udep_item) return loss_items
def _loss_evt(self, mask_expr, expr_evt_labels, expr_loss_weight_non): conf: ZDecoderSRLConf = self.conf # -- loss_items = [] # -- prepare weights and masks evt_not_nil = (expr_evt_labels > 0) # [*, slen] evt_extra_weights = BK.where(evt_not_nil, mask_expr, expr_loss_weight_non.unsqueeze(-1) * conf.evt_loss_weight_non) evt_weights = self._prepare_loss_weights(mask_expr, evt_not_nil, conf.evt_loss_sample_neg, evt_extra_weights) # -- get losses all_evt_scores = self.evt_node.buffer_scores.values() # List([*,slen]) all_evt_losses = [] for one_evt_scores in all_evt_scores: one_losses = BK.loss_nll(one_evt_scores, expr_evt_labels, label_smoothing=conf.evt_label_smoothing) all_evt_losses.append(one_losses) evt_loss_results = self.evt_node.helper.loss(all_losses=all_evt_losses) for loss_t, loss_alpha, loss_name in evt_loss_results: one_evt_item = LossHelper.compile_leaf_loss( "evt" + loss_name, (loss_t * evt_weights).sum(), evt_weights.sum(), loss_lambda=conf.loss_evt * loss_alpha, gold=evt_not_nil.float().sum()) loss_items.append(one_evt_item) return loss_items
def _loss_arg(self, mask_expr, expr_evt_labels, expr_arg_labels, expr_arg2_labels): conf: ZDecoderSRLConf = self.conf # -- loss_items = [] slen = BK.get_shape(mask_expr, -1) evt_not_nil = (expr_evt_labels > 0) # [*, slen] # -- # first prepare evts to focus at arg_evt_masks = self._prepare_loss_weights(mask_expr, evt_not_nil, conf.arg_loss_inc_neg_evt) # [*, slen] # expand/flat the dims arg_flat_mask = (arg_evt_masks > 0) # [*, slen] flat_mask_expr = mask_expr.unsqueeze(-2).expand(-1, slen, slen)[arg_flat_mask] # [*, 1->slen, slen] => [??, slen] # -- get losses for aname, anode, glabels, sneg_rate, loss_mul, lsmooth in \ zip(["arg", "arg2"], [self.arg_node, self.arg2_node], [expr_arg_labels, expr_arg2_labels], [conf.arg_loss_sample_neg, conf.arg2_loss_sample_neg], [conf.loss_arg, conf.loss_arg2], [conf.arg_label_smoothing, conf.arg2_label_smoothing]): # -- if loss_mul <= 0.: continue # -- # prepare flat_arg_labels = glabels[arg_flat_mask] # [??, slen] flat_arg_not_nil = (flat_arg_labels > 0) # [??, slen] flat_arg_weights = self._prepare_loss_weights(flat_mask_expr, flat_arg_not_nil, sneg_rate) # scores and losses all_arg_scores = anode.buffer_scores.values() # [*, slen, slen] all_arg_losses = [] for one_arg_scores in all_arg_scores: one_flat_arg_scores = one_arg_scores[arg_flat_mask] # [??, slen] one_losses = BK.loss_nll(one_flat_arg_scores, flat_arg_labels, label_smoothing=lsmooth) all_arg_losses.append(one_losses) arg_loss_results = anode.helper.loss(all_losses=all_arg_losses) # collect losses for loss_t, loss_alpha, loss_name in arg_loss_results: one_arg_item = LossHelper.compile_leaf_loss( aname + loss_name, (loss_t * flat_arg_weights).sum(), flat_arg_weights.sum(), loss_lambda=(loss_mul*loss_alpha), gold=flat_arg_not_nil.float().sum()) loss_items.append(one_arg_item) return loss_items
def loss(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr, pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr = None): conf: AnchorExtractorConf = self.conf assert not lookup_flatten bsize, slen = BK.get_shape(mask_expr) # -- # step 0: prepare arr_items, expr_seq_gaddr, expr_seq_labs, expr_group_widxes, expr_group_masks, expr_loss_weight_non = \ self.helper.prepare(insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True) arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [*, 1] arange3_t = arange2_t.unsqueeze(-1) # [*, 1, 1] # -- # step 1: label, simply scoring everything! _main_t, _pair_t = self.lab_node.transform_expr(input_expr, pair_expr) all_scores_t = self.lab_node.score_all( _main_t, _pair_t, mask_expr, None, local_normalize=False, extra_score=external_extra_score ) # unnormalized scores [*, slen, L] all_probs_t = all_scores_t.softmax(-1) # [*, slen, L] all_gprob_t = all_probs_t.gather(-1, expr_seq_labs.unsqueeze(-1)).squeeze( -1) # [*, slen] # how to weight extended_gprob_t = all_gprob_t[ arange3_t, expr_group_widxes] * expr_group_masks # [*, slen, MW] if BK.is_zero_shape(extended_gprob_t): extended_gprob_max_t = BK.zeros(mask_expr.shape) # [*, slen] else: extended_gprob_max_t, _ = extended_gprob_t.max(-1) # [*, slen] _w_alpha = conf.cand_loss_weight_alpha _weight = ( (all_gprob_t * mask_expr) / (extended_gprob_max_t.clamp(min=1e-5)))**_w_alpha # [*, slen] _label_smoothing = conf.lab_conf.labeler_conf.label_smoothing _loss1 = BK.loss_nll(all_scores_t, expr_seq_labs, label_smoothing=_label_smoothing) # [*, slen] _loss2 = BK.loss_nll(all_scores_t, BK.constants_idx([bsize, slen], 0), label_smoothing=_label_smoothing) # [*, slen] _weight1 = _weight.detach() if conf.detach_weight_lab else _weight _raw_loss = _weight1 * _loss1 + (1. - _weight1) * _loss2 # [*, slen] # final weight it cand_loss_weights = BK.where(expr_seq_labs == 0, expr_loss_weight_non.unsqueeze(-1) * conf.loss_weight_non, mask_expr) # [*, slen] final_cand_loss_weights = cand_loss_weights * mask_expr # [*, slen] loss_lab_item = LossHelper.compile_leaf_loss( f"lab", (_raw_loss * final_cand_loss_weights).sum(), final_cand_loss_weights.sum(), loss_lambda=conf.loss_lab, gold=(expr_seq_labs > 0).float().sum()) # -- # step 1.5 all_losses = [loss_lab_item] _loss_cand_entropy = conf.loss_cand_entropy if _loss_cand_entropy > 0.: _prob = extended_gprob_t # [*, slen, MW] _ent = EntropyHelper.self_entropy(_prob) # [*, slen] # [*, slen], only first one in bag _ent_mask = BK.concat([expr_seq_gaddr[:,:1]>=0, expr_seq_gaddr[:,1:]!=expr_seq_gaddr[:,:-1]],-1).float() \ * (expr_seq_labs>0).float() _loss_ent_item = LossHelper.compile_leaf_loss( f"cand_ent", (_ent * _ent_mask).sum(), _ent_mask.sum(), loss_lambda=_loss_cand_entropy) all_losses.append(_loss_ent_item) # -- # step 4: extend (select topk) if conf.loss_ext > 0.: if BK.is_zero_shape(extended_gprob_t): flt_mask = (BK.zeros(mask_expr.shape) > 0) else: _topk = min(conf.ext_loss_topk, BK.get_shape(extended_gprob_t, -1)) # number to extract _topk_grpob_t, _ = extended_gprob_t.topk( _topk, dim=-1) # [*, slen, K] flt_mask = (expr_seq_labs > 0) & (all_gprob_t >= _topk_grpob_t.min(-1)[0]) & ( _weight > conf.ext_loss_thresh) # [*, slen] flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[ flt_mask] # [?] flt_expr = input_expr[flt_mask] # [?, D] flt_full_expr = self._prepare_full_expr(flt_mask) # [?, slen, D] flt_items = arr_items.flatten()[BK.get_value( expr_seq_gaddr[flt_mask])] # [?] flt_weights = _weight.detach( )[flt_mask] if conf.detach_weight_ext else _weight[flt_mask] # [?] loss_ext_item = self.ext_node.loss(flt_items, input_expr[flt_sidx], flt_expr, flt_full_expr, mask_expr[flt_sidx], flt_extra_weights=flt_weights) all_losses.append(loss_ext_item) # -- # return loss ret_loss = LossHelper.combine_multiple_losses(all_losses) return ret_loss, None
def loss(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr, pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr = None): conf: SoftExtractorConf = self.conf assert not lookup_flatten bsize, slen = BK.get_shape(mask_expr) # -- # step 0: prepare arr_items, expr_seq_gaddr, expr_group_widxes, expr_group_masks, expr_loss_weight_non = \ self.helper.prepare(insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True) arange2_t = BK.arange_idx(bsize).unsqueeze(-1) # [*, 1] # -- # step 1: cand cand_full_scores, pred_cand_decisions = self._cand_score_and_select( input_expr, mask_expr) # [*, slen] loss_cand_items, cand_widxes, cand_masks = self._loss_feed_cand( mask_expr, cand_full_scores, pred_cand_decisions, expr_seq_gaddr, expr_group_widxes, expr_group_masks, expr_loss_weight_non, ) # ~, [*, clen] # -- # step 2: split cand_expr, cand_scores = input_expr[ arange2_t, cand_widxes], cand_full_scores[arange2_t, cand_widxes] # [*, clen] split_scores, pred_split_decisions = self._split_score( cand_expr, cand_masks) # [*, clen-1] loss_split_item, seg_masks, seg_ext_widxes, seg_ext_masks, seg_weighted_expr, oracle_gaddr = self._loss_feed_split( mask_expr, split_scores, pred_split_decisions, cand_widxes, cand_masks, cand_expr, cand_scores, expr_seq_gaddr, ) # ~, [*, seglen, *?] # -- # step 3: lab # todo(note): add a 0 as idx=-1 to make NEG ones as 0!! flatten_gold_label_idxes = BK.input_idx( [(0 if z is None else z.label_idx) for z in arr_items.flatten()] + [0]) gold_label_idxes = flatten_gold_label_idxes[ oracle_gaddr] # [*, seglen] lab_loss_weights = BK.where(oracle_gaddr >= 0, expr_loss_weight_non.unsqueeze(-1) * conf.loss_weight_non, seg_masks) # [*, seglen] final_lab_loss_weights = lab_loss_weights * seg_masks # [*, seglen] # go loss_lab, loss_count = self.lab_node.loss( seg_weighted_expr, pair_expr, seg_masks, gold_label_idxes, loss_weight_expr=final_lab_loss_weights, extra_score=external_extra_score) loss_lab_item = LossHelper.compile_leaf_loss( f"lab", loss_lab, loss_count, loss_lambda=conf.loss_lab, gold=(gold_label_idxes > 0).float().sum()) # step 4: extend flt_mask = ((gold_label_idxes > 0) & (seg_masks > 0.)) # [*, seglen] flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[ flt_mask] # [?] flt_expr = seg_weighted_expr[flt_mask] # [?, D] flt_full_expr = self._prepare_full_expr(seg_ext_widxes[flt_mask], seg_ext_masks[flt_mask], slen) # [?, slen, D] flt_items = arr_items.flatten()[BK.get_value( oracle_gaddr[flt_mask])] # [?] loss_ext_item = self.ext_node.loss(flt_items, input_expr[flt_sidx], flt_expr, flt_full_expr, mask_expr[flt_sidx]) # -- # return loss ret_loss = LossHelper.combine_multiple_losses( loss_cand_items + [loss_split_item, loss_lab_item, loss_ext_item]) return ret_loss, None
def _loss_feed_cand(self, mask_expr, cand_full_scores, pred_cand_decisions, expr_seq_gaddr, expr_group_widxes, expr_group_masks, expr_loss_weight_non): conf: SoftExtractorConf = self.conf bsize, slen = BK.get_shape(mask_expr) arange3_t = BK.arange_idx(bsize).unsqueeze(-1).unsqueeze( -1) # [*, 1, 1] # -- # step 1.1: bag loss cand_gold_mask = (expr_seq_gaddr >= 0).float() * mask_expr # [*, slen], whether is-arg raw_loss_cand = BK.loss_binary( cand_full_scores, cand_gold_mask, label_smoothing=conf.cand_label_smoothing) # [*, slen] # how to weight? extended_scores_t = cand_full_scores[arange3_t, expr_group_widxes] + ( 1. - expr_group_masks) * Constants.REAL_PRAC_MIN # [*, slen, MW] if BK.is_zero_shape(extended_scores_t): extended_scores_max_t = BK.zeros(mask_expr.shape) # [*, slen] else: extended_scores_max_t, _ = extended_scores_t.max(-1) # [*, slen] _w_alpha = conf.cand_loss_weight_alpha _weight = ((cand_full_scores - extended_scores_max_t) * _w_alpha).exp() # [*, slen] if not conf.cand_loss_div_max: # div sum-all, like doing softmax _weight = _weight / ( (extended_scores_t - extended_scores_max_t.unsqueeze(-1)) * _w_alpha).exp().sum(-1) _weight = _weight * (_weight >= conf.cand_loss_weight_thresh).float() # [*, slen] if conf.cand_detach_weight: _weight = _weight.detach() # pos poison (dis-encouragement) if conf.cand_loss_pos_poison: poison_loss = BK.loss_binary( cand_full_scores, 1. - cand_gold_mask, label_smoothing=conf.cand_label_smoothing) # [*, slen] raw_loss_cand = raw_loss_cand * _weight + poison_loss * cand_gold_mask * ( 1. - _weight) # [*, slen] else: raw_loss_cand = raw_loss_cand * _weight # final weight it cand_loss_weights = BK.where(cand_gold_mask == 0., expr_loss_weight_non.unsqueeze(-1) * conf.loss_weight_non, mask_expr) # [*, slen] final_cand_loss_weights = cand_loss_weights * mask_expr # [*, slen] loss_cand_item = LossHelper.compile_leaf_loss( f"cand", (raw_loss_cand * final_cand_loss_weights).sum(), final_cand_loss_weights.sum(), loss_lambda=conf.loss_cand) # step 1.2: feed cand # todo(+N): currently only pred/sample, whether adding certain teacher-forcing? sample_decisions = (BK.sigmoid(cand_full_scores) >= BK.rand( cand_full_scores.shape)).float() * mask_expr # [*, slen] _use_sample_mask = (BK.rand([bsize]) <= conf.cand_feed_sample_rate).float().unsqueeze( -1) # [*, 1], seq-level feed_cand_decisions = (_use_sample_mask * sample_decisions + (1. - _use_sample_mask) * pred_cand_decisions ) # [*, slen] # next cand_widxes, cand_masks = BK.mask2idx(feed_cand_decisions) # [*, clen] # -- # extra: loss_cand_entropy rets = [loss_cand_item] _loss_cand_entropy = conf.loss_cand_entropy if _loss_cand_entropy > 0.: _prob = extended_scores_t.softmax(-1) # [*, slen, MW] _ent = EntropyHelper.self_entropy(_prob) # [*, slen] # [*, slen], only first one in bag _ent_mask = BK.concat([ expr_seq_gaddr[:, :1] >= 0, expr_seq_gaddr[:, 1:] != expr_seq_gaddr[:, :-1] ], -1).float() * cand_gold_mask _loss_ent_item = LossHelper.compile_leaf_loss( f"cand_ent", (_ent * _ent_mask).sum(), _ent_mask.sum(), loss_lambda=_loss_cand_entropy) rets.append(_loss_ent_item) # -- return rets, cand_widxes, cand_masks