示例#1
0
 def _get_loss_mask(self,
                    pos_t: BK.Expr,
                    valid_t: BK.Expr,
                    loss_neg_sample: float = None):
     conf: ZLabelConf = self.conf
     # use default config if not from outside!
     _loss_neg_sample = conf.loss_neg_sample if loss_neg_sample is None else loss_neg_sample
     # --
     if _loss_neg_sample >= 1.:  # all valid is ok!
         return valid_t
     # --
     pos_t = pos_t * valid_t  # should also filter pos here!
     # first get sample rate
     if _loss_neg_sample >= 0.:  # percentage to valid
         _rate = _loss_neg_sample  # directly it!!
     else:  # ratio to pos
         _count_pos = pos_t.sum()
         _count_valid = valid_t.sum()
         _rate = (-_loss_neg_sample) * (
             (_count_pos + 1) /
             (_count_valid - _count_pos + 1))  # add-1 to make it >0
     # random select!
     ret_t = (BK.rand(valid_t.shape) <= _rate).float() * valid_t
     ret_t += pos_t  # also include pos ones!
     ret_t.clamp_(max=1.)
     return ret_t
示例#2
0
文件: crf.py 项目: zzsfornlp/zmsp
 def loss(self, unary_scores: BK.Expr, input_mask: BK.Expr, gold_idxes: BK.Expr):
     mat_t = self.bigram.get_matrix()  # [L, L]
     if BK.is_zero_shape(unary_scores):  # note: avoid empty
         potential_t = BK.zeros(BK.get_shape(unary_scores)[:-2])  # [*]
     else:
         potential_t = BigramInferenceHelper.inference_forward(unary_scores, mat_t, input_mask, self.conf.crf_beam)  # [*]
     gold_single_scores_t = unary_scores.gather(-1, gold_idxes.unsqueeze(-1)).squeeze(-1) * input_mask  # [*, slen]
     gold_bigram_scores_t = mat_t[gold_idxes[:, :-1], gold_idxes[:, 1:]] * input_mask[:, 1:]  # [*, slen-1]
     all_losses_t = (potential_t - (gold_single_scores_t.sum(-1) + gold_bigram_scores_t.sum(-1)))  # [*]
     if self.conf.loss_by_tok:
         ret_count = input_mask.sum()  # []
     else:
         ret_count = (input_mask.sum(-1)>0).float()  # [*]
     return all_losses_t, ret_count
示例#3
0
 def prepare_sd_init(self, expr_main: BK.Expr, expr_pair: BK.Expr):
     if self.is_pairwise:
         sd_init_t = self.sd_init_aff(expr_pair)  # [*, hid]
     else:
         if BK.is_zero_shape(expr_main):
             sd_init_t0 = expr_main.sum(-2)  # simply make the shape!
         else:
             sd_init_t0 = self.sd_init_pool_f(expr_main,
                                              -2)  # pooling at -2: [*, Dm']
         sd_init_t = self.sd_init_aff(sd_init_t0)  # [*, hid]
     return sd_init_t
示例#4
0
文件: direct.py 项目: zzsfornlp/zmsp
 def loss(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr,
          pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr=None):
     conf: DirectExtractorConf = self.conf
     # step 0: prepare golds
     arr_gold_items, expr_gold_gaddr, expr_gold_widxes, expr_gold_wlens, expr_loss_weight_non = \
         self.helper.prepare(insts, use_cache=True)
     # step 1: extract cands
     if conf.loss_use_posi:
         cand_res = self.extract_node.go_lookup(
             input_expr, expr_gold_widxes, expr_gold_wlens, (expr_gold_gaddr>=0).float(), gaddr_expr=expr_gold_gaddr)
     else:  # todo(note): assume no in-middle mask!!
         cand_widx, cand_wlen, cand_mask, cand_gaddr = self.extract_node.prepare_with_lengths(
             BK.get_shape(mask_expr), mask_expr.sum(-1).long(), expr_gold_widxes, expr_gold_wlens, expr_gold_gaddr)
         if conf.span_train_sample:  # simply do sampling
             cand_res = self.extract_node.go_sample(
                 input_expr, mask_expr, cand_widx, cand_wlen, cand_mask,
                 rate=conf.span_train_sample_rate, count=conf.span_train_sample_count,
                 gaddr_expr=cand_gaddr, add_gold_rate=1.0)  # note: always fully add gold for sampling!!
         else:  # beam pruner using topk
             cand_res = self.extract_node.go_topk(
                 input_expr, mask_expr, cand_widx, cand_wlen, cand_mask,
                 rate=conf.span_topk_rate, count=conf.span_topk_count,
                 gaddr_expr=cand_gaddr, add_gold_rate=conf.span_train_topk_add_gold_rate)
     # step 1+: prepare for labeling
     cand_gold_mask = (cand_res.gaddr_expr>=0).float() * cand_res.mask_expr  # [*, cand_len]
     # todo(note): add a 0 as idx=-1 to make NEG ones as 0!!
     flatten_gold_label_idxes = BK.input_idx([(0 if z is None else z.label_idx) for z in arr_gold_items.flatten()] + [0])
     gold_label_idxes = flatten_gold_label_idxes[cand_res.gaddr_expr]
     cand_loss_weights = BK.where(gold_label_idxes==0, expr_loss_weight_non.unsqueeze(-1)*conf.loss_weight_non, cand_res.mask_expr)
     final_loss_weights = cand_loss_weights * cand_res.mask_expr
     # cand loss
     if conf.loss_cand > 0. and not conf.loss_use_posi:
         loss_cand0 = BK.loss_binary(cand_res.score_expr, cand_gold_mask, label_smoothing=conf.cand_label_smoothing)
         loss_cand = (loss_cand0 * final_loss_weights).sum()
         loss_cand_item = LossHelper.compile_leaf_loss(f"cand", loss_cand, final_loss_weights.sum(),
                                                       loss_lambda=conf.loss_cand)
     else:
         loss_cand_item = None
     # extra score
     cand_extra_score = self._get_extra_score(
         cand_res.score_expr, insts, cand_res, arr_gold_items, conf.loss_use_cons, conf.loss_use_lu)
     final_extra_score = self._sum_scores(external_extra_score, cand_extra_score)
     # step 2: label; with special weights
     loss_lab, loss_count = self.lab_node.loss(
         cand_res.span_expr, pair_expr, cand_res.mask_expr, gold_label_idxes,
         loss_weight_expr=final_loss_weights, extra_score=final_extra_score)
     loss_lab_item = LossHelper.compile_leaf_loss(f"lab", loss_lab, loss_count,
                                                  loss_lambda=conf.loss_lab, gold=cand_gold_mask.sum())
     # ==
     # return loss
     ret_loss = LossHelper.combine_multiple_losses([loss_cand_item, loss_lab_item])
     return self._finish_loss(ret_loss, insts, input_expr, mask_expr, pair_expr, lookup_flatten)
示例#5
0
 def go_topk(
         self,
         input_expr: BK.Expr,
         input_mask: BK.Expr,  # input
         widx_expr: BK.Expr,
         wlen_expr: BK.Expr,
         span_mask: BK.Expr,
         rate: float = None,
         count: float = None,  # span
         gaddr_expr: BK.Expr = None,
         add_gold_rate: float = 0.,  # gold
         non_overlapping=False,
         score_prune: float = None):  # non-overlapping!
     lookup_res = self.go_lookup(input_expr, widx_expr, wlen_expr,
                                 span_mask, gaddr_expr)  # [bsize, NUM, *]
     # --
     with BK.no_grad_env():  # no need grad here!
         all_score_expr = lookup_res.score_expr
         # get topk score: again rate is to the original input length
         if BK.is_zero_shape(lookup_res.mask_expr):
             topk_mask = lookup_res.mask_expr.clone(
             )  # no need to go topk since no elements
         else:
             topk_expr = self._determine_size(
                 input_mask.sum(-1, keepdim=True), rate,
                 count).long()  # [bsize, 1]
             if non_overlapping:
                 topk_mask = select_topk_non_overlapping(all_score_expr,
                                                         topk_expr,
                                                         widx_expr,
                                                         wlen_expr,
                                                         input_mask,
                                                         mask_t=span_mask,
                                                         dim=-1)
             else:
                 topk_mask = select_topk(all_score_expr,
                                         topk_expr,
                                         mask_t=span_mask,
                                         dim=-1)
         # further score_prune?
         if score_prune is not None:
             topk_mask *= (all_score_expr >= score_prune).float()
     # select and add_gold
     return self._go_common(lookup_res, topk_mask, add_gold_rate)
示例#6
0
文件: direct.py 项目: zzsfornlp/zmsp
 def predict(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr,
             pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr=None):
     conf: DirectExtractorConf = self.conf
     # step 1: prepare targets
     if conf.pred_use_posi:
         # step 1a: directly use provided positions
         arr_gold_items, expr_gold_gaddr, expr_gold_widxes, expr_gold_wlens, _ = self.helper.prepare(insts, use_cache=False)
         cand_res = self.extract_node.go_lookup(input_expr, expr_gold_widxes, expr_gold_wlens,
                                                (expr_gold_gaddr>=0).float(), gaddr_expr=expr_gold_gaddr)
     else:
         arr_gold_items = None
         # step 1b: extract cands (topk); todo(note): assume no in-middle mask!!
         cand_widx, cand_wlen, cand_mask, _ = self.extract_node.prepare_with_lengths(
             BK.get_shape(mask_expr), mask_expr.sum(-1).long(), None, None, None)
         cand_res = self.extract_node.go_topk(
             input_expr, mask_expr, cand_widx, cand_wlen, cand_mask,
             rate=conf.span_topk_rate, count=conf.span_topk_count,
             non_overlapping=conf.pred_non_overlapping, score_prune=conf.pred_score_prune)
     # --
     # note: check empty
     if BK.is_zero_shape(cand_res.mask_expr):
         if not conf.pred_use_posi:
             for inst in insts:  # still need to clear things!!
                 self.helper._clear_f(inst)
     else:
         # step 2: labeling
         # extra score
         cand_extra_score = self._get_extra_score(
             cand_res.score_expr, insts, cand_res, arr_gold_items, conf.pred_use_cons, conf.pred_use_lu)
         final_extra_score = self._sum_scores(external_extra_score, cand_extra_score)
         best_labs, best_scores = self.lab_node.predict(
             cand_res.span_expr, pair_expr, cand_res.mask_expr, extra_score=final_extra_score)
         # step 3: put results
         if conf.pred_use_posi:  # reuse the old ones, but replace label
             self.helper.put_labels(arr_gold_items, best_labs, best_scores)
         else:  # make new frames
             self.helper.put_results(insts, best_labs, best_scores, cand_res.widx_expr, cand_res.wlen_expr, cand_res.mask_expr)
     # --
     # finally
     return self._finish_pred(insts, input_expr, mask_expr, pair_expr, lookup_flatten)
示例#7
0
 def go_sample(
         self,
         input_expr: BK.Expr,
         input_mask: BK.Expr,  # input
         widx_expr: BK.Expr,
         wlen_expr: BK.Expr,
         span_mask: BK.Expr,
         rate: float = None,
         count: float = None,  # span
         gaddr_expr: BK.Expr = None,
         add_gold_rate: float = 0.):  # gold
     lookup_res = self.go_lookup(input_expr, widx_expr, wlen_expr,
                                 span_mask, gaddr_expr)  # [bsize, NUM, *]
     # --
     # rate is according to overall input length
     _tmp_len = (input_mask.sum(-1, keepdim=True) + 1e-5)
     sample_rate = self._determine_size(_tmp_len, rate,
                                        count) / _tmp_len  # [bsize, 1]
     sample_mask = (BK.rand(span_mask.shape) <
                    sample_rate).float()  # [bsize, NUM]
     # select and add_gold
     return self._go_common(lookup_res, sample_mask, add_gold_rate)
示例#8
0
文件: soft.py 项目: zzsfornlp/zmsp
 def _cand_score_and_select(self, input_expr: BK.Expr, mask_expr: BK.Expr):
     conf: SoftExtractorConf = self.conf
     # --
     cand_full_scores = self.cand_scorer(input_expr).squeeze(
         -1) + (1. - mask_expr) * Constants.REAL_PRAC_MIN  # [*, slen]
     # decide topk count
     len_t = mask_expr.sum(-1)  # [*]
     topk_t = (len_t * conf.cand_topk_rate).clamp(
         max=conf.cand_topk_count).ceil().long().unsqueeze(-1)  # [*, 1]
     # get topk mask
     if BK.is_zero_shape(mask_expr):
         topk_mask = mask_expr.clone(
         )  # no need to go topk since no elements
     else:
         topk_mask = select_topk(cand_full_scores,
                                 topk_t,
                                 mask_t=mask_expr,
                                 dim=-1)  # [*, slen]
     # thresh
     cand_decisions = topk_mask * (cand_full_scores >=
                                   conf.cand_thresh).float()  # [*, slen]
     return cand_full_scores, cand_decisions  # [*, slen]
示例#9
0
文件: soft.py 项目: zzsfornlp/zmsp
 def _split_extend(self, split_decisions: BK.Expr, cand_mask: BK.Expr):
     # first augment/pad split_decisions
     slice_ones = BK.constants([BK.get_shape(split_decisions, 0), 1],
                               1.)  # [*, 1]
     padded_split_decisions = BK.concat([slice_ones, split_decisions],
                                        -1)  # [*, clen]
     seg_cidxes, seg_masks = BK.mask2idx(
         padded_split_decisions)  # [*, seglen]
     # --
     cand_lens = cand_mask.sum(-1, keepdim=True).long()  # [*, 1]
     seg_masks *= (cand_lens > 0).float()  # for the case of no cands
     # --
     seg_cidxes_special = seg_cidxes + (1. - seg_masks).long(
     ) * cand_lens  # [*, seglen], fill in for paddings
     seg_cidxes_special2 = BK.concat([seg_cidxes_special, cand_lens],
                                     -1)  # [*, seglen+1]
     seg_clens = seg_cidxes_special2[:,
                                     1:] - seg_cidxes_special  # [*, seglen]
     # extend the idxes
     seg_ext_cidxes, seg_ext_masks = expand_ranged_idxes(
         seg_cidxes, seg_clens)  # [*, seglen, MW]
     seg_ext_masks *= seg_masks.unsqueeze(-1)
     return seg_ext_cidxes, seg_ext_masks, seg_masks  # 2x[*, seglen, MW], [*, seglen]
示例#10
0
 def loss(self,
          input_main: BK.Expr,
          input_pair: BK.Expr,
          input_mask: BK.Expr,
          gold_idxes: BK.Expr,
          loss_weight_expr: BK.Expr = None,
          extra_score: BK.Expr = None):
     # not normalize here!
     scores_t = self.score(input_main,
                           input_pair,
                           input_mask,
                           local_normalize=False,
                           extra_score=extra_score)  # [*, L]
     # negative log likelihood
     # all_losses_t = - scores_t.gather(-1, gold_idxes.unsqueeze(-1)).squeeze(-1) * input_mask  # [*]
     all_losses_t = BK.loss_nll(
         scores_t, gold_idxes,
         label_smoothing=self.conf.label_smoothing)  # [*]
     all_losses_t *= input_mask
     if loss_weight_expr is not None:
         all_losses_t *= loss_weight_expr  # [*]
     ret_loss = all_losses_t.sum()  # []
     ret_div = input_mask.sum()
     return (ret_loss, ret_div)
示例#11
0
 def gather_losses(self,
                   scores: List[BK.Expr],
                   label_t: BK.Expr,
                   valid_t: BK.Expr,
                   loss_neg_sample: float = None):
     conf: ZLabelConf = self.conf
     _loss_do_sel = conf.loss_do_sel
     _alpha_binary, _alpha_full = conf.loss_binary_alpha, conf.loss_full_alpha
     _alpha_all_binary = conf.loss_allbinary_alpha
     # --
     if self.crf is not None:  # CRF mode!
         assert _alpha_binary <= 0. and _alpha_all_binary <= 0.
         # reshape them into 3d
         valid_premask = (valid_t.sum(-1) > 0.)  # [bs, ...]
         # note: simply collect them all
         rets = []
         _pm_mask, _pm_label = valid_t[valid_premask], label_t[
             valid_premask]  # [??, slen]
         for score_t in scores:
             _one_pm_score = score_t[valid_premask]  # [??, slen, D]
             _one_fscore_t, _ = self._get_score(
                 _one_pm_score)  # [??, slen, L]
             # --
             # todo(+N): hacky fix, make it a leading NIL
             _pm_mask2 = _pm_mask.clone()
             _pm_mask2[:, 0] = 1.
             # --
             _one_loss, _one_count = self.crf.loss(_one_fscore_t, _pm_mask2,
                                                   _pm_label)  # ??
             rets.append((_one_loss * _alpha_full, _one_count))
     else:
         pos_t = (label_t > 0).float()  # 0 as NIL!!
         loss_mask_t = self._get_loss_mask(
             pos_t, valid_t, loss_neg_sample=loss_neg_sample)  # [bs, ...]
         if _loss_do_sel:
             _sel_mask = (loss_mask_t > 0.)  # [bs, ...]
             _sel_label = label_t[_sel_mask]  # [??]
             _sel_mask2 = BK.constants([len(_sel_label)], 1.)  # [??]
         # note: simply collect them all
         rets = []
         for score_t in scores:
             if _loss_do_sel:  # [??, ]
                 one_score_t, one_mask_t, one_label_t = score_t[
                     _sel_mask], _sel_mask2, _sel_label
             else:  # [bs, ..., D]
                 one_score_t, one_mask_t, one_label_t = score_t, loss_mask_t, label_t
             one_fscore_t, one_nilscore_t = self._get_score(one_score_t)
             # full loss
             one_loss_t = BK.loss_nll(one_fscore_t,
                                      one_label_t) * _alpha_full  # [????]
             # binary loss
             if _alpha_binary > 0.:  # plus ...
                 _binary_loss = BK.loss_binary(
                     one_nilscore_t.squeeze(-1),
                     (one_label_t > 0).float()) * _alpha_binary  # [???]
                 one_loss_t = one_loss_t + _binary_loss
             # all binary
             if _alpha_all_binary > 0.:  # plus ...
                 _tmp_label_t = BK.zeros(
                     BK.get_shape(one_fscore_t))  # [???, L]
                 _tmp_label_t.scatter_(-1, one_label_t.unsqueeze(-1), 1.)
                 _ab_loss = BK.loss_binary(
                     one_fscore_t,
                     _tmp_label_t) * _alpha_all_binary  # [???, L]
                 one_loss_t = one_loss_t + _ab_loss[..., 1:].sum(-1)
             # --
             one_loss_t = one_loss_t * one_mask_t
             rets.append((one_loss_t, one_mask_t))  # tuple(loss, mask)
     return rets
示例#12
0
 def loss(self,
          input_main: BK.Expr,
          input_pair: BK.Expr,
          input_mask: BK.Expr,
          gold_idxes: BK.Expr,
          loss_weight_expr: BK.Expr = None,
          extra_score: BK.Expr = None):
     conf: SeqLabelerConf = self.conf
     # --
     expr_main, expr_pair = self.transform_expr(input_main, input_pair)
     if self.loss_mle:
         # simply collect them all (not normalize here!)
         all_scores_t = self.score_all(
             expr_main,
             expr_pair,
             input_mask,
             gold_idxes,
             local_normalize=False,
             extra_score=extra_score)  # [*, slen, L]
         # negative log likelihood; todo(+1): repeat log-softmax here
         # all_losses_t = - all_scores_t.gather(-1, gold_idxes.unsqueeze(-1)).squeeze(-1) * input_mask  # [*, slen]
         all_losses_t = BK.loss_nll(
             all_scores_t,
             gold_idxes,
             label_smoothing=self.conf.labeler_conf.label_smoothing)  # [*]
         all_losses_t *= input_mask
         if loss_weight_expr is not None:
             all_losses_t *= loss_weight_expr
         ret_loss = all_losses_t.sum()  # []
     elif self.loss_crf:
         # no normalization & no bigram
         single_scores_t = self.score_all(
             expr_main,
             expr_pair,
             input_mask,
             None,
             use_bigram=False,
             extra_score=extra_score)  # [*, slen, L]
         mat_t = self.bigram.get_matrix()  # [L, L]
         if BK.is_zero_shape(single_scores_t):  # note: avoid empty
             potential_t = BK.zeros(
                 BK.get_shape(single_scores_t)[:-2])  # [*]
         else:
             potential_t = BigramInferenceHelper.inference_forward(
                 single_scores_t, mat_t, input_mask, conf.beam_k)  # [*]
         gold_single_scores_t = single_scores_t.gather(
             -1,
             gold_idxes.unsqueeze(-1)).squeeze(-1) * input_mask  # [*, slen]
         gold_bigram_scores_t = mat_t[
             gold_idxes[:, :-1],
             gold_idxes[:, 1:]] * input_mask[:, 1:]  # [*, slen-1]
         all_losses_t = (
             potential_t -
             (gold_single_scores_t.sum(-1) + gold_bigram_scores_t.sum(-1))
         )  # [*]
         # todo(+N): also no label_smoothing for crf
         # todo(+N): for now, ignore loss_weight for crf mode!!
         # if loss_weight_expr is not None:
         #     assert BK.get_shape(loss_weight_expr, -1) == 1, "Currently CRF loss requires seq level loss_weight!!"
         #     all_losses_t *= loss_weight_expr
         ret_loss = all_losses_t.sum()  # []
     else:
         raise NotImplementedError()
     # ret_count
     if conf.loss_by_tok:  # sum all valid toks
         if conf.loss_by_tok_weighted and loss_weight_expr is not None:
             ret_count = (input_mask * loss_weight_expr).sum()
         else:
             ret_count = input_mask.sum()
     else:  # sum all valid batch items
         ret_count = input_mask.prod(-1).sum()
     return (ret_loss, ret_count)