def select_plain(self, ags: List[BfsAgenda], candidates, mode, k_arc, k_label) -> List[List]: flattened_states, cur_arc_scores, scoring_mask_ct = candidates cur_cache = self.cache cur_bsize = len(flattened_states) cur_slen = cur_cache.max_slen cur_arc_scores_flattend = cur_arc_scores.view([cur_bsize, -1]) # [bs, Lm*Lh] if mode == "topk": # arcs [*, k] topk_arc_scores, topk_arc_idxes = BK.topk( cur_arc_scores_flattend, min(k_arc, BK.get_shape(cur_arc_scores_flattend, -1)), dim=-1, sorted=False) topk_m, topk_h = topk_arc_idxes / cur_slen, topk_arc_idxes % cur_slen # [m, h] # labels [*, k, k'] cur_label_scores = cur_cache.get_selected_label_scores( topk_m, topk_h, self.mw_arc, self.mw_label) topk_label_scores, topk_label_idxes = BK.topk( cur_label_scores, min(k_label, BK.get_shape(cur_label_scores, -1)), dim=-1, sorted=False) return self._new_states(flattened_states, scoring_mask_ct, topk_arc_scores, topk_m, topk_h, topk_label_scores, topk_label_idxes) elif mode == "": return [[]] * cur_bsize # todo(+N): other modes like sampling to be implemented: sample, topk-sample else: raise NotImplementedError(mode)
def _select_topk(self, masked_scores, pad_mask, ratio_mask, topk_ratio, thresh_k): slen = BK.get_shape(masked_scores, -1) sel_mask = BK.copy(pad_mask) # first apply the absolute thresh if thresh_k is not None: sel_mask *= (masked_scores > thresh_k).float() # then ratio-ed topk if topk_ratio > 0.: # prepare number cur_topk_num = ratio_mask.sum(-1) # [*] cur_topk_num = (cur_topk_num * topk_ratio).long() # [*] cur_topk_num.clamp_(min=1, max=slen) # at least one, at most all # topk actual_max_k = max(cur_topk_num.max().item(), 1) topk_score, _ = BK.topk(masked_scores, actual_max_k, dim=-1, sorted=True) # [*, k] thresh_score = topk_score.gather( -1, cur_topk_num.clamp(min=1).unsqueeze(-1) - 1) # [*, 1] # get mask and apply sel_mask *= (masked_scores >= thresh_score).float() return sel_mask
def select_oracle(self, ags: List[BfsAgenda], candidates, mode, k_arc, k_label) -> List[List]: flattened_states, cur_arc_scores, scoring_mask_ct = candidates cur_cache = self.cache cur_bsize = len(flattened_states) cur_slen = cur_cache.max_slen if mode == "topk": # todo(note): there can be multiple oracles, select topk(usually top1) in this mode. # get and apply oracle mask cur_oracle_mask_t, cur_oracle_label_t = self._get_oracle_mask( flattened_states) # [bs, Lm*Lh] cur_oracle_arc_scores = (cur_arc_scores + Constants.REAL_PRAC_MIN * (1. - cur_oracle_mask_t)).view( [cur_bsize, -1]) # arcs [*, k] topk_arc_scores, topk_arc_idxes = BK.topk(cur_oracle_arc_scores, k_arc, dim=-1, sorted=False) topk_m, topk_h = topk_arc_idxes / cur_slen, topk_arc_idxes % cur_slen # [m, h] # labels [*, k, 1] # todo(note): here we gather labels since one arc can only have one oracle label cur_label_scores = cur_cache.get_selected_label_scores( topk_m, topk_h, 0., 0.) # [*, k, labels] topk_label_idxes = cur_oracle_label_t[ cur_cache.bsize_range_t.unsqueeze(-1), topk_m, topk_h].unsqueeze(-1) # [*, k, 1] # todo(+N): here is the trick to avoid repeated calculations, maybe not correct when using full dynamic oracle topk_label_scores = BK.gather(cur_label_scores, topk_label_idxes, -1) - self.mw_label # todo(+N): here use both masks, which may lead to no oracles! Can we simply drop the oracle_mask? return self._new_states(flattened_states, scoring_mask_ct * cur_cache.oracle_mask_ct, topk_arc_scores, topk_m, topk_h, topk_label_scores, topk_label_idxes) elif mode == "": return [[]] * cur_bsize # todo(+N): other modes like sampling to be implemented: sample, topk-sample, gather else: raise NotImplementedError(mode)
def prune_with_scores(arc_score, label_score, mask_expr, pconf: PruneG1Conf, arc_marginals=None): prune_use_topk, prune_use_marginal, prune_labeled, prune_perc, prune_topk, prune_gap, prune_mthresh, prune_mthresh_rel = \ pconf.pruning_use_topk, pconf.pruning_use_marginal, pconf.pruning_labeled, pconf.pruning_perc, pconf.pruning_topk, \ pconf.pruning_gap, pconf.pruning_mthresh, pconf.pruning_mthresh_rel full_score = arc_score + label_score final_valid_mask = BK.constants(BK.get_shape(arc_score), 0, dtype=BK.uint8).squeeze(-1) # (put as argument) arc_marginals = None # [*, mlen, hlen] if prune_use_marginal: if arc_marginals is None: # does not provided, calculate from scores if prune_labeled: # arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).max(-1)[0] # use sum of label marginals instead of max arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).sum(-1) else: arc_marginals = nmarginal_unproj(arc_score, mask_expr, None, labeled=True).squeeze(-1) if prune_mthresh_rel: # relative value max_arc_marginals = arc_marginals.max(-1)[0].log().unsqueeze( -1) m_valid_mask = (arc_marginals.log() - max_arc_marginals) > float( np.log(prune_mthresh)) else: # absolute value m_valid_mask = (arc_marginals > prune_mthresh ) # [*, len-m, len-h] final_valid_mask |= m_valid_mask if prune_use_topk: # prune by "in topk" and "gap-to-top less than gap" for each mod if prune_labeled: # take argmax among label dim tmp_arc_score, _ = full_score.max(-1) else: # todo(note): may be modified inplaced, but does not matter since will finally be masked later tmp_arc_score = arc_score.squeeze(-1) # first apply mask mask_value = Constants.REAL_PRAC_MIN mask_mul = (mask_value * (1. - mask_expr)) # [*, len] tmp_arc_score += mask_mul.unsqueeze(-1) tmp_arc_score += mask_mul.unsqueeze(-2) maxlen = BK.get_shape(tmp_arc_score, -1) tmp_arc_score += mask_value * BK.eye(maxlen) prune_topk = min(prune_topk, int(maxlen * prune_perc + 1), maxlen) if prune_topk >= maxlen: topk_arc_score = tmp_arc_score else: topk_arc_score, _ = BK.topk(tmp_arc_score, prune_topk, dim=-1, sorted=False) # [*, len, k] min_topk_arc_score = topk_arc_score.min(-1)[0].unsqueeze( -1) # [*, len, 1] max_topk_arc_score = topk_arc_score.max(-1)[0].unsqueeze( -1) # [*, len, 1] arc_score_thresh = BK.max_elem(min_topk_arc_score, max_topk_arc_score - prune_gap) # [*, len, 1] t_valid_mask = (tmp_arc_score > arc_score_thresh ) # [*, len-m, len-h] final_valid_mask |= t_valid_mask return final_valid_mask, arc_marginals