Beispiel #1
0
 def select_plain(self, ags: List[BfsAgenda], candidates, mode, k_arc,
                  k_label) -> List[List]:
     flattened_states, cur_arc_scores, scoring_mask_ct = candidates
     cur_cache = self.cache
     cur_bsize = len(flattened_states)
     cur_slen = cur_cache.max_slen
     cur_arc_scores_flattend = cur_arc_scores.view([cur_bsize,
                                                    -1])  # [bs, Lm*Lh]
     if mode == "topk":
         # arcs [*, k]
         topk_arc_scores, topk_arc_idxes = BK.topk(
             cur_arc_scores_flattend,
             min(k_arc, BK.get_shape(cur_arc_scores_flattend, -1)),
             dim=-1,
             sorted=False)
         topk_m, topk_h = topk_arc_idxes / cur_slen, topk_arc_idxes % cur_slen  # [m, h]
         # labels [*, k, k']
         cur_label_scores = cur_cache.get_selected_label_scores(
             topk_m, topk_h, self.mw_arc, self.mw_label)
         topk_label_scores, topk_label_idxes = BK.topk(
             cur_label_scores,
             min(k_label, BK.get_shape(cur_label_scores, -1)),
             dim=-1,
             sorted=False)
         return self._new_states(flattened_states, scoring_mask_ct,
                                 topk_arc_scores, topk_m, topk_h,
                                 topk_label_scores, topk_label_idxes)
     elif mode == "":
         return [[]] * cur_bsize
     # todo(+N): other modes like sampling to be implemented: sample, topk-sample
     else:
         raise NotImplementedError(mode)
Beispiel #2
0
 def _select_topk(self, masked_scores, pad_mask, ratio_mask, topk_ratio,
                  thresh_k):
     slen = BK.get_shape(masked_scores, -1)
     sel_mask = BK.copy(pad_mask)
     # first apply the absolute thresh
     if thresh_k is not None:
         sel_mask *= (masked_scores > thresh_k).float()
     # then ratio-ed topk
     if topk_ratio > 0.:
         # prepare number
         cur_topk_num = ratio_mask.sum(-1)  # [*]
         cur_topk_num = (cur_topk_num * topk_ratio).long()  # [*]
         cur_topk_num.clamp_(min=1, max=slen)  # at least one, at most all
         # topk
         actual_max_k = max(cur_topk_num.max().item(), 1)
         topk_score, _ = BK.topk(masked_scores,
                                 actual_max_k,
                                 dim=-1,
                                 sorted=True)  # [*, k]
         thresh_score = topk_score.gather(
             -1,
             cur_topk_num.clamp(min=1).unsqueeze(-1) - 1)  # [*, 1]
         # get mask and apply
         sel_mask *= (masked_scores >= thresh_score).float()
     return sel_mask
Beispiel #3
0
 def select_oracle(self, ags: List[BfsAgenda], candidates, mode, k_arc,
                   k_label) -> List[List]:
     flattened_states, cur_arc_scores, scoring_mask_ct = candidates
     cur_cache = self.cache
     cur_bsize = len(flattened_states)
     cur_slen = cur_cache.max_slen
     if mode == "topk":
         # todo(note): there can be multiple oracles, select topk(usually top1) in this mode.
         # get and apply oracle mask
         cur_oracle_mask_t, cur_oracle_label_t = self._get_oracle_mask(
             flattened_states)
         # [bs, Lm*Lh]
         cur_oracle_arc_scores = (cur_arc_scores + Constants.REAL_PRAC_MIN *
                                  (1. - cur_oracle_mask_t)).view(
                                      [cur_bsize, -1])
         # arcs [*, k]
         topk_arc_scores, topk_arc_idxes = BK.topk(cur_oracle_arc_scores,
                                                   k_arc,
                                                   dim=-1,
                                                   sorted=False)
         topk_m, topk_h = topk_arc_idxes / cur_slen, topk_arc_idxes % cur_slen  # [m, h]
         # labels [*, k, 1]
         # todo(note): here we gather labels since one arc can only have one oracle label
         cur_label_scores = cur_cache.get_selected_label_scores(
             topk_m, topk_h, 0., 0.)  # [*, k, labels]
         topk_label_idxes = cur_oracle_label_t[
             cur_cache.bsize_range_t.unsqueeze(-1), topk_m,
             topk_h].unsqueeze(-1)  # [*, k, 1]
         # todo(+N): here is the trick to avoid repeated calculations, maybe not correct when using full dynamic oracle
         topk_label_scores = BK.gather(cur_label_scores, topk_label_idxes,
                                       -1) - self.mw_label
         # todo(+N): here use both masks, which may lead to no oracles! Can we simply drop the oracle_mask?
         return self._new_states(flattened_states,
                                 scoring_mask_ct * cur_cache.oracle_mask_ct,
                                 topk_arc_scores, topk_m, topk_h,
                                 topk_label_scores, topk_label_idxes)
     elif mode == "":
         return [[]] * cur_bsize
     # todo(+N): other modes like sampling to be implemented: sample, topk-sample, gather
     else:
         raise NotImplementedError(mode)
Beispiel #4
0
 def prune_with_scores(arc_score,
                       label_score,
                       mask_expr,
                       pconf: PruneG1Conf,
                       arc_marginals=None):
     prune_use_topk, prune_use_marginal, prune_labeled, prune_perc, prune_topk, prune_gap, prune_mthresh, prune_mthresh_rel = \
         pconf.pruning_use_topk, pconf.pruning_use_marginal, pconf.pruning_labeled, pconf.pruning_perc, pconf.pruning_topk, \
         pconf.pruning_gap, pconf.pruning_mthresh, pconf.pruning_mthresh_rel
     full_score = arc_score + label_score
     final_valid_mask = BK.constants(BK.get_shape(arc_score),
                                     0,
                                     dtype=BK.uint8).squeeze(-1)
     # (put as argument) arc_marginals = None  # [*, mlen, hlen]
     if prune_use_marginal:
         if arc_marginals is None:  # does not provided, calculate from scores
             if prune_labeled:
                 # arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).max(-1)[0]
                 # use sum of label marginals instead of max
                 arc_marginals = nmarginal_unproj(full_score,
                                                  mask_expr,
                                                  None,
                                                  labeled=True).sum(-1)
             else:
                 arc_marginals = nmarginal_unproj(arc_score,
                                                  mask_expr,
                                                  None,
                                                  labeled=True).squeeze(-1)
         if prune_mthresh_rel:
             # relative value
             max_arc_marginals = arc_marginals.max(-1)[0].log().unsqueeze(
                 -1)
             m_valid_mask = (arc_marginals.log() -
                             max_arc_marginals) > float(
                                 np.log(prune_mthresh))
         else:
             # absolute value
             m_valid_mask = (arc_marginals > prune_mthresh
                             )  # [*, len-m, len-h]
         final_valid_mask |= m_valid_mask
     if prune_use_topk:
         # prune by "in topk" and "gap-to-top less than gap" for each mod
         if prune_labeled:  # take argmax among label dim
             tmp_arc_score, _ = full_score.max(-1)
         else:
             # todo(note): may be modified inplaced, but does not matter since will finally be masked later
             tmp_arc_score = arc_score.squeeze(-1)
         # first apply mask
         mask_value = Constants.REAL_PRAC_MIN
         mask_mul = (mask_value * (1. - mask_expr))  # [*, len]
         tmp_arc_score += mask_mul.unsqueeze(-1)
         tmp_arc_score += mask_mul.unsqueeze(-2)
         maxlen = BK.get_shape(tmp_arc_score, -1)
         tmp_arc_score += mask_value * BK.eye(maxlen)
         prune_topk = min(prune_topk, int(maxlen * prune_perc + 1), maxlen)
         if prune_topk >= maxlen:
             topk_arc_score = tmp_arc_score
         else:
             topk_arc_score, _ = BK.topk(tmp_arc_score,
                                         prune_topk,
                                         dim=-1,
                                         sorted=False)  # [*, len, k]
         min_topk_arc_score = topk_arc_score.min(-1)[0].unsqueeze(
             -1)  # [*, len, 1]
         max_topk_arc_score = topk_arc_score.max(-1)[0].unsqueeze(
             -1)  # [*, len, 1]
         arc_score_thresh = BK.max_elem(min_topk_arc_score,
                                        max_topk_arc_score -
                                        prune_gap)  # [*, len, 1]
         t_valid_mask = (tmp_arc_score > arc_score_thresh
                         )  # [*, len-m, len-h]
         final_valid_mask |= t_valid_mask
     return final_valid_mask, arc_marginals