Beispiel #1
0
 def loss(self,
          insts: Union[List[Sent], List[Frame]],
          input_expr: BK.Expr,
          mask_expr: BK.Expr,
          pair_expr: BK.Expr = None,
          lookup_flatten=False,
          external_extra_score: BK.Expr = None):
     conf: SeqExtractorConf = self.conf
     # step 0: prepare golds
     expr_gold_slabs, expr_loss_weight_non = self.helper.prepare(
         insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True)
     final_loss_weights = BK.where(
         expr_gold_slabs == 0,
         expr_loss_weight_non.unsqueeze(-1) * conf.loss_weight_non,
         mask_expr)
     # step 1: label; with special weights
     loss_lab, loss_count = self.lab_node.loss(
         input_expr,
         pair_expr,
         mask_expr,
         expr_gold_slabs,
         loss_weight_expr=final_loss_weights,
         extra_score=external_extra_score)
     loss_lab_item = LossHelper.compile_leaf_loss(
         f"lab",
         loss_lab,
         loss_count,
         loss_lambda=conf.loss_lab,
         gold=(expr_gold_slabs > 0).float().sum())
     # ==
     # return loss
     ret_loss = LossHelper.combine_multiple_losses([loss_lab_item])
     return self._finish_loss(ret_loss, insts, input_expr, mask_expr,
                              pair_expr, lookup_flatten)
Beispiel #2
0
 def loss(self, insts: List[Sent], input_expr: BK.Expr, mask_expr: BK.Expr):
     conf: MyFramerConf = self.conf
     # --
     all_losses = []
     # evt
     if conf.loss_evt > 0.:
         evt_loss, evt_res = self.evt_extractor.loss(insts, input_expr, mask_expr)
         one_loss = LossHelper.compile_component_loss("evt", [evt_loss], loss_lambda=conf.loss_evt)
         all_losses.append(one_loss)
     else:
         evt_res = None
     # arg
     if conf.loss_arg > 0.:
         if evt_res is None:
             evt_res = self.evt_extractor.lookup_flatten(insts, input_expr, mask_expr)
         flt_items, flt_sidx, flt_expr, flt_full_expr = evt_res  # flatten to make dim0 -> frames
         flt_input_expr, flt_mask_expr = input_expr[flt_sidx], mask_expr[flt_sidx]
         flt_fenc_expr = self._forward_fenc(flt_input_expr, flt_full_expr, flt_mask_expr)  # [**, slen, D]
         arg_loss, _ = self.arg_extractor.loss(
             flt_items, flt_fenc_expr, flt_mask_expr, pair_expr=(flt_expr if conf.arg_use_finput else None),
             external_extra_score=self._get_arg_external_extra_score(flt_items))
         one_loss = LossHelper.compile_component_loss("arg", [arg_loss], loss_lambda=conf.loss_arg)
         all_losses.append(one_loss)
     # --
     ret_loss = LossHelper.combine_multiple_losses(all_losses)
     return ret_loss
Beispiel #3
0
 def loss_regs(regs: List['ParamRegHelper']):
     loss_items = []
     for ii, reg in enumerate(regs):
         if reg.reg_method_loss:
             _loss, _loss_lambda = reg.compute_loss()
             _loss_item = LossHelper.compile_leaf_loss(f'reg_{ii}', _loss, BK.input_real(1.), loss_lambda=_loss_lambda)
             loss_items.append(_loss_item)
     ret_loss = LossHelper.combine_multiple_losses(loss_items)
     return ret_loss
Beispiel #4
0
 def loss(self, insts: List[Sent], input_expr: BK.Expr, mask_expr: BK.Expr):
     conf: MySRLConf = self.conf
     # --
     slen = BK.get_shape(mask_expr, -1)
     arr_items, expr_evt_labels, expr_arg_labels, expr_loss_weight_non = self.helper.prepare(insts, True)
     if conf.binary_evt:
         expr_evt_labels = (expr_evt_labels>0).long()  # either 0 or 1
     loss_items = []
     # =====
     # evt
     # -- prepare weights and masks
     evt_not_nil = (expr_evt_labels>0)  # [*, slen]
     evt_extra_weights = BK.where(evt_not_nil, mask_expr, expr_loss_weight_non.unsqueeze(-1)*conf.evt_loss_weight_non)
     evt_weights = self._prepare_loss_weights(mask_expr, evt_not_nil, conf.evt_loss_sample_neg, evt_extra_weights)
     # -- get losses
     _, all_evt_cfs, all_evt_scores = self.evt_node.get_all_values()  # [*, slen]
     all_evt_losses = []
     for one_evt_scores in all_evt_scores:
         one_losses = BK.loss_nll(one_evt_scores, expr_evt_labels, label_smoothing=conf.evt_label_smoothing)
         all_evt_losses.append(one_losses)
     evt_loss_results = self.evt_node.helper.loss(all_losses=all_evt_losses, all_cfs=all_evt_cfs)
     for loss_t, loss_alpha, loss_name in evt_loss_results:
         one_evt_item = LossHelper.compile_leaf_loss("evt"+loss_name, (loss_t*evt_weights).sum(), evt_weights.sum(),
                                                     loss_lambda=conf.loss_evt*loss_alpha, gold=evt_not_nil.float().sum())
         loss_items.append(one_evt_item)
     # =====
     # arg
     _arg_loss_evt_sample_neg = conf.arg_loss_evt_sample_neg
     if _arg_loss_evt_sample_neg > 0:
         arg_evt_masks = ((BK.rand(mask_expr.shape)<_arg_loss_evt_sample_neg) | evt_not_nil).float() * mask_expr
     else:
         arg_evt_masks = evt_not_nil.float()  # [*, slen]
     # expand/flat the dims
     arg_flat_mask = (arg_evt_masks > 0)  # [*, slen]
     flat_mask_expr = mask_expr.unsqueeze(-2).expand(-1, slen, slen)[arg_flat_mask]  # [*, 1->slen, slen] => [??, slen]
     flat_arg_labels = expr_arg_labels[arg_flat_mask]  # [??, slen]
     flat_arg_not_nil = (flat_arg_labels > 0)  # [??, slen]
     flat_arg_weights = self._prepare_loss_weights(flat_mask_expr, flat_arg_not_nil, conf.arg_loss_sample_neg)
     # -- get losses
     _, all_arg_cfs, all_arg_scores = self.arg_node.get_all_values()  # [*, slen, slen]
     all_arg_losses = []
     for one_arg_scores in all_arg_scores:
         one_flat_arg_scores = one_arg_scores[arg_flat_mask]  # [??, slen]
         one_losses = BK.loss_nll(one_flat_arg_scores, flat_arg_labels, label_smoothing=conf.evt_label_smoothing)
         all_arg_losses.append(one_losses)
     all_arg_cfs = [z[arg_flat_mask] for z in all_arg_cfs]  # [??, slen]
     arg_loss_results = self.arg_node.helper.loss(all_losses=all_arg_losses, all_cfs=all_arg_cfs)
     for loss_t, loss_alpha, loss_name in arg_loss_results:
         one_arg_item = LossHelper.compile_leaf_loss("arg"+loss_name, (loss_t*flat_arg_weights).sum(), flat_arg_weights.sum(),
                                                     loss_lambda=conf.loss_arg*loss_alpha, gold=flat_arg_not_nil.float().sum())
         loss_items.append(one_arg_item)
     # =====
     # return loss
     ret_loss = LossHelper.combine_multiple_losses(loss_items)
     return ret_loss
Beispiel #5
0
 def loss(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr,
          pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr=None):
     conf: DirectExtractorConf = self.conf
     # step 0: prepare golds
     arr_gold_items, expr_gold_gaddr, expr_gold_widxes, expr_gold_wlens, expr_loss_weight_non = \
         self.helper.prepare(insts, use_cache=True)
     # step 1: extract cands
     if conf.loss_use_posi:
         cand_res = self.extract_node.go_lookup(
             input_expr, expr_gold_widxes, expr_gold_wlens, (expr_gold_gaddr>=0).float(), gaddr_expr=expr_gold_gaddr)
     else:  # todo(note): assume no in-middle mask!!
         cand_widx, cand_wlen, cand_mask, cand_gaddr = self.extract_node.prepare_with_lengths(
             BK.get_shape(mask_expr), mask_expr.sum(-1).long(), expr_gold_widxes, expr_gold_wlens, expr_gold_gaddr)
         if conf.span_train_sample:  # simply do sampling
             cand_res = self.extract_node.go_sample(
                 input_expr, mask_expr, cand_widx, cand_wlen, cand_mask,
                 rate=conf.span_train_sample_rate, count=conf.span_train_sample_count,
                 gaddr_expr=cand_gaddr, add_gold_rate=1.0)  # note: always fully add gold for sampling!!
         else:  # beam pruner using topk
             cand_res = self.extract_node.go_topk(
                 input_expr, mask_expr, cand_widx, cand_wlen, cand_mask,
                 rate=conf.span_topk_rate, count=conf.span_topk_count,
                 gaddr_expr=cand_gaddr, add_gold_rate=conf.span_train_topk_add_gold_rate)
     # step 1+: prepare for labeling
     cand_gold_mask = (cand_res.gaddr_expr>=0).float() * cand_res.mask_expr  # [*, cand_len]
     # todo(note): add a 0 as idx=-1 to make NEG ones as 0!!
     flatten_gold_label_idxes = BK.input_idx([(0 if z is None else z.label_idx) for z in arr_gold_items.flatten()] + [0])
     gold_label_idxes = flatten_gold_label_idxes[cand_res.gaddr_expr]
     cand_loss_weights = BK.where(gold_label_idxes==0, expr_loss_weight_non.unsqueeze(-1)*conf.loss_weight_non, cand_res.mask_expr)
     final_loss_weights = cand_loss_weights * cand_res.mask_expr
     # cand loss
     if conf.loss_cand > 0. and not conf.loss_use_posi:
         loss_cand0 = BK.loss_binary(cand_res.score_expr, cand_gold_mask, label_smoothing=conf.cand_label_smoothing)
         loss_cand = (loss_cand0 * final_loss_weights).sum()
         loss_cand_item = LossHelper.compile_leaf_loss(f"cand", loss_cand, final_loss_weights.sum(),
                                                       loss_lambda=conf.loss_cand)
     else:
         loss_cand_item = None
     # extra score
     cand_extra_score = self._get_extra_score(
         cand_res.score_expr, insts, cand_res, arr_gold_items, conf.loss_use_cons, conf.loss_use_lu)
     final_extra_score = self._sum_scores(external_extra_score, cand_extra_score)
     # step 2: label; with special weights
     loss_lab, loss_count = self.lab_node.loss(
         cand_res.span_expr, pair_expr, cand_res.mask_expr, gold_label_idxes,
         loss_weight_expr=final_loss_weights, extra_score=final_extra_score)
     loss_lab_item = LossHelper.compile_leaf_loss(f"lab", loss_lab, loss_count,
                                                  loss_lambda=conf.loss_lab, gold=cand_gold_mask.sum())
     # ==
     # return loss
     ret_loss = LossHelper.combine_multiple_losses([loss_cand_item, loss_lab_item])
     return self._finish_loss(ret_loss, insts, input_expr, mask_expr, pair_expr, lookup_flatten)
Beispiel #6
0
 def loss(self, med: ZMediator, *args, **kwargs):
     conf: ZDecoderUdepConf = self.conf
     # --
     # prepare info
     ibatch = med.ibatch
     expr_udep_labels, expr_isroot = self.prepare(ibatch)  # [bs, dlen, dlen], [bs, dlen]
     base_mask_t = self.get_dec_mask(ibatch, conf.msent_loss_center)  # [bs, dlen]
     # get losses
     loss_items = []
     _loss_udep_lab = conf.loss_udep_lab
     if _loss_udep_lab > 0.:
         # extra masks: force same sent!
         _dec_sent_idxes = ibatch.seq_info.dec_sent_idxes  # [bs, dlen]
         _mask_t = (_dec_sent_idxes.unsqueeze(-1) == _dec_sent_idxes.unsqueeze(-2)).float()  # [bs, dlen, dlen]
         _mask_t *= base_mask_t.unsqueeze(-1)
         _mask_t *= base_mask_t.unsqueeze(-2)  # [bs, dlen, dlen]
         # special handlings
         _mask_t *= (1.-self.udep_ignore_masks[expr_udep_labels])  # [bs, dlen, dlen]
         expr_udep_labels2 = (expr_udep_labels * (1.-self.udep_nilout_masks[expr_udep_labels])).long()  # [bs, dlen, dlen]
         # --
         loss_items.extend(self.loss_from_lab(self.lab_udep, 'udep', med, expr_udep_labels2, _mask_t, _loss_udep_lab))
     _loss_udep_root = conf.loss_udep_root
     if _loss_udep_root > 0.:
         loss_items.extend(self.loss_from_lab(self.lab_root, 'root', med, expr_isroot, base_mask_t, _loss_udep_root))
     # --
     ret_loss = LossHelper.combine_multiple_losses(loss_items)
     return ret_loss, {}
Beispiel #7
0
 def loss_cf(self, cf_scores: List[BK.Expr], insts, loss_cf: float):
     conf: SeqExitHelperConf = self.conf
     # --
     assert self.is_cf
     # get oracle
     oracles = [self.cf_oracle_f(ff)
                for ff in insts]  # bs*[NL, slen] or bs*[NL]
     rets = []
     mask_t = BK.input_real(
         DataPadder.lengths2mask([len(z.sent)
                                  for z in insts]))  # [bs, slen]
     for one_li, one_scores in enumerate(cf_scores):
         if conf.cf_use_seq:
             one_oracle_t = BK.input_real([z[one_li]
                                           for z in oracles])  # [bs]
             one_oracle_t *= conf.cf_scale
             one_mask_t = BK.zeros([len(one_oracle_t)]) + 1
         else:
             one_oracle_t = BK.input_real(
                 DataPadder.go_batch_2d([z[one_li] for z in oracles],
                                        1.))  # [bs, slen]
             one_mask_t = (BK.rand(one_oracle_t.shape) >=
                           ((one_oracle_t**conf.cf_loss_discard_curve) *
                            conf.cf_loss_discard)) * mask_t
             one_oracle_t *= conf.cf_scale
         # simple L2 loss
         one_loss_t = (one_scores.squeeze(-1) - one_oracle_t)**2
         one_loss_item = LossHelper.compile_leaf_loss(f"cf{one_li}",
                                                      (one_loss_t *
                                                       one_mask_t).sum(),
                                                      one_mask_t.sum(),
                                                      loss_lambda=loss_cf)
         rets.append(one_loss_item)
     return rets
Beispiel #8
0
 def _loss_feed_split(self, mask_expr, split_scores, pred_split_decisions,
                      cand_widxes, cand_masks, cand_expr, cand_scores,
                      expr_seq_gaddr):
     conf: SoftExtractorConf = self.conf
     bsize, slen = BK.get_shape(mask_expr)
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [*, 1, 1]
     # --
     # step 2.1: split loss (only on good points (excluding -1|-1 or paddings) with dynamic oracle)
     cand_gaddr = expr_seq_gaddr[arange2_t, cand_widxes]  # [*, clen]
     cand_gaddr0, cand_gaddr1 = cand_gaddr[:, :
                                           -1], cand_gaddr[:,
                                                           1:]  # [*, clen-1]
     split_oracle = (cand_gaddr0 !=
                     cand_gaddr1).float() * cand_masks[:, 1:]  # [*, clen-1]
     split_oracle_mask = (
         (cand_gaddr0 >= 0) |
         (cand_gaddr1 >= 0)).float() * cand_masks[:, 1:]  # [*, clen-1]
     raw_split_loss = BK.loss_binary(
         split_scores,
         split_oracle,
         label_smoothing=conf.split_label_smoothing)  # [*, slen]
     loss_split_item = LossHelper.compile_leaf_loss(
         f"split", (raw_split_loss * split_oracle_mask).sum(),
         split_oracle_mask.sum(),
         loss_lambda=conf.loss_split)
     # step 2.2: feed split
     # note: when teacher-forcing, only forcing good points, others still use pred
     force_split_decisions = split_oracle_mask * split_oracle + (
         1. - split_oracle_mask) * pred_split_decisions  # [*, clen-1]
     _use_force_mask = (BK.rand([bsize])
                        <= conf.split_feed_force_rate).float().unsqueeze(
                            -1)  # [*, 1], seq-level
     feed_split_decisions = (_use_force_mask * force_split_decisions +
                             (1. - _use_force_mask) * pred_split_decisions
                             )  # [*, clen-1]
     # next
     # *[*, seglen, MW], [*, seglen]
     seg_ext_cidxes, seg_ext_masks, seg_masks = self._split_extend(
         feed_split_decisions, cand_masks)
     seg_ext_scores, seg_ext_cidxes, seg_ext_widxes, seg_ext_masks, seg_weighted_expr = self._split_aggregate(
         cand_expr, cand_scores, cand_widxes, seg_ext_cidxes, seg_ext_masks,
         conf.split_topk)  # [*, seglen, ?]
     # finally get oracles for next steps
     # todo(+N): simply select the highest scored one as oracle
     if BK.is_zero_shape(seg_ext_scores):  # simply make them all -1
         oracle_gaddr = BK.constants_idx(seg_masks.shape, -1)  # [*, seglen]
     else:
         _, _seg_max_t = seg_ext_scores.max(-1,
                                            keepdim=True)  # [*, seglen, 1]
         oracle_widxes = seg_ext_widxes.gather(-1, _seg_max_t).squeeze(
             -1)  # [*, seglen]
         oracle_gaddr = expr_seq_gaddr.gather(-1,
                                              oracle_widxes)  # [*, seglen]
     oracle_gaddr[seg_masks <= 0] = -1  # (assign invalid ones) [*, seglen]
     return loss_split_item, seg_masks, seg_ext_widxes, seg_ext_masks, seg_weighted_expr, oracle_gaddr
Beispiel #9
0
 def loss(self, flt_items, flt_input_expr, flt_pair_expr, flt_full_expr, flt_mask_expr, flt_extra_weights=None):
     conf: ExtenderConf = self.conf
     _loss_lambda = conf._loss_lambda
     # --
     enc_t = self._forward_eenc(flt_input_expr, flt_full_expr, flt_mask_expr)  # [*, slen, D]
     s_left, s_right = self.enode.score(enc_t, flt_pair_expr if conf.ext_use_finput else None, flt_mask_expr)  # [*, slen]
     # --
     gold_posi = [self.ext_span_getter(z.mention) for z in flt_items]  # List[(widx, wlen)]
     widx_t = BK.input_idx([z[0] for z in gold_posi])  # [*]
     wlen_t = BK.input_idx([z[1] for z in gold_posi])
     loss_left_t, loss_right_t = BK.loss_nll(s_left, widx_t), BK.loss_nll(s_right, widx_t+wlen_t-1)  # [*]
     if flt_extra_weights is not None:
         loss_left_t *= flt_extra_weights
         loss_right_t *= flt_extra_weights
         loss_div = flt_extra_weights.sum()  # note: also use this!
     else:
         loss_div = BK.constants([len(flt_items)], value=1.).sum()
     loss_left_item = LossHelper.compile_leaf_loss("left", loss_left_t.sum(), loss_div, loss_lambda=_loss_lambda)
     loss_right_item = LossHelper.compile_leaf_loss("right", loss_right_t.sum(), loss_div, loss_lambda=_loss_lambda)
     ret_loss = LossHelper.combine_multiple_losses([loss_left_item, loss_right_item])
     return ret_loss
Beispiel #10
0
 def loss(self, med: ZMediator):
     conf: ZDecoderUPOSConf = self.conf
     insts, mask_expr = med.insts, med.get_mask_t()
     # --
     # first prepare golds
     expr_upos_labels = self.helper.prepare(insts, True)
     loss_items = []
     if conf.loss_upos > 0.:
         loss_items.extend(self._loss_upos(mask_expr, expr_upos_labels))
     # =====
     # return loss
     ret_loss = LossHelper.combine_multiple_losses(loss_items)
     return ret_loss
Beispiel #11
0
 def loss_from_lab(self, lab_node, score_name: str, med: ZMediator, label_t, mask_t, loss_lambda: float,
                   loss_neg_sample: float = None):
     score_cache = med.get_cache((self.name, score_name))
     loss_items = []
     # note: simply collect them all
     all_losses = lab_node.gather_losses(score_cache.vals, label_t, mask_t, loss_neg_sample=loss_neg_sample)
     for ii, vv in enumerate(all_losses):
         nn = score_cache.infos[ii]
         _loss_t, _mask_t = vv
         _loss_item = LossHelper.compile_leaf_loss(
             f'{score_name}_{nn}', _loss_t.sum(), _mask_t.sum(), loss_lambda=loss_lambda)
         loss_items.append(_loss_item)
     return loss_items
Beispiel #12
0
 def loss(self, med: ZMediator, *args, **kwargs):
     conf: ZDecoderMlmConf = self.conf
     # --
     loss_items = []
     if conf.loss_mlm > 0.:
         origin_ids = med.ibatch.seq_info.enc_input_ids
         mlm_mask = med.get_cache('mlm_mask')
         loss_items.extend(
             self.loss_from_lab(self.lab_mlm, 'mlm', med, origin_ids,
                                mlm_mask, conf.loss_mlm))
     # --
     ret_loss = LossHelper.combine_multiple_losses(loss_items)
     return ret_loss, {}
Beispiel #13
0
 def loss(self, med: ZMediator):
     conf: ZDecoderUDEPConf = self.conf
     insts, mask_expr = med.insts, med.get_mask_t()
     # --
     # first prepare golds
     expr_depth, expr_udep = self.helper.prepare(insts, conf.udep_train_use_cache)
     loss_items = []
     if conf.loss_depth > 0.:
         loss_items.extend(self._loss_depth(med, mask_expr, expr_depth))
     if conf.loss_udep > 0.:
         loss_items.extend(self._loss_udep(med, mask_expr, expr_udep))
     # --
     ret_loss = LossHelper.combine_multiple_losses(loss_items)
     return ret_loss
Beispiel #14
0
 def _finish_loss(self, core_loss, insts, input_expr, mask_expr, pair_expr, lookup_flatten: bool):
     conf: BaseExtractorConf = self.conf
     if conf.loss_ext>0. or lookup_flatten:
         lookup_res = self.lookup_flatten(insts, input_expr, mask_expr, pair_expr)
     else:
         lookup_res = None
     if conf.loss_ext>0.:
         flt_items, flt_sidx, flt_expr, flt_full_expr = lookup_res  # flatten to make dim0 -> frames
         flt_input_expr, flt_mask_expr = input_expr[flt_sidx], mask_expr[flt_sidx]
         ext_loss = self.ext_node.loss(flt_items, flt_input_expr, flt_expr, flt_full_expr, flt_mask_expr)
         ret_loss = LossHelper.combine_multiple_losses([core_loss, ext_loss])  # with another loss
     else:
         ret_loss = core_loss
     return ret_loss, lookup_res
Beispiel #15
0
 def _loss_depth(self, med: ZMediator, mask_expr, expr_depth):
     conf: ZDecoderUDEPConf = self.conf
     # --
     all_depth_scores = med.main_scores.get((self.name, "depth"))  # [*, slen]
     all_depth_losses = []
     for one_depth_scores in all_depth_scores:
         one_losses = BK.loss_binary(one_depth_scores.squeeze(-1), expr_depth, label_smoothing=conf.depth_label_smoothing)
         all_depth_losses.append(one_losses)
     depth_loss_results = self.depth_node.helper.loss(all_losses=all_depth_losses)
     loss_items = []
     for loss_t, loss_alpha, loss_name in depth_loss_results:
         one_depth_item = LossHelper.compile_leaf_loss(
             "depth"+loss_name, (loss_t*mask_expr).sum(), mask_expr.sum(), loss_lambda=(loss_alpha*conf.loss_depth))
         loss_items.append(one_depth_item)
     return loss_items
Beispiel #16
0
 def loss(self, med: ZMediator, *args, **kwargs):
     conf: ZDecoderUposConf = self.conf
     # --
     # prepare info
     ibatch = med.ibatch
     expr_upos_labels = self.prepare(ibatch)
     mast_t = self.get_dec_mask(ibatch, conf.msent_loss_center)
     # get losses
     loss_items = []
     _loss_upos = conf.loss_upos
     if _loss_upos > 0.:
         loss_items.extend(
             self.loss_from_lab(self.lab_upos, 'upos', med,
                                expr_upos_labels, mast_t, _loss_upos))
     # --
     ret_loss = LossHelper.combine_multiple_losses(loss_items)
     return ret_loss, {}
Beispiel #17
0
 def loss(self, med: ZMediator):
     conf: ZDecoderSRLConf = self.conf
     insts, mask_expr = med.insts, med.get_mask_t()
     # --
     # first prepare golds
     arr_items, expr_evt_labels, expr_arg_labels, expr_arg2_labels, expr_loss_weight_non = self.helper.prepare(insts, True)
     if conf.binary_evt:
         expr_evt_labels = (expr_evt_labels>0).long()  # either 0 or 1
     # --
     loss_items = []
     if conf.loss_evt > 0.:
         loss_items.extend(self._loss_evt(mask_expr, expr_evt_labels, expr_loss_weight_non))
     if conf.loss_arg > 0. or conf.loss_arg2 > 0.:
         loss_items.extend(self._loss_arg(mask_expr, expr_evt_labels, expr_arg_labels, expr_arg2_labels))
     # =====
     # return loss
     ret_loss = LossHelper.combine_multiple_losses(loss_items)
     return ret_loss
Beispiel #18
0
 def _loss_upos(self, mask_expr, expr_upos_labels):
     conf: ZDecoderUPOSConf = self.conf
     all_upos_scores = self.upos_node.buffer_scores.values()  # [*, slen, L]
     all_upos_losses = []
     for one_upos_scores in all_upos_scores:
         one_losses = BK.loss_nll(one_upos_scores,
                                  expr_upos_labels,
                                  label_smoothing=conf.upos_label_smoothing)
         all_upos_losses.append(one_losses)
     upos_loss_results = self.upos_node.helper.loss(
         all_losses=all_upos_losses)
     loss_items = []
     for loss_t, loss_alpha, loss_name in upos_loss_results:
         one_upos_item = LossHelper.compile_leaf_loss(
             "upos" + loss_name, (loss_t * mask_expr).sum(),
             mask_expr.sum(),
             loss_lambda=(loss_alpha * conf.loss_upos))
         loss_items.append(one_upos_item)
     return loss_items
Beispiel #19
0
 def _loss_udep(self, med: ZMediator, mask_expr, expr_udep):
     conf: ZDecoderUDEPConf = self.conf
     # --
     all_udep_scores = med.main_scores.get((self.name, "udep"))  # [*, slen, slen, L]
     all_udep_losses = []
     for one_udep_scores in all_udep_scores:
         one_losses = BK.loss_nll(one_udep_scores, expr_udep, label_smoothing=conf.udep_label_smoothing)
         all_udep_losses.append(one_losses)
     udep_loss_results = self.udep_node.helper.loss(all_losses=all_udep_losses)
     # --
     _loss_weights = ((BK.rand(expr_udep.shape) < conf.udep_loss_sample_neg) | (expr_udep>0)).float() \
                     * mask_expr.unsqueeze(-1) * mask_expr.unsqueeze(-2)  # [*, slen, slen]
     # --
     loss_items = []
     for loss_t, loss_alpha, loss_name in udep_loss_results:
         one_udep_item = LossHelper.compile_leaf_loss(
             "udep"+loss_name, (loss_t*_loss_weights).sum(), _loss_weights.sum(), loss_lambda=(loss_alpha*conf.loss_udep))
         loss_items.append(one_udep_item)
     return loss_items
Beispiel #20
0
 def _loss_evt(self, mask_expr, expr_evt_labels, expr_loss_weight_non):
     conf: ZDecoderSRLConf = self.conf
     # --
     loss_items = []
     # -- prepare weights and masks
     evt_not_nil = (expr_evt_labels > 0)  # [*, slen]
     evt_extra_weights = BK.where(evt_not_nil, mask_expr, expr_loss_weight_non.unsqueeze(-1) * conf.evt_loss_weight_non)
     evt_weights = self._prepare_loss_weights(mask_expr, evt_not_nil, conf.evt_loss_sample_neg, evt_extra_weights)
     # -- get losses
     all_evt_scores = self.evt_node.buffer_scores.values()  # List([*,slen])
     all_evt_losses = []
     for one_evt_scores in all_evt_scores:
         one_losses = BK.loss_nll(one_evt_scores, expr_evt_labels, label_smoothing=conf.evt_label_smoothing)
         all_evt_losses.append(one_losses)
     evt_loss_results = self.evt_node.helper.loss(all_losses=all_evt_losses)
     for loss_t, loss_alpha, loss_name in evt_loss_results:
         one_evt_item = LossHelper.compile_leaf_loss(
             "evt" + loss_name, (loss_t * evt_weights).sum(), evt_weights.sum(),
             loss_lambda=conf.loss_evt * loss_alpha, gold=evt_not_nil.float().sum())
         loss_items.append(one_evt_item)
     return loss_items
Beispiel #21
0
 def _loss_arg(self, mask_expr, expr_evt_labels, expr_arg_labels, expr_arg2_labels):
     conf: ZDecoderSRLConf = self.conf
     # --
     loss_items = []
     slen = BK.get_shape(mask_expr, -1)
     evt_not_nil = (expr_evt_labels > 0)  # [*, slen]
     # --
     # first prepare evts to focus at
     arg_evt_masks = self._prepare_loss_weights(mask_expr, evt_not_nil, conf.arg_loss_inc_neg_evt)  # [*, slen]
     # expand/flat the dims
     arg_flat_mask = (arg_evt_masks > 0)  # [*, slen]
     flat_mask_expr = mask_expr.unsqueeze(-2).expand(-1, slen, slen)[arg_flat_mask]  # [*, 1->slen, slen] => [??, slen]
     # -- get losses
     for aname, anode, glabels, sneg_rate, loss_mul, lsmooth in \
             zip(["arg", "arg2"], [self.arg_node, self.arg2_node], [expr_arg_labels, expr_arg2_labels],
                 [conf.arg_loss_sample_neg, conf.arg2_loss_sample_neg], [conf.loss_arg, conf.loss_arg2],
                 [conf.arg_label_smoothing, conf.arg2_label_smoothing]):
         # --
         if loss_mul <= 0.:
             continue
         # --
         # prepare
         flat_arg_labels = glabels[arg_flat_mask]  # [??, slen]
         flat_arg_not_nil = (flat_arg_labels > 0)  # [??, slen]
         flat_arg_weights = self._prepare_loss_weights(flat_mask_expr, flat_arg_not_nil, sneg_rate)
         # scores and losses
         all_arg_scores = anode.buffer_scores.values()  # [*, slen, slen]
         all_arg_losses = []
         for one_arg_scores in all_arg_scores:
             one_flat_arg_scores = one_arg_scores[arg_flat_mask]  # [??, slen]
             one_losses = BK.loss_nll(one_flat_arg_scores, flat_arg_labels, label_smoothing=lsmooth)
             all_arg_losses.append(one_losses)
         arg_loss_results = anode.helper.loss(all_losses=all_arg_losses)
         # collect losses
         for loss_t, loss_alpha, loss_name in arg_loss_results:
             one_arg_item = LossHelper.compile_leaf_loss(
                 aname + loss_name, (loss_t * flat_arg_weights).sum(), flat_arg_weights.sum(),
                 loss_lambda=(loss_mul*loss_alpha), gold=flat_arg_not_nil.float().sum())
             loss_items.append(one_arg_item)
     return loss_items
Beispiel #22
0
 def loss(self,
          insts: Union[List[Sent], List[Frame]],
          input_expr: BK.Expr,
          mask_expr: BK.Expr,
          pair_expr: BK.Expr = None,
          lookup_flatten=False,
          external_extra_score: BK.Expr = None):
     conf: AnchorExtractorConf = self.conf
     assert not lookup_flatten
     bsize, slen = BK.get_shape(mask_expr)
     # --
     # step 0: prepare
     arr_items, expr_seq_gaddr, expr_seq_labs, expr_group_widxes, expr_group_masks, expr_loss_weight_non = \
         self.helper.prepare(insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True)
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [*, 1]
     arange3_t = arange2_t.unsqueeze(-1)  # [*, 1, 1]
     # --
     # step 1: label, simply scoring everything!
     _main_t, _pair_t = self.lab_node.transform_expr(input_expr, pair_expr)
     all_scores_t = self.lab_node.score_all(
         _main_t,
         _pair_t,
         mask_expr,
         None,
         local_normalize=False,
         extra_score=external_extra_score
     )  # unnormalized scores [*, slen, L]
     all_probs_t = all_scores_t.softmax(-1)  # [*, slen, L]
     all_gprob_t = all_probs_t.gather(-1,
                                      expr_seq_labs.unsqueeze(-1)).squeeze(
                                          -1)  # [*, slen]
     # how to weight
     extended_gprob_t = all_gprob_t[
         arange3_t, expr_group_widxes] * expr_group_masks  # [*, slen, MW]
     if BK.is_zero_shape(extended_gprob_t):
         extended_gprob_max_t = BK.zeros(mask_expr.shape)  # [*, slen]
     else:
         extended_gprob_max_t, _ = extended_gprob_t.max(-1)  # [*, slen]
     _w_alpha = conf.cand_loss_weight_alpha
     _weight = (
         (all_gprob_t * mask_expr) /
         (extended_gprob_max_t.clamp(min=1e-5)))**_w_alpha  # [*, slen]
     _label_smoothing = conf.lab_conf.labeler_conf.label_smoothing
     _loss1 = BK.loss_nll(all_scores_t,
                          expr_seq_labs,
                          label_smoothing=_label_smoothing)  # [*, slen]
     _loss2 = BK.loss_nll(all_scores_t,
                          BK.constants_idx([bsize, slen], 0),
                          label_smoothing=_label_smoothing)  # [*, slen]
     _weight1 = _weight.detach() if conf.detach_weight_lab else _weight
     _raw_loss = _weight1 * _loss1 + (1. - _weight1) * _loss2  # [*, slen]
     # final weight it
     cand_loss_weights = BK.where(expr_seq_labs == 0,
                                  expr_loss_weight_non.unsqueeze(-1) *
                                  conf.loss_weight_non,
                                  mask_expr)  # [*, slen]
     final_cand_loss_weights = cand_loss_weights * mask_expr  # [*, slen]
     loss_lab_item = LossHelper.compile_leaf_loss(
         f"lab", (_raw_loss * final_cand_loss_weights).sum(),
         final_cand_loss_weights.sum(),
         loss_lambda=conf.loss_lab,
         gold=(expr_seq_labs > 0).float().sum())
     # --
     # step 1.5
     all_losses = [loss_lab_item]
     _loss_cand_entropy = conf.loss_cand_entropy
     if _loss_cand_entropy > 0.:
         _prob = extended_gprob_t  # [*, slen, MW]
         _ent = EntropyHelper.self_entropy(_prob)  # [*, slen]
         # [*, slen], only first one in bag
         _ent_mask = BK.concat([expr_seq_gaddr[:,:1]>=0, expr_seq_gaddr[:,1:]!=expr_seq_gaddr[:,:-1]],-1).float() \
                     * (expr_seq_labs>0).float()
         _loss_ent_item = LossHelper.compile_leaf_loss(
             f"cand_ent", (_ent * _ent_mask).sum(),
             _ent_mask.sum(),
             loss_lambda=_loss_cand_entropy)
         all_losses.append(_loss_ent_item)
     # --
     # step 4: extend (select topk)
     if conf.loss_ext > 0.:
         if BK.is_zero_shape(extended_gprob_t):
             flt_mask = (BK.zeros(mask_expr.shape) > 0)
         else:
             _topk = min(conf.ext_loss_topk,
                         BK.get_shape(extended_gprob_t,
                                      -1))  # number to extract
             _topk_grpob_t, _ = extended_gprob_t.topk(
                 _topk, dim=-1)  # [*, slen, K]
             flt_mask = (expr_seq_labs >
                         0) & (all_gprob_t >= _topk_grpob_t.min(-1)[0]) & (
                             _weight > conf.ext_loss_thresh)  # [*, slen]
         flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[
             flt_mask]  # [?]
         flt_expr = input_expr[flt_mask]  # [?, D]
         flt_full_expr = self._prepare_full_expr(flt_mask)  # [?, slen, D]
         flt_items = arr_items.flatten()[BK.get_value(
             expr_seq_gaddr[flt_mask])]  # [?]
         flt_weights = _weight.detach(
         )[flt_mask] if conf.detach_weight_ext else _weight[flt_mask]  # [?]
         loss_ext_item = self.ext_node.loss(flt_items,
                                            input_expr[flt_sidx],
                                            flt_expr,
                                            flt_full_expr,
                                            mask_expr[flt_sidx],
                                            flt_extra_weights=flt_weights)
         all_losses.append(loss_ext_item)
     # --
     # return loss
     ret_loss = LossHelper.combine_multiple_losses(all_losses)
     return ret_loss, None
Beispiel #23
0
 def loss(self,
          insts: Union[List[Sent], List[Frame]],
          input_expr: BK.Expr,
          mask_expr: BK.Expr,
          pair_expr: BK.Expr = None,
          lookup_flatten=False,
          external_extra_score: BK.Expr = None):
     conf: SoftExtractorConf = self.conf
     assert not lookup_flatten
     bsize, slen = BK.get_shape(mask_expr)
     # --
     # step 0: prepare
     arr_items, expr_seq_gaddr, expr_group_widxes, expr_group_masks, expr_loss_weight_non = \
         self.helper.prepare(insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True)
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [*, 1]
     # --
     # step 1: cand
     cand_full_scores, pred_cand_decisions = self._cand_score_and_select(
         input_expr, mask_expr)  # [*, slen]
     loss_cand_items, cand_widxes, cand_masks = self._loss_feed_cand(
         mask_expr,
         cand_full_scores,
         pred_cand_decisions,
         expr_seq_gaddr,
         expr_group_widxes,
         expr_group_masks,
         expr_loss_weight_non,
     )  # ~, [*, clen]
     # --
     # step 2: split
     cand_expr, cand_scores = input_expr[
         arange2_t, cand_widxes], cand_full_scores[arange2_t,
                                                   cand_widxes]  # [*, clen]
     split_scores, pred_split_decisions = self._split_score(
         cand_expr, cand_masks)  # [*, clen-1]
     loss_split_item, seg_masks, seg_ext_widxes, seg_ext_masks, seg_weighted_expr, oracle_gaddr = self._loss_feed_split(
         mask_expr,
         split_scores,
         pred_split_decisions,
         cand_widxes,
         cand_masks,
         cand_expr,
         cand_scores,
         expr_seq_gaddr,
     )  # ~, [*, seglen, *?]
     # --
     # step 3: lab
     # todo(note): add a 0 as idx=-1 to make NEG ones as 0!!
     flatten_gold_label_idxes = BK.input_idx(
         [(0 if z is None else z.label_idx)
          for z in arr_items.flatten()] + [0])
     gold_label_idxes = flatten_gold_label_idxes[
         oracle_gaddr]  # [*, seglen]
     lab_loss_weights = BK.where(oracle_gaddr >= 0,
                                 expr_loss_weight_non.unsqueeze(-1) *
                                 conf.loss_weight_non,
                                 seg_masks)  # [*, seglen]
     final_lab_loss_weights = lab_loss_weights * seg_masks  # [*, seglen]
     # go
     loss_lab, loss_count = self.lab_node.loss(
         seg_weighted_expr,
         pair_expr,
         seg_masks,
         gold_label_idxes,
         loss_weight_expr=final_lab_loss_weights,
         extra_score=external_extra_score)
     loss_lab_item = LossHelper.compile_leaf_loss(
         f"lab",
         loss_lab,
         loss_count,
         loss_lambda=conf.loss_lab,
         gold=(gold_label_idxes > 0).float().sum())
     # step 4: extend
     flt_mask = ((gold_label_idxes > 0) & (seg_masks > 0.))  # [*, seglen]
     flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[
         flt_mask]  # [?]
     flt_expr = seg_weighted_expr[flt_mask]  # [?, D]
     flt_full_expr = self._prepare_full_expr(seg_ext_widxes[flt_mask],
                                             seg_ext_masks[flt_mask],
                                             slen)  # [?, slen, D]
     flt_items = arr_items.flatten()[BK.get_value(
         oracle_gaddr[flt_mask])]  # [?]
     loss_ext_item = self.ext_node.loss(flt_items, input_expr[flt_sidx],
                                        flt_expr, flt_full_expr,
                                        mask_expr[flt_sidx])
     # --
     # return loss
     ret_loss = LossHelper.combine_multiple_losses(
         loss_cand_items + [loss_split_item, loss_lab_item, loss_ext_item])
     return ret_loss, None
Beispiel #24
0
 def _loss_feed_cand(self, mask_expr, cand_full_scores, pred_cand_decisions,
                     expr_seq_gaddr, expr_group_widxes, expr_group_masks,
                     expr_loss_weight_non):
     conf: SoftExtractorConf = self.conf
     bsize, slen = BK.get_shape(mask_expr)
     arange3_t = BK.arange_idx(bsize).unsqueeze(-1).unsqueeze(
         -1)  # [*, 1, 1]
     # --
     # step 1.1: bag loss
     cand_gold_mask = (expr_seq_gaddr >=
                       0).float() * mask_expr  # [*, slen], whether is-arg
     raw_loss_cand = BK.loss_binary(
         cand_full_scores,
         cand_gold_mask,
         label_smoothing=conf.cand_label_smoothing)  # [*, slen]
     # how to weight?
     extended_scores_t = cand_full_scores[arange3_t, expr_group_widxes] + (
         1. - expr_group_masks) * Constants.REAL_PRAC_MIN  # [*, slen, MW]
     if BK.is_zero_shape(extended_scores_t):
         extended_scores_max_t = BK.zeros(mask_expr.shape)  # [*, slen]
     else:
         extended_scores_max_t, _ = extended_scores_t.max(-1)  # [*, slen]
     _w_alpha = conf.cand_loss_weight_alpha
     _weight = ((cand_full_scores - extended_scores_max_t) *
                _w_alpha).exp()  # [*, slen]
     if not conf.cand_loss_div_max:  # div sum-all, like doing softmax
         _weight = _weight / (
             (extended_scores_t - extended_scores_max_t.unsqueeze(-1)) *
             _w_alpha).exp().sum(-1)
     _weight = _weight * (_weight >=
                          conf.cand_loss_weight_thresh).float()  # [*, slen]
     if conf.cand_detach_weight:
         _weight = _weight.detach()
     # pos poison (dis-encouragement)
     if conf.cand_loss_pos_poison:
         poison_loss = BK.loss_binary(
             cand_full_scores,
             1. - cand_gold_mask,
             label_smoothing=conf.cand_label_smoothing)  # [*, slen]
         raw_loss_cand = raw_loss_cand * _weight + poison_loss * cand_gold_mask * (
             1. - _weight)  # [*, slen]
     else:
         raw_loss_cand = raw_loss_cand * _weight
     # final weight it
     cand_loss_weights = BK.where(cand_gold_mask == 0.,
                                  expr_loss_weight_non.unsqueeze(-1) *
                                  conf.loss_weight_non,
                                  mask_expr)  # [*, slen]
     final_cand_loss_weights = cand_loss_weights * mask_expr  # [*, slen]
     loss_cand_item = LossHelper.compile_leaf_loss(
         f"cand", (raw_loss_cand * final_cand_loss_weights).sum(),
         final_cand_loss_weights.sum(),
         loss_lambda=conf.loss_cand)
     # step 1.2: feed cand
     # todo(+N): currently only pred/sample, whether adding certain teacher-forcing?
     sample_decisions = (BK.sigmoid(cand_full_scores) >= BK.rand(
         cand_full_scores.shape)).float() * mask_expr  # [*, slen]
     _use_sample_mask = (BK.rand([bsize])
                         <= conf.cand_feed_sample_rate).float().unsqueeze(
                             -1)  # [*, 1], seq-level
     feed_cand_decisions = (_use_sample_mask * sample_decisions +
                            (1. - _use_sample_mask) * pred_cand_decisions
                            )  # [*, slen]
     # next
     cand_widxes, cand_masks = BK.mask2idx(feed_cand_decisions)  # [*, clen]
     # --
     # extra: loss_cand_entropy
     rets = [loss_cand_item]
     _loss_cand_entropy = conf.loss_cand_entropy
     if _loss_cand_entropy > 0.:
         _prob = extended_scores_t.softmax(-1)  # [*, slen, MW]
         _ent = EntropyHelper.self_entropy(_prob)  # [*, slen]
         # [*, slen], only first one in bag
         _ent_mask = BK.concat([
             expr_seq_gaddr[:, :1] >= 0,
             expr_seq_gaddr[:, 1:] != expr_seq_gaddr[:, :-1]
         ], -1).float() * cand_gold_mask
         _loss_ent_item = LossHelper.compile_leaf_loss(
             f"cand_ent", (_ent * _ent_mask).sum(),
             _ent_mask.sum(),
             loss_lambda=_loss_cand_entropy)
         rets.append(_loss_ent_item)
     # --
     return rets, cand_widxes, cand_masks