Esempio n. 1
0
 def predict(self,
             insts: Union[List[Sent], List[Frame]],
             input_expr: BK.Expr,
             mask_expr: BK.Expr,
             pair_expr: BK.Expr = None,
             lookup_flatten=False,
             external_extra_score: BK.Expr = None):
     # --
     # note: check empty
     if BK.is_zero_shape(mask_expr):
         for inst in insts:  # still need to clear things!!
             self.helper._clear_f(inst)
     else:
         # simply labeling!
         best_labs, best_scores = self.lab_node.predict(
             input_expr,
             pair_expr,
             mask_expr,
             extra_score=external_extra_score)
         # put results
         self.helper.put_results(insts, best_labs, best_scores)
     # --
     # finally
     return self._finish_pred(insts, input_expr, mask_expr, pair_expr,
                              lookup_flatten)
Esempio n. 2
0
 def _forward_max(self, repr_t: BK.Expr, dsel_seq_info):
     RDIM = 2  # reduce dim
     # --
     _all_repr_t, _ = self._aggregate_subtoks(repr_t, dsel_seq_info)
     ret = _all_repr_t.sum(RDIM) if BK.is_zero_shape(_all_repr_t) else _all_repr_t.max(RDIM)[0]  # [*, dlen, D]
     ret = BK.relu(ret)  # note: for simplicity, just make things>=0.
     return ret
Esempio n. 3
0
 def predict(self, insts: List[Sent], input_expr: BK.Expr, mask_expr: BK.Expr):
     conf: MySRLConf = self.conf
     slen = BK.get_shape(mask_expr, -1)
     # --
     # =====
     # evt
     _, all_evt_cfs, all_evt_raw_scores = self.evt_node.get_all_values()  # [*, slen, Le]
     all_evt_scores = [z.log_softmax(-1) for z in all_evt_raw_scores]
     final_evt_scores = self.evt_node.helper.pred(all_logprobs=all_evt_scores, all_cfs=all_evt_cfs)  # [*, slen, Le]
     if conf.evt_pred_use_all or conf.evt_pred_use_posi:  # todo(+W): not an elegant way...
         final_evt_scores[:,:,0] += Constants.REAL_PRAC_MIN  # all pred sth!!
     pred_evt_scores, pred_evt_labels = final_evt_scores.max(-1)  # [*, slen]
     # =====
     # arg
     _, all_arg_cfs, all_arg_raw_score = self.arg_node.get_all_values()  # [*, slen, slen, La]
     all_arg_scores = [z.log_softmax(-1) for z in all_arg_raw_score]
     final_arg_scores = self.arg_node.helper.pred(all_logprobs=all_arg_scores, all_cfs=all_arg_cfs)  # [*, slen, slen, La]
     # slightly more efficient by masking valid evts??
     full_pred_shape = BK.get_shape(final_arg_scores)[:-1]  # [*, slen, slen]
     pred_arg_scores, pred_arg_labels = BK.zeros(full_pred_shape), BK.zeros(full_pred_shape).long()
     arg_flat_mask = (pred_evt_labels > 0)  # [*, slen]
     flat_arg_scores = final_arg_scores[arg_flat_mask]  # [??, slen, La]
     if not BK.is_zero_shape(flat_arg_scores):  # at least one predicate!
         if self.pred_cons_mat is not None:
             flat_mask_expr = mask_expr.unsqueeze(-2).expand(-1, slen, slen)[arg_flat_mask]  # [*, 1->slen, slen] => [??, slen]
             flat_pred_arg_labels, flat_pred_arg_scores = BigramInferenceHelper.inference_search(
                 flat_arg_scores, self.pred_cons_mat, flat_mask_expr, conf.arg_beam_k)  # [??, slen]
         else:
             flat_pred_arg_scores, flat_pred_arg_labels = flat_arg_scores.max(-1)  # [??, slen]
         pred_arg_scores[arg_flat_mask] = flat_pred_arg_scores
         pred_arg_labels[arg_flat_mask] = flat_pred_arg_labels
     # =====
     # assign
     self.helper.put_results(insts, pred_evt_labels, pred_evt_scores, pred_arg_labels, pred_arg_scores)
Esempio n. 4
0
 def _loss_feed_split(self, mask_expr, split_scores, pred_split_decisions,
                      cand_widxes, cand_masks, cand_expr, cand_scores,
                      expr_seq_gaddr):
     conf: SoftExtractorConf = self.conf
     bsize, slen = BK.get_shape(mask_expr)
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [*, 1, 1]
     # --
     # step 2.1: split loss (only on good points (excluding -1|-1 or paddings) with dynamic oracle)
     cand_gaddr = expr_seq_gaddr[arange2_t, cand_widxes]  # [*, clen]
     cand_gaddr0, cand_gaddr1 = cand_gaddr[:, :
                                           -1], cand_gaddr[:,
                                                           1:]  # [*, clen-1]
     split_oracle = (cand_gaddr0 !=
                     cand_gaddr1).float() * cand_masks[:, 1:]  # [*, clen-1]
     split_oracle_mask = (
         (cand_gaddr0 >= 0) |
         (cand_gaddr1 >= 0)).float() * cand_masks[:, 1:]  # [*, clen-1]
     raw_split_loss = BK.loss_binary(
         split_scores,
         split_oracle,
         label_smoothing=conf.split_label_smoothing)  # [*, slen]
     loss_split_item = LossHelper.compile_leaf_loss(
         f"split", (raw_split_loss * split_oracle_mask).sum(),
         split_oracle_mask.sum(),
         loss_lambda=conf.loss_split)
     # step 2.2: feed split
     # note: when teacher-forcing, only forcing good points, others still use pred
     force_split_decisions = split_oracle_mask * split_oracle + (
         1. - split_oracle_mask) * pred_split_decisions  # [*, clen-1]
     _use_force_mask = (BK.rand([bsize])
                        <= conf.split_feed_force_rate).float().unsqueeze(
                            -1)  # [*, 1], seq-level
     feed_split_decisions = (_use_force_mask * force_split_decisions +
                             (1. - _use_force_mask) * pred_split_decisions
                             )  # [*, clen-1]
     # next
     # *[*, seglen, MW], [*, seglen]
     seg_ext_cidxes, seg_ext_masks, seg_masks = self._split_extend(
         feed_split_decisions, cand_masks)
     seg_ext_scores, seg_ext_cidxes, seg_ext_widxes, seg_ext_masks, seg_weighted_expr = self._split_aggregate(
         cand_expr, cand_scores, cand_widxes, seg_ext_cidxes, seg_ext_masks,
         conf.split_topk)  # [*, seglen, ?]
     # finally get oracles for next steps
     # todo(+N): simply select the highest scored one as oracle
     if BK.is_zero_shape(seg_ext_scores):  # simply make them all -1
         oracle_gaddr = BK.constants_idx(seg_masks.shape, -1)  # [*, seglen]
     else:
         _, _seg_max_t = seg_ext_scores.max(-1,
                                            keepdim=True)  # [*, seglen, 1]
         oracle_widxes = seg_ext_widxes.gather(-1, _seg_max_t).squeeze(
             -1)  # [*, seglen]
         oracle_gaddr = expr_seq_gaddr.gather(-1,
                                              oracle_widxes)  # [*, seglen]
     oracle_gaddr[seg_masks <= 0] = -1  # (assign invalid ones) [*, seglen]
     return loss_split_item, seg_masks, seg_ext_widxes, seg_ext_masks, seg_weighted_expr, oracle_gaddr
Esempio n. 5
0
 def prepare_sd_init(self, expr_main: BK.Expr, expr_pair: BK.Expr):
     if self.is_pairwise:
         sd_init_t = self.sd_init_aff(expr_pair)  # [*, hid]
     else:
         if BK.is_zero_shape(expr_main):
             sd_init_t0 = expr_main.sum(-2)  # simply make the shape!
         else:
             sd_init_t0 = self.sd_init_pool_f(expr_main,
                                              -2)  # pooling at -2: [*, Dm']
         sd_init_t = self.sd_init_aff(sd_init_t0)  # [*, hid]
     return sd_init_t
Esempio n. 6
0
 def loss(self, unary_scores: BK.Expr, input_mask: BK.Expr, gold_idxes: BK.Expr):
     mat_t = self.bigram.get_matrix()  # [L, L]
     if BK.is_zero_shape(unary_scores):  # note: avoid empty
         potential_t = BK.zeros(BK.get_shape(unary_scores)[:-2])  # [*]
     else:
         potential_t = BigramInferenceHelper.inference_forward(unary_scores, mat_t, input_mask, self.conf.crf_beam)  # [*]
     gold_single_scores_t = unary_scores.gather(-1, gold_idxes.unsqueeze(-1)).squeeze(-1) * input_mask  # [*, slen]
     gold_bigram_scores_t = mat_t[gold_idxes[:, :-1], gold_idxes[:, 1:]] * input_mask[:, 1:]  # [*, slen-1]
     all_losses_t = (potential_t - (gold_single_scores_t.sum(-1) + gold_bigram_scores_t.sum(-1)))  # [*]
     if self.conf.loss_by_tok:
         ret_count = input_mask.sum()  # []
     else:
         ret_count = (input_mask.sum(-1)>0).float()  # [*]
     return all_losses_t, ret_count
Esempio n. 7
0
 def _aggregate_subtoks(self, repr_t: BK.Expr, dsel_seq_info):
     conf: DSelectorConf = self.conf
     _arange_t, _sel_t, _len_t = dsel_seq_info.arange2_t, dsel_seq_info.dec_sel_idxes, dsel_seq_info.dec_sel_lens
     _max_len = 1 if BK.is_zero_shape(_len_t) else _len_t.max().item()
     _max_len = max(1, min(conf.dsel_max_subtoks, _max_len))  # truncate
     # --
     _tmp_arange_t = BK.arange_idx(_max_len)  # [M]
     _all_valids_t = (_tmp_arange_t < _len_t.unsqueeze(-1)).float()  # [*, dlen, M]
     _tmp_arange_t = _tmp_arange_t * _all_valids_t.long()  # note: pad as 0
     _all_idxes_t = _sel_t.unsqueeze(-1) + _tmp_arange_t  # [*, dlen, M]
     _all_repr_t = repr_t[_arange_t.unsqueeze(-1), _all_idxes_t]  # [*, dlen, M, D]
     while len(BK.get_shape(_all_valids_t)) < len(BK.get_shape(_all_repr_t)):
         _all_valids_t = _all_valids_t.unsqueeze(-1)
     _all_repr_t = _all_repr_t * _all_valids_t
     return _all_repr_t, _all_valids_t
Esempio n. 8
0
def expand_ranged_idxes(widx_t: BK.Expr,
                        wlen_t: BK.Expr,
                        pad: int = 0,
                        max_width: int = None):
    if max_width is None:  # if not provided
        if BK.is_zero_shape(wlen_t):
            max_width = 1  # at least one
        else:
            max_width = wlen_t.max().item()  # overall max width
    # --
    input_shape = BK.get_shape(widx_t)  # [*]
    mw_range_t = BK.arange_idx(max_width).view([1] * len(input_shape) +
                                               [-1])  # [*, MW]
    expanded_idxes = widx_t.unsqueeze(-1) + mw_range_t  # [*, MW]
    expanded_masks_bool = (mw_range_t < wlen_t.unsqueeze(-1))  # [*, MW]
    expanded_idxes.masked_fill_(~expanded_masks_bool, pad)  # [*, MW]
    return expanded_idxes, expanded_masks_bool.float()
Esempio n. 9
0
 def go_topk(
         self,
         input_expr: BK.Expr,
         input_mask: BK.Expr,  # input
         widx_expr: BK.Expr,
         wlen_expr: BK.Expr,
         span_mask: BK.Expr,
         rate: float = None,
         count: float = None,  # span
         gaddr_expr: BK.Expr = None,
         add_gold_rate: float = 0.,  # gold
         non_overlapping=False,
         score_prune: float = None):  # non-overlapping!
     lookup_res = self.go_lookup(input_expr, widx_expr, wlen_expr,
                                 span_mask, gaddr_expr)  # [bsize, NUM, *]
     # --
     with BK.no_grad_env():  # no need grad here!
         all_score_expr = lookup_res.score_expr
         # get topk score: again rate is to the original input length
         if BK.is_zero_shape(lookup_res.mask_expr):
             topk_mask = lookup_res.mask_expr.clone(
             )  # no need to go topk since no elements
         else:
             topk_expr = self._determine_size(
                 input_mask.sum(-1, keepdim=True), rate,
                 count).long()  # [bsize, 1]
             if non_overlapping:
                 topk_mask = select_topk_non_overlapping(all_score_expr,
                                                         topk_expr,
                                                         widx_expr,
                                                         wlen_expr,
                                                         input_mask,
                                                         mask_t=span_mask,
                                                         dim=-1)
             else:
                 topk_mask = select_topk(all_score_expr,
                                         topk_expr,
                                         mask_t=span_mask,
                                         dim=-1)
         # further score_prune?
         if score_prune is not None:
             topk_mask *= (all_score_expr >= score_prune).float()
     # select and add_gold
     return self._go_common(lookup_res, topk_mask, add_gold_rate)
Esempio n. 10
0
 def loss(self, all_losses: List[BK.Expr], all_cfs: List[BK.Expr],
          **kwargs):
     conf: IdecHelperCW2Conf = self.conf
     _temp = self.temperature.value
     # --
     stack_t = BK.stack(all_losses, -1)  # [*, NL]
     w_t = (-stack_t / _temp)  # [*, NL], smaller loss is better!
     w_t_detach = w_t.detach()
     # main loss
     apply_w_t = w_t_detach if conf.detach_weights else w_t
     ret_t = (stack_t * apply_w_t.softmax(-1)).sum(-1)  # [*]
     # cf loss
     cf_t = BK.stack(all_cfs, -1).sigmoid()  # [*, NL]
     if conf.cf_trg_rel:  # relative prob proportion?
         _max_t = w_t_detach.sum(-1, keepdim=True) if BK.is_zero_shape(
             w_t_detach) else w_t_detach.max(-1, keepdim=True)[0]  # [*, 1]
         _trg_t = (w_t_detach - _max_t).exp() * conf.max_cf  # [*, NL]
     else:
         _trg_t = w_t_detach.exp() * conf.max_cf
     loss_cf_t = BK.loss_binary(cf_t, _trg_t).mean(-1)  # [*]
     return [(ret_t, 1., ""), (loss_cf_t, conf.loss_cf, "_cf")]
Esempio n. 11
0
 def predict(self, insts: Union[List[Sent], List[Frame]], input_expr: BK.Expr, mask_expr: BK.Expr,
             pair_expr: BK.Expr = None, lookup_flatten=False, external_extra_score: BK.Expr=None):
     conf: DirectExtractorConf = self.conf
     # step 1: prepare targets
     if conf.pred_use_posi:
         # step 1a: directly use provided positions
         arr_gold_items, expr_gold_gaddr, expr_gold_widxes, expr_gold_wlens, _ = self.helper.prepare(insts, use_cache=False)
         cand_res = self.extract_node.go_lookup(input_expr, expr_gold_widxes, expr_gold_wlens,
                                                (expr_gold_gaddr>=0).float(), gaddr_expr=expr_gold_gaddr)
     else:
         arr_gold_items = None
         # step 1b: extract cands (topk); todo(note): assume no in-middle mask!!
         cand_widx, cand_wlen, cand_mask, _ = self.extract_node.prepare_with_lengths(
             BK.get_shape(mask_expr), mask_expr.sum(-1).long(), None, None, None)
         cand_res = self.extract_node.go_topk(
             input_expr, mask_expr, cand_widx, cand_wlen, cand_mask,
             rate=conf.span_topk_rate, count=conf.span_topk_count,
             non_overlapping=conf.pred_non_overlapping, score_prune=conf.pred_score_prune)
     # --
     # note: check empty
     if BK.is_zero_shape(cand_res.mask_expr):
         if not conf.pred_use_posi:
             for inst in insts:  # still need to clear things!!
                 self.helper._clear_f(inst)
     else:
         # step 2: labeling
         # extra score
         cand_extra_score = self._get_extra_score(
             cand_res.score_expr, insts, cand_res, arr_gold_items, conf.pred_use_cons, conf.pred_use_lu)
         final_extra_score = self._sum_scores(external_extra_score, cand_extra_score)
         best_labs, best_scores = self.lab_node.predict(
             cand_res.span_expr, pair_expr, cand_res.mask_expr, extra_score=final_extra_score)
         # step 3: put results
         if conf.pred_use_posi:  # reuse the old ones, but replace label
             self.helper.put_labels(arr_gold_items, best_labs, best_scores)
         else:  # make new frames
             self.helper.put_results(insts, best_labs, best_scores, cand_res.widx_expr, cand_res.wlen_expr, cand_res.mask_expr)
     # --
     # finally
     return self._finish_pred(insts, input_expr, mask_expr, pair_expr, lookup_flatten)
Esempio n. 12
0
 def _cand_score_and_select(self, input_expr: BK.Expr, mask_expr: BK.Expr):
     conf: SoftExtractorConf = self.conf
     # --
     cand_full_scores = self.cand_scorer(input_expr).squeeze(
         -1) + (1. - mask_expr) * Constants.REAL_PRAC_MIN  # [*, slen]
     # decide topk count
     len_t = mask_expr.sum(-1)  # [*]
     topk_t = (len_t * conf.cand_topk_rate).clamp(
         max=conf.cand_topk_count).ceil().long().unsqueeze(-1)  # [*, 1]
     # get topk mask
     if BK.is_zero_shape(mask_expr):
         topk_mask = mask_expr.clone(
         )  # no need to go topk since no elements
     else:
         topk_mask = select_topk(cand_full_scores,
                                 topk_t,
                                 mask_t=mask_expr,
                                 dim=-1)  # [*, slen]
     # thresh
     cand_decisions = topk_mask * (cand_full_scores >=
                                   conf.cand_thresh).float()  # [*, slen]
     return cand_full_scores, cand_decisions  # [*, slen]
Esempio n. 13
0
 def _pred_arg(self, mask_expr, pred_evt_labels):
     conf: ZDecoderSRLConf = self.conf
     slen = BK.get_shape(mask_expr, -1)
     # --
     all_arg_raw_score = self.arg_node.buffer_scores.values()  # [*, slen, slen, La]
     all_arg_logprobs = [z.log_softmax(-1) for z in all_arg_raw_score]
     final_arg_logprobs = self.arg_node.helper.pred(all_logprobs=all_arg_logprobs)  # [*, slen, slen, La]
     # slightly more efficient by masking valid evts??
     full_pred_shape = BK.get_shape(final_arg_logprobs)[:-1]  # [*, slen, slen]
     pred_arg_scores, pred_arg_labels = BK.zeros(full_pred_shape), BK.zeros(full_pred_shape).long()
     # mask
     arg_flat_mask = (pred_evt_labels > 0)  # [*, slen]
     flat_arg_logprobs = final_arg_logprobs[arg_flat_mask]  # [??, slen, La]
     if not BK.is_zero_shape(flat_arg_logprobs):  # at least one predicate
         if self.pred_cons_mat is not None:
             flat_mask_expr = mask_expr.unsqueeze(-2).expand(-1, slen, slen)[arg_flat_mask]  # [*, 1->slen, slen] => [??, slen]
             flat_pred_arg_labels, flat_pred_arg_scores = BigramInferenceHelper.inference_search(
                 flat_arg_logprobs, self.pred_cons_mat, flat_mask_expr, conf.arg_beam_k)  # [??, slen]
         else:
             flat_pred_arg_scores, flat_pred_arg_labels = flat_arg_logprobs.max(-1)  # [??, slen]
         pred_arg_scores[arg_flat_mask] = flat_pred_arg_scores
         pred_arg_labels[arg_flat_mask] = flat_pred_arg_labels
     return pred_arg_labels, pred_arg_scores  # [*, slen, slen, La]
Esempio n. 14
0
 def _loss_feed_cand(self, mask_expr, cand_full_scores, pred_cand_decisions,
                     expr_seq_gaddr, expr_group_widxes, expr_group_masks,
                     expr_loss_weight_non):
     conf: SoftExtractorConf = self.conf
     bsize, slen = BK.get_shape(mask_expr)
     arange3_t = BK.arange_idx(bsize).unsqueeze(-1).unsqueeze(
         -1)  # [*, 1, 1]
     # --
     # step 1.1: bag loss
     cand_gold_mask = (expr_seq_gaddr >=
                       0).float() * mask_expr  # [*, slen], whether is-arg
     raw_loss_cand = BK.loss_binary(
         cand_full_scores,
         cand_gold_mask,
         label_smoothing=conf.cand_label_smoothing)  # [*, slen]
     # how to weight?
     extended_scores_t = cand_full_scores[arange3_t, expr_group_widxes] + (
         1. - expr_group_masks) * Constants.REAL_PRAC_MIN  # [*, slen, MW]
     if BK.is_zero_shape(extended_scores_t):
         extended_scores_max_t = BK.zeros(mask_expr.shape)  # [*, slen]
     else:
         extended_scores_max_t, _ = extended_scores_t.max(-1)  # [*, slen]
     _w_alpha = conf.cand_loss_weight_alpha
     _weight = ((cand_full_scores - extended_scores_max_t) *
                _w_alpha).exp()  # [*, slen]
     if not conf.cand_loss_div_max:  # div sum-all, like doing softmax
         _weight = _weight / (
             (extended_scores_t - extended_scores_max_t.unsqueeze(-1)) *
             _w_alpha).exp().sum(-1)
     _weight = _weight * (_weight >=
                          conf.cand_loss_weight_thresh).float()  # [*, slen]
     if conf.cand_detach_weight:
         _weight = _weight.detach()
     # pos poison (dis-encouragement)
     if conf.cand_loss_pos_poison:
         poison_loss = BK.loss_binary(
             cand_full_scores,
             1. - cand_gold_mask,
             label_smoothing=conf.cand_label_smoothing)  # [*, slen]
         raw_loss_cand = raw_loss_cand * _weight + poison_loss * cand_gold_mask * (
             1. - _weight)  # [*, slen]
     else:
         raw_loss_cand = raw_loss_cand * _weight
     # final weight it
     cand_loss_weights = BK.where(cand_gold_mask == 0.,
                                  expr_loss_weight_non.unsqueeze(-1) *
                                  conf.loss_weight_non,
                                  mask_expr)  # [*, slen]
     final_cand_loss_weights = cand_loss_weights * mask_expr  # [*, slen]
     loss_cand_item = LossHelper.compile_leaf_loss(
         f"cand", (raw_loss_cand * final_cand_loss_weights).sum(),
         final_cand_loss_weights.sum(),
         loss_lambda=conf.loss_cand)
     # step 1.2: feed cand
     # todo(+N): currently only pred/sample, whether adding certain teacher-forcing?
     sample_decisions = (BK.sigmoid(cand_full_scores) >= BK.rand(
         cand_full_scores.shape)).float() * mask_expr  # [*, slen]
     _use_sample_mask = (BK.rand([bsize])
                         <= conf.cand_feed_sample_rate).float().unsqueeze(
                             -1)  # [*, 1], seq-level
     feed_cand_decisions = (_use_sample_mask * sample_decisions +
                            (1. - _use_sample_mask) * pred_cand_decisions
                            )  # [*, slen]
     # next
     cand_widxes, cand_masks = BK.mask2idx(feed_cand_decisions)  # [*, clen]
     # --
     # extra: loss_cand_entropy
     rets = [loss_cand_item]
     _loss_cand_entropy = conf.loss_cand_entropy
     if _loss_cand_entropy > 0.:
         _prob = extended_scores_t.softmax(-1)  # [*, slen, MW]
         _ent = EntropyHelper.self_entropy(_prob)  # [*, slen]
         # [*, slen], only first one in bag
         _ent_mask = BK.concat([
             expr_seq_gaddr[:, :1] >= 0,
             expr_seq_gaddr[:, 1:] != expr_seq_gaddr[:, :-1]
         ], -1).float() * cand_gold_mask
         _loss_ent_item = LossHelper.compile_leaf_loss(
             f"cand_ent", (_ent * _ent_mask).sum(),
             _ent_mask.sum(),
             loss_lambda=_loss_cand_entropy)
         rets.append(_loss_ent_item)
     # --
     return rets, cand_widxes, cand_masks
Esempio n. 15
0
 def loss(self,
          insts: Union[List[Sent], List[Frame]],
          input_expr: BK.Expr,
          mask_expr: BK.Expr,
          pair_expr: BK.Expr = None,
          lookup_flatten=False,
          external_extra_score: BK.Expr = None):
     conf: AnchorExtractorConf = self.conf
     assert not lookup_flatten
     bsize, slen = BK.get_shape(mask_expr)
     # --
     # step 0: prepare
     arr_items, expr_seq_gaddr, expr_seq_labs, expr_group_widxes, expr_group_masks, expr_loss_weight_non = \
         self.helper.prepare(insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True)
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [*, 1]
     arange3_t = arange2_t.unsqueeze(-1)  # [*, 1, 1]
     # --
     # step 1: label, simply scoring everything!
     _main_t, _pair_t = self.lab_node.transform_expr(input_expr, pair_expr)
     all_scores_t = self.lab_node.score_all(
         _main_t,
         _pair_t,
         mask_expr,
         None,
         local_normalize=False,
         extra_score=external_extra_score
     )  # unnormalized scores [*, slen, L]
     all_probs_t = all_scores_t.softmax(-1)  # [*, slen, L]
     all_gprob_t = all_probs_t.gather(-1,
                                      expr_seq_labs.unsqueeze(-1)).squeeze(
                                          -1)  # [*, slen]
     # how to weight
     extended_gprob_t = all_gprob_t[
         arange3_t, expr_group_widxes] * expr_group_masks  # [*, slen, MW]
     if BK.is_zero_shape(extended_gprob_t):
         extended_gprob_max_t = BK.zeros(mask_expr.shape)  # [*, slen]
     else:
         extended_gprob_max_t, _ = extended_gprob_t.max(-1)  # [*, slen]
     _w_alpha = conf.cand_loss_weight_alpha
     _weight = (
         (all_gprob_t * mask_expr) /
         (extended_gprob_max_t.clamp(min=1e-5)))**_w_alpha  # [*, slen]
     _label_smoothing = conf.lab_conf.labeler_conf.label_smoothing
     _loss1 = BK.loss_nll(all_scores_t,
                          expr_seq_labs,
                          label_smoothing=_label_smoothing)  # [*, slen]
     _loss2 = BK.loss_nll(all_scores_t,
                          BK.constants_idx([bsize, slen], 0),
                          label_smoothing=_label_smoothing)  # [*, slen]
     _weight1 = _weight.detach() if conf.detach_weight_lab else _weight
     _raw_loss = _weight1 * _loss1 + (1. - _weight1) * _loss2  # [*, slen]
     # final weight it
     cand_loss_weights = BK.where(expr_seq_labs == 0,
                                  expr_loss_weight_non.unsqueeze(-1) *
                                  conf.loss_weight_non,
                                  mask_expr)  # [*, slen]
     final_cand_loss_weights = cand_loss_weights * mask_expr  # [*, slen]
     loss_lab_item = LossHelper.compile_leaf_loss(
         f"lab", (_raw_loss * final_cand_loss_weights).sum(),
         final_cand_loss_weights.sum(),
         loss_lambda=conf.loss_lab,
         gold=(expr_seq_labs > 0).float().sum())
     # --
     # step 1.5
     all_losses = [loss_lab_item]
     _loss_cand_entropy = conf.loss_cand_entropy
     if _loss_cand_entropy > 0.:
         _prob = extended_gprob_t  # [*, slen, MW]
         _ent = EntropyHelper.self_entropy(_prob)  # [*, slen]
         # [*, slen], only first one in bag
         _ent_mask = BK.concat([expr_seq_gaddr[:,:1]>=0, expr_seq_gaddr[:,1:]!=expr_seq_gaddr[:,:-1]],-1).float() \
                     * (expr_seq_labs>0).float()
         _loss_ent_item = LossHelper.compile_leaf_loss(
             f"cand_ent", (_ent * _ent_mask).sum(),
             _ent_mask.sum(),
             loss_lambda=_loss_cand_entropy)
         all_losses.append(_loss_ent_item)
     # --
     # step 4: extend (select topk)
     if conf.loss_ext > 0.:
         if BK.is_zero_shape(extended_gprob_t):
             flt_mask = (BK.zeros(mask_expr.shape) > 0)
         else:
             _topk = min(conf.ext_loss_topk,
                         BK.get_shape(extended_gprob_t,
                                      -1))  # number to extract
             _topk_grpob_t, _ = extended_gprob_t.topk(
                 _topk, dim=-1)  # [*, slen, K]
             flt_mask = (expr_seq_labs >
                         0) & (all_gprob_t >= _topk_grpob_t.min(-1)[0]) & (
                             _weight > conf.ext_loss_thresh)  # [*, slen]
         flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[
             flt_mask]  # [?]
         flt_expr = input_expr[flt_mask]  # [?, D]
         flt_full_expr = self._prepare_full_expr(flt_mask)  # [?, slen, D]
         flt_items = arr_items.flatten()[BK.get_value(
             expr_seq_gaddr[flt_mask])]  # [?]
         flt_weights = _weight.detach(
         )[flt_mask] if conf.detach_weight_ext else _weight[flt_mask]  # [?]
         loss_ext_item = self.ext_node.loss(flt_items,
                                            input_expr[flt_sidx],
                                            flt_expr,
                                            flt_full_expr,
                                            mask_expr[flt_sidx],
                                            flt_extra_weights=flt_weights)
         all_losses.append(loss_ext_item)
     # --
     # return loss
     ret_loss = LossHelper.combine_multiple_losses(all_losses)
     return ret_loss, None
Esempio n. 16
0
 def predict(self,
             insts: Union[List[Sent], List[Frame]],
             input_expr: BK.Expr,
             mask_expr: BK.Expr,
             pair_expr: BK.Expr = None,
             lookup_flatten=False,
             external_extra_score: BK.Expr = None):
     conf: SoftExtractorConf = self.conf
     assert not lookup_flatten
     bsize, slen = BK.get_shape(mask_expr)
     # --
     for inst in insts:  # first clear things
         self.helper._clear_f(inst)
     # --
     # step 1: cand score and select
     cand_full_scores, cand_decisions = self._cand_score_and_select(
         input_expr, mask_expr)  # [*, slen]
     cand_widxes, cand_masks = BK.mask2idx(cand_decisions)  # [*, clen]
     # step 2: split and seg
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [*, 1]
     arange3_t = BK.arange_idx(bsize).unsqueeze(-1).unsqueeze(
         -1)  # [*, 1, 1]
     cand_expr, cand_scores = input_expr[
         arange2_t, cand_widxes], cand_full_scores[arange2_t,
                                                   cand_widxes]  # [*, clen]
     split_scores, split_decisions = self._split_score(
         cand_expr, cand_masks)  # [*, clen-1]
     # *[*, seglen, MW], [*, seglen]
     seg_ext_cidxes, seg_ext_masks, seg_masks = self._split_extend(
         split_decisions, cand_masks)
     seg_ext_widxes0, seg_ext_masks0 = cand_widxes[
         arange3_t, seg_ext_cidxes], seg_ext_masks  # [*, seglen, ORIG-MW]
     seg_ext_scores, seg_ext_cidxes, seg_ext_widxes, seg_ext_masks, seg_weighted_expr = self._split_aggregate(
         cand_expr, cand_scores, cand_widxes, seg_ext_cidxes, seg_ext_masks,
         conf.split_topk)  # [*, seglen, ?]
     # --
     # step 3: lab
     flt_items = []
     if not BK.is_zero_shape(seg_masks):
         best_labs, best_scores = self.lab_node.predict(
             seg_weighted_expr,
             pair_expr,
             seg_masks,
             extra_score=external_extra_score)  # *[*, seglen]
         flt_items = self.helper.put_results(
             insts, best_labs, best_scores, seg_masks, seg_ext_widxes0,
             seg_ext_widxes, seg_ext_masks0, seg_ext_masks,
             cand_full_scores, cand_decisions, split_decisions)
     # --
     # step 4: final extend (in a flattened way)
     if len(flt_items) > 0 and conf.pred_ext:
         flt_mask = ((best_labs > 0) & (seg_masks > 0.))  # [*, seglen]
         flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[
             flt_mask]  # [?]
         flt_expr = seg_weighted_expr[flt_mask]  # [?, D]
         flt_full_expr = self._prepare_full_expr(seg_ext_widxes[flt_mask],
                                                 seg_ext_masks[flt_mask],
                                                 slen)  # [?, slen, D]
         self.ext_node.predict(flt_items, input_expr[flt_sidx], flt_expr,
                               flt_full_expr, mask_expr[flt_sidx])
     # --
     # extra:
     self.pp_node.prune(insts)
     return None
Esempio n. 17
0
 def loss(self,
          input_main: BK.Expr,
          input_pair: BK.Expr,
          input_mask: BK.Expr,
          gold_idxes: BK.Expr,
          loss_weight_expr: BK.Expr = None,
          extra_score: BK.Expr = None):
     conf: SeqLabelerConf = self.conf
     # --
     expr_main, expr_pair = self.transform_expr(input_main, input_pair)
     if self.loss_mle:
         # simply collect them all (not normalize here!)
         all_scores_t = self.score_all(
             expr_main,
             expr_pair,
             input_mask,
             gold_idxes,
             local_normalize=False,
             extra_score=extra_score)  # [*, slen, L]
         # negative log likelihood; todo(+1): repeat log-softmax here
         # all_losses_t = - all_scores_t.gather(-1, gold_idxes.unsqueeze(-1)).squeeze(-1) * input_mask  # [*, slen]
         all_losses_t = BK.loss_nll(
             all_scores_t,
             gold_idxes,
             label_smoothing=self.conf.labeler_conf.label_smoothing)  # [*]
         all_losses_t *= input_mask
         if loss_weight_expr is not None:
             all_losses_t *= loss_weight_expr
         ret_loss = all_losses_t.sum()  # []
     elif self.loss_crf:
         # no normalization & no bigram
         single_scores_t = self.score_all(
             expr_main,
             expr_pair,
             input_mask,
             None,
             use_bigram=False,
             extra_score=extra_score)  # [*, slen, L]
         mat_t = self.bigram.get_matrix()  # [L, L]
         if BK.is_zero_shape(single_scores_t):  # note: avoid empty
             potential_t = BK.zeros(
                 BK.get_shape(single_scores_t)[:-2])  # [*]
         else:
             potential_t = BigramInferenceHelper.inference_forward(
                 single_scores_t, mat_t, input_mask, conf.beam_k)  # [*]
         gold_single_scores_t = single_scores_t.gather(
             -1,
             gold_idxes.unsqueeze(-1)).squeeze(-1) * input_mask  # [*, slen]
         gold_bigram_scores_t = mat_t[
             gold_idxes[:, :-1],
             gold_idxes[:, 1:]] * input_mask[:, 1:]  # [*, slen-1]
         all_losses_t = (
             potential_t -
             (gold_single_scores_t.sum(-1) + gold_bigram_scores_t.sum(-1))
         )  # [*]
         # todo(+N): also no label_smoothing for crf
         # todo(+N): for now, ignore loss_weight for crf mode!!
         # if loss_weight_expr is not None:
         #     assert BK.get_shape(loss_weight_expr, -1) == 1, "Currently CRF loss requires seq level loss_weight!!"
         #     all_losses_t *= loss_weight_expr
         ret_loss = all_losses_t.sum()  # []
     else:
         raise NotImplementedError()
     # ret_count
     if conf.loss_by_tok:  # sum all valid toks
         if conf.loss_by_tok_weighted and loss_weight_expr is not None:
             ret_count = (input_mask * loss_weight_expr).sum()
         else:
             ret_count = input_mask.sum()
     else:  # sum all valid batch items
         ret_count = input_mask.prod(-1).sum()
     return (ret_loss, ret_count)