def get_selected_label_scores(self, idxes_m_t, idxes_h_t, bsize_range_t, oracle_mask_t, oracle_label_t, arc_margin: float, label_margin: float): # todo(note): in this mode, no repeated arc_margin dim1_range_t = bsize_range_t dim2_range_t = dim1_range_t.unsqueeze(-1) if self.system_labeled: selected_m_cache = [ z[dim2_range_t, idxes_m_t] for z in self.mod_label_cache ] selected_h_repr = self.head_label_cache[dim2_range_t, idxes_h_t] ret = self.scorer.score_label(selected_m_cache, selected_h_repr) # [*, k, labels] if label_margin > 0.: oracle_label_idxes = oracle_label_t[dim2_range_t, idxes_m_t, idxes_h_t].unsqueeze( -1) # [*, k, 1] of int ret.scatter_add_( -1, oracle_label_idxes, BK.constants(oracle_label_idxes.shape, -label_margin)) else: # todo(note): otherwise, simply put zeros (with idx=0 as the slightly best to be consistent) ret = BK.zeros(BK.get_shape(idxes_m_t) + [self.num_label]) ret[:, :, 0] += 0.01 if self.g1_lab_scores is not None: ret += self.g1_lab_scores[dim2_range_t, idxes_m_t, idxes_h_t] return ret
def _score(self, repr_t, attn_t, mask_t): conf = self.conf # ----- repr_m = self.pre_aff_m(repr_t) # [bs, slen, S] repr_h = self.pre_aff_h(repr_t) # [bs, slen, S] scores0 = self.dps_node.paired_score( repr_m, repr_h, inputp=attn_t) # [bs, len_q, len_k, 1+N] # mask at outside slen = BK.get_shape(mask_t, -1) score_mask = BK.constants(BK.get_shape(scores0)[:-1], 1.) # [bs, len_q, len_k] score_mask *= (1. - BK.eye(slen)) # no diag score_mask *= mask_t.unsqueeze(-1) # input mask at len_k score_mask *= mask_t.unsqueeze(-2) # input mask at len_q NEG = Constants.REAL_PRAC_MIN scores1 = scores0 + NEG * (1. - score_mask.unsqueeze(-1) ) # [bs, len_q, len_k, 1+N] # add fixed idx0 scores if set if conf.fix_s0: fix_s0_mask_t = BK.input_real(self.dps_s0_mask) # [1+N] scores1 = ( 1. - fix_s0_mask_t ) * scores1 + fix_s0_mask_t * conf.fix_s0_val # [bs, len_q, len_k, 1+N] # minus s0 if conf.minus_s0: scores1 = scores1 - scores1.narrow(-1, 0, 1) # minus idx=0 scores return scores1, score_mask
def nmst_greedy(scores_expr, mask_expr, lengths_arr, labeled=True, ret_arr=False): assert labeled with BK.no_grad_env(): scores_shape = BK.get_shape(scores_expr) maxlen = scores_shape[1] # mask out diag scores_expr += BK.diagflat( BK.constants([maxlen], Constants.REAL_PRAC_MIN)).unsqueeze(-1) # combined last two dimension and Max over them combined_scores_expr = scores_expr.view(scores_shape[:-2] + [-1]) combine_max_scores, combined_max_idxes = BK.max(combined_scores_expr, dim=-1) # back to real idxes last_size = scores_shape[-1] greedy_heads = combined_max_idxes // last_size greedy_labels = combined_max_idxes % last_size if ret_arr: mst_heads_arr, mst_labels_arr, mst_scores_arr = [ BK.get_value(z) for z in (greedy_heads, greedy_labels, combine_max_scores) ] return mst_heads_arr, mst_labels_arr, mst_scores_arr else: return greedy_heads, greedy_labels, combine_max_scores
def get_unpruned_mask(self, valid_expr, gold_pack): batch_idxes, m_idxes, h_idxes, _, _, _ = gold_pack gold_mask = valid_expr[batch_idxes, m_idxes, h_idxes] gold_mask = gold_mask.byte() mod_unpruned_mask = BK.constants(BK.get_shape(valid_expr)[:2], 0, dtype=BK.uint8) mod_unpruned_mask[batch_idxes[gold_mask], m_idxes[gold_mask]] = 1 return mod_unpruned_mask, gold_mask
def __call__(self, char_input, add_root_token: bool): char_input_t = BK.input_idx(char_input) # [*, slen, wlen] if add_root_token: slice_shape = BK.get_shape(char_input_t) slice_shape[-2] = 1 char_input_t0 = BK.constants(slice_shape, 0, dtype=char_input_t.dtype) # todo(note): simply put 0 here! char_input_t1 = BK.concat([char_input_t0, char_input_t], -2) # [*, 1?+slen, wlen] else: char_input_t1 = char_input_t char_embeds = self.E(char_input_t1) # [*, 1?+slen, wlen, D] char_cat_expr = BK.concat([z(char_embeds) for z in self.char_cnns]) return self.dropout(char_cat_expr) # todo(note): only final dropout
def _step(self, input_expr, input_mask, hard_coverage, prev_state, force_widx, force_lidx, free_beam_size): conf = self.conf free_mode = (force_widx is None) prev_state_h = prev_state[0] # ===== # collect att scores key_up = self.affine_k([input_expr, hard_coverage.unsqueeze(-1)]) # [*, slen, h] query_up = self.affine_q([self.repos.unsqueeze(0), prev_state_h.unsqueeze(-2)]) # [*, R, h] orig_scores = BK.matmul(key_up, query_up.transpose(-2, -1)) # [*, slen, R] orig_scores += (1.-input_mask).unsqueeze(-1) * Constants.REAL_PRAC_MIN # [*, slen, R] # first maximum across the R dim (this step is hard max) maxr_scores, maxr_idxes = orig_scores.max(-1) # [*, slen] if conf.zero_eos_score: # use mask to make it able to be backward tmp_mask = BK.constants(BK.get_shape(maxr_scores), 1.) tmp_mask.index_fill_(-1, BK.input_idx(0), 0.) maxr_scores *= tmp_mask # then select over the slen dim (this step is prob based) maxr_logprobs = BK.log_softmax(maxr_scores) # [*, slen] if free_mode: cur_beam_size = min(free_beam_size, BK.get_shape(maxr_logprobs, -1)) sel_tok_logprobs, sel_tok_idxes = maxr_logprobs.topk(cur_beam_size, dim=-1, sorted=False) # [*, beam] else: sel_tok_idxes = force_widx.unsqueeze(-1) # [*, 1] sel_tok_logprobs = maxr_logprobs.gather(-1, sel_tok_idxes) # [*, 1] # then collect the info and perform labeling lf_input_expr = BK.gather_first_dims(input_expr, sel_tok_idxes, -2) # [*, ?, ~] lf_coverage = hard_coverage.gather(-1, sel_tok_idxes).unsqueeze(-1) # [*, ?, 1] lf_repos = self.repos[maxr_idxes.gather(-1, sel_tok_idxes)] # [*, ?, ~] # todo(+3): using soft version? lf_prev_state = prev_state_h.unsqueeze(-2) # [*, 1, ~] lab_hid_expr = self.lab_f([lf_input_expr, lf_coverage, lf_repos, lf_prev_state]) # [*, ?, ~] # final predicting labels # todo(+N): here we select only max at labeling part, only beam at previous one if free_mode: sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, None) # [*, ?] else: sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, force_lidx.unsqueeze(-1)) # no lab-logprob (*=0) for eos (sel_tok==0) sel_lab_logprobs *= (sel_tok_idxes>0).float() # compute next-state [*, ?, ~] # todo(note): here we flatten the first two dims tmp_rnn_dims = BK.get_shape(sel_tok_idxes) + [-1] tmp_rnn_input = BK.concat([lab_hid_expr, sel_lab_embeds], -1) tmp_rnn_input = tmp_rnn_input.view(-1, BK.get_shape(tmp_rnn_input, -1)) tmp_rnn_hidden = [z.unsqueeze(-2).expand(tmp_rnn_dims).contiguous().view(-1, BK.get_shape(z, -1)) for z in prev_state] # [*, ?, ?, D] next_state = self.rnn_unit(tmp_rnn_input, tmp_rnn_hidden, None) next_state = [z.view(tmp_rnn_dims) for z in next_state] return sel_tok_idxes, sel_tok_logprobs, sel_lab_idxes, sel_lab_logprobs, sel_lab_embeds, next_state
def init_cache(self, enc_repr, enc_mask_arr, insts, g1_pack): # init caches and scores, [orig_bsize, max_slen, D] self.enc_repr = enc_repr self.scoring_fixed_mask_ct = self._init_fixed_mask(enc_mask_arr) # init other masks self.scoring_mask_ct = BK.copy(self.scoring_fixed_mask_ct) full_shape = BK.get_shape(self.scoring_mask_ct) # init oracle masks oracle_mask_ct = BK.constants(full_shape, value=0., device=BK.CPU_DEVICE) # label=0 means nothing, but still need it to avoid index error (dummy oracle for wrong/no-oracle states) oracle_label_ct = BK.constants(full_shape, value=0, dtype=BK.int64, device=BK.CPU_DEVICE) for i, inst in enumerate(insts): EfOracler.init_oracle_mask(inst, oracle_mask_ct[i], oracle_label_ct[i]) self.oracle_mask_t = BK.to_device(oracle_mask_ct) self.oracle_mask_ct = oracle_mask_ct self.oracle_label_t = BK.to_device(oracle_label_ct) # scoring cache self.scoring_cache.init_cache(enc_repr, g1_pack)
def _losses_single(self, score_expr, gold_idxes_expr, single_sample, is_hinge=False, margin=0.): # expand the idxes to 0/1 score_shape = BK.get_shape(score_expr) expanded_idxes_expr = BK.constants(score_shape, 0.) expanded_idxes_expr = BK.minus_margin(expanded_idxes_expr, gold_idxes_expr, -1.) # minus -1 means +1 # todo(+N): first adjust margin, since previously only minus margin for golds? if margin > 0.: adjusted_scores = margin + BK.minus_margin(score_expr, gold_idxes_expr, margin) else: adjusted_scores = score_expr # [*, L] if is_hinge: # multiply pos instances with -1 flipped_scores = adjusted_scores * (1. - 2 * expanded_idxes_expr) losses_all = BK.clamp(flipped_scores, min=0.) else: losses_all = BK.binary_cross_entropy_with_logits( adjusted_scores, expanded_idxes_expr, reduction='none') # special interpretation (todo(+2): there can be better implementation) if single_sample < 1.: # todo(warn): lower bound of sample_rate, ensure 2 samples real_sample_rate = max(single_sample, 2. / score_shape[-1]) elif single_sample >= 2.: # including the positive one real_sample_rate = max(single_sample, 2.) / score_shape[-1] else: # [1., 2.) real_sample_rate = single_sample # if real_sample_rate < 1.: sample_weight = BK.random_bernoulli(score_shape, real_sample_rate, 1.) # make sure positive is valid sample_weight = (sample_weight + expanded_idxes_expr.float()).clamp_(0., 1.) # final_losses = (losses_all * sample_weight).sum(-1) / sample_weight.sum(-1) else: final_losses = losses_all.mean(-1) return final_losses
def __call__(self, input, add_root_token: bool): voc = self.voc # todo(note): append a [cls/root] idx, currently use "bos" input_t = BK.input_idx(input) # [*, 1+slen] # rare unk in training if self.rop.training and self.use_rare_unk: rare_unk_rate = self.ec_conf.comp_rare_unk cur_unk_imask = (self.rare_mask[input_t] * (BK.rand(BK.get_shape(input_t))<rare_unk_rate)).detach().long() input_t = input_t * (1-cur_unk_imask) + self.voc.unk * cur_unk_imask # root if add_root_token: input_t_p0 = BK.constants(BK.get_shape(input_t)[:-1]+[1], voc.bos, dtype=input_t.dtype) # [*, 1+slen] input_t_p1 = BK.concat([input_t_p0, input_t], -1) else: input_t_p1 = input_t expr = self.E(input_t_p1) # [*, 1?+slen] return self.dropout(expr)
def _add_margin_inplaced(self, shape, hit_idxes0, hit_idxes1, hit_idxes2, hit_labels, query_idxes0, query_idxes1, query_idxes2, arc_scores, lab_scores, arc_margin: float, lab_margin: float): # arc gold_arc_mat = BK.constants(shape, 0.) gold_arc_mat[hit_idxes0, hit_idxes1, hit_idxes2] = arc_margin gold_arc_margins = gold_arc_mat[query_idxes0, query_idxes1, query_idxes2] arc_scores -= gold_arc_margins if lab_scores is not None: # label gold_lab_mat = BK.constants_idx(shape, 0) # 0 means the padding idx gold_lab_mat[hit_idxes0, hit_idxes1, hit_idxes2] = hit_labels gold_lab_margin_idxes = gold_lab_mat[query_idxes0, query_idxes1, query_idxes2] lab_scores[BK.arange_idx(BK.get_shape(gold_lab_margin_idxes, 0)), gold_lab_margin_idxes] -= lab_margin return
def _normalize(self, cnode: ConcreteNode, orig_scores, use_noop: bool, noop_fixed_val: float, temperature: float, dim: int): cur_shape = BK.get_shape(orig_scores) # original orig_that_dim = cur_shape[dim] cur_shape[dim] = 1 if use_noop: noop_scores = BK.constants(cur_shape, value=noop_fixed_val) # [*, 1, *] to_norm_scores = BK.concat([orig_scores, noop_scores], dim=dim) # [*, D+1, *] else: to_norm_scores = orig_scores # [*, D, *] # normalize prob_full = cnode(to_norm_scores, temperature=temperature, dim=dim) # [*, ?, *] if use_noop: prob_valid, prob_noop = BK.split(prob_full, [orig_that_dim, 1], dim) # [*, D|1, *] else: prob_valid, prob_noop = prob_full, None return prob_valid, prob_noop, prob_full
def __call__(self, input_map: Dict): exprs = [] # get masks: this mask is for validing of inst batching final_masks = BK.input_real(input_map["mask"]) # [*, slen] if self.add_root_token: # append 1 slice_t = BK.constants(BK.get_shape(final_masks)[:-1]+[1], 1.) final_masks = BK.concat([slice_t, final_masks], -1) # [*, 1+slen] # ----- # for each component for idx, name in enumerate(self.comp_names): cur_node = self.nodes[idx] cur_input = input_map[name] cur_expr = cur_node(cur_input, self.add_root_token) exprs.append(cur_expr) # ----- concated_exprs = BK.concat(exprs, dim=-1) # optional proj if self.has_proj: final_expr = self.final_layer(concated_exprs) else: final_expr = concated_exprs return final_expr, final_masks
def _get_hit_mask(self, shape, hit_idxes0, hit_idxes1, hit_idxes2, query_idxes0, query_idxes1, query_idxes2): hit_mat = BK.constants(shape, 0, dtype=BK.uint8) hit_mat[hit_idxes0, hit_idxes1, hit_idxes2] = 1 hit_mask = hit_mat[query_idxes0, query_idxes1, query_idxes2] return hit_mask
def _get_tmp_mat(self, shape, val, dtype, idx0, idx1, vals): x = BK.constants(shape, val, dtype=dtype) x[idx0, idx1] = vals return x
def forward_batch(self, batched_ids: List, batched_starts: List, batched_typeids: List, training: bool, other_inputs: List[List] = None): conf = self.bconf tokenizer = self.tokenizer PAD_IDX = tokenizer.pad_token_id MASK_IDX = tokenizer.mask_token_id CLS_IDX = tokenizer.cls_token_id SEP_IDX = tokenizer.sep_token_id if other_inputs is None: other_inputs = [] # ===== # batch: here add CLS and SEP bsize = len(batched_ids) max_len = max(len(z) for z in batched_ids) + 2 # plus [CLS] and [SEP] input_shape = (bsize, max_len) # first collect on CPU input_ids_arr = np.full(input_shape, PAD_IDX, dtype=np.int64) input_ids_arr[:, 0] = CLS_IDX input_mask_arr = np.full(input_shape, 0, dtype=np.float32) input_is_start_arr = np.full(input_shape, 0, dtype=np.int64) input_typeids = None if batched_typeids is None else np.full( input_shape, 0, dtype=np.int64) other_input_arrs = [ np.full(input_shape, 0, dtype=np.int64) for _ in other_inputs ] if conf.bert2_retinc_cls: # act as the ROOT word input_is_start_arr[:, 0] = 1 training_mask_rate = conf.bert2_training_mask_rate if training else 0. self_sample_stream = self.random_sample_stream for bidx in range(bsize): cur_ids, cur_starts = batched_ids[bidx], batched_starts[bidx] cur_end = len(cur_ids) + 2 # plus CLS and SEP if training_mask_rate > 0.: # input dropout input_ids_arr[bidx, 1:cur_end] = [ (MASK_IDX if next(self_sample_stream) < training_mask_rate else z) for z in cur_ids ] + [SEP_IDX] else: input_ids_arr[bidx, 1:cur_end] = cur_ids + [SEP_IDX] input_is_start_arr[bidx, 1:cur_end - 1] = cur_starts input_mask_arr[bidx, :cur_end] = 1. if batched_typeids is not None and batched_typeids[ bidx] is not None: input_typeids[bidx, 1:cur_end - 1] = batched_typeids[bidx] for one_other_input_arr, one_other_input_list in zip( other_input_arrs, other_inputs): one_other_input_arr[bidx, 1:cur_end - 1] = one_other_input_list[bidx] # arr to tensor input_ids_t = BK.input_idx(input_ids_arr) input_mask_t = BK.input_real(input_mask_arr) input_is_start_t = BK.input_idx(input_is_start_arr) input_typeid_t = None if input_typeids is None else BK.input_idx( input_typeids) other_input_ts = [BK.input_idx(z) for z in other_input_arrs] # ===== # forward (maybe need multiple times to fit maxlen constraint) MAX_LEN = 510 # save two for [CLS] and [SEP] BACK_LEN = 100 # for splitting cases, still remaining some of previous sub-tokens for context if max_len <= MAX_LEN: # directly once final_outputs = self.forward_features( input_ids_t, input_mask_t, input_typeid_t, other_input_ts) # [bs, slen, *...] start_idxes, start_masks = BK.mask2idx( input_is_start_t.float()) # [bsize, ?] else: all_outputs = [] cur_sub_idx = 0 slice_size = [bsize, 1] slice_cls, slice_sep = BK.constants(slice_size, CLS_IDX, dtype=BK.int64), BK.constants( slice_size, SEP_IDX, dtype=BK.int64) while cur_sub_idx < max_len - 1: # minus 1 to ignore ending SEP cur_slice_start = max(1, cur_sub_idx - BACK_LEN) cur_slice_end = min(cur_slice_start + MAX_LEN, max_len - 1) cur_input_ids_t = BK.concat([ slice_cls, input_ids_t[:, cur_slice_start:cur_slice_end], slice_sep ], 1) # here we simply extend extra original masks cur_input_mask_t = input_mask_t[:, cur_slice_start - 1:cur_slice_end + 1] cur_input_typeid_t = None if input_typeid_t is None else input_typeid_t[:, cur_slice_start - 1: cur_slice_end + 1] cur_other_input_ts = [ z[:, cur_slice_start - 1:cur_slice_end + 1] for z in other_input_ts ] cur_outputs = self.forward_features(cur_input_ids_t, cur_input_mask_t, cur_input_typeid_t, cur_other_input_ts) # only include CLS in the first run, no SEP included if cur_sub_idx == 0: # include CLS, exclude SEP all_outputs.append(cur_outputs[:, :-1]) else: # include only new ones, discard BACK ones, exclude CLS, SEP all_outputs.append(cur_outputs[:, cur_sub_idx - cur_slice_start + 1:-1]) zwarn( f"Add multiple-seg range: [{cur_slice_start}, {cur_sub_idx}, {cur_slice_end})] " f"for all-len={max_len}") cur_sub_idx = cur_slice_end final_outputs = BK.concat(all_outputs, 1) # [bs, max_len-1, *...] start_idxes, start_masks = BK.mask2idx( input_is_start_t[:, :-1].float()) # [bsize, ?] start_expr = BK.gather_first_dims(final_outputs, start_idxes, 1) # [bsize, ?, *...] return start_expr, start_masks # [bsize, ?, ...], [bsize, ?]
def loss(self, input_expr, loss_mask, gold_idxes, margin=0.): gold_all_idxes = self._get_all_idxes(gold_idxes) # scoring raw_scores = self._raw_scores(input_expr) raw_scores_aug = [] margin_P, margin_R, margin_T = self.conf.margin_lambda_P, self.conf.margin_lambda_R, self.conf.margin_lambda_T # gold_shape = BK.get_shape(gold_idxes) # [*] gold_bsize_prod = np.prod(gold_shape) # gold_arange_idxes = BK.arange_idx(gold_bsize_prod) # margin for i in range(self.eff_max_layer): cur_gold_inputs = gold_all_idxes[i] # add margin cur_scores = raw_scores[i] # [*, ?] cur_margin = margin * self.margin_lambdas[i] if cur_margin > 0.: cur_num_target = self.prediction_sizes[i] cur_isnil = self.layered_isnil[i].byte() # [NLab] cost_matrix = BK.constants([cur_num_target, cur_num_target], margin_T) # [gold, pred] cost_matrix[cur_isnil, :] = margin_P cost_matrix[:, cur_isnil] = margin_R diag_idxes = BK.arange_idx(cur_num_target) cost_matrix[diag_idxes, diag_idxes] = 0. margin_mat = cost_matrix[cur_gold_inputs] cur_aug_scores = cur_scores + margin_mat # [*, ?] else: cur_aug_scores = cur_scores raw_scores_aug.append(cur_aug_scores) # cascade scores final_scores = self._cascade_scores(raw_scores_aug) # loss weight, todo(note): asserted self.hl_vocab.nil_as_zero before loss_weights = ((gold_idxes == 0).float() * (self.loss_fullnil_weight - 1.) + 1.) if self.loss_fullnil_weight < 1. else 1. # calculate loss loss_prob_entropy_lambda = self.conf.loss_prob_entropy_lambda loss_prob_reweight = self.conf.loss_prob_reweight final_losses = [] no_loss_max_gold = self.conf.no_loss_max_gold if loss_mask is None: loss_mask = BK.constants(BK.get_shape(input_expr)[:-1], 1.) for i in range(self.eff_max_layer): cur_final_scores, cur_gold_inputs = final_scores[ i], gold_all_idxes[i] # [*, ?], [*] # collect the loss if self.is_hinge_loss: cur_pred_scores, cur_pred_idxes = cur_final_scores.max(-1) cur_gold_scores = BK.gather(cur_final_scores, cur_gold_inputs.unsqueeze(-1), -1).squeeze(-1) cur_loss = cur_pred_scores - cur_gold_scores # [*], todo(note): this must be >=0 if no_loss_max_gold: # this should be implicit cur_loss = cur_loss * (cur_loss > 0.).float() elif self.is_prob_loss: # cur_loss = BK.loss_nll(cur_final_scores, cur_gold_inputs) # [*] cur_loss = self._my_loss_prob(cur_final_scores, cur_gold_inputs, loss_prob_entropy_lambda, loss_mask, loss_prob_reweight) # [*] if no_loss_max_gold: cur_pred_scores, cur_pred_idxes = cur_final_scores.max(-1) cur_gold_scores = BK.gather(cur_final_scores, cur_gold_inputs.unsqueeze(-1), -1).squeeze(-1) cur_loss = cur_loss * (cur_gold_scores > cur_pred_scores).float() else: raise NotImplementedError( f"UNK loss {self.conf.loss_function}") # here first summing up, divided at the outside one_loss_sum = ( cur_loss * (loss_mask * loss_weights)).sum() * self.loss_lambdas[i] final_losses.append(one_loss_sum) # final sum final_loss_sum = BK.stack(final_losses).sum() _, ret_lab_idxes, ret_lab_embeds = self._predict(final_scores, None) return [[final_loss_sum, loss_mask.sum()]], ret_lab_idxes, ret_lab_embeds
def prune_with_scores(arc_score, label_score, mask_expr, pconf: PruneG1Conf, arc_marginals=None): prune_use_topk, prune_use_marginal, prune_labeled, prune_perc, prune_topk, prune_gap, prune_mthresh, prune_mthresh_rel = \ pconf.pruning_use_topk, pconf.pruning_use_marginal, pconf.pruning_labeled, pconf.pruning_perc, pconf.pruning_topk, \ pconf.pruning_gap, pconf.pruning_mthresh, pconf.pruning_mthresh_rel full_score = arc_score + label_score final_valid_mask = BK.constants(BK.get_shape(arc_score), 0, dtype=BK.uint8).squeeze(-1) # (put as argument) arc_marginals = None # [*, mlen, hlen] if prune_use_marginal: if arc_marginals is None: # does not provided, calculate from scores if prune_labeled: # arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).max(-1)[0] # use sum of label marginals instead of max arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).sum(-1) else: arc_marginals = nmarginal_unproj(arc_score, mask_expr, None, labeled=True).squeeze(-1) if prune_mthresh_rel: # relative value max_arc_marginals = arc_marginals.max(-1)[0].log().unsqueeze( -1) m_valid_mask = (arc_marginals.log() - max_arc_marginals) > float( np.log(prune_mthresh)) else: # absolute value m_valid_mask = (arc_marginals > prune_mthresh ) # [*, len-m, len-h] final_valid_mask |= m_valid_mask if prune_use_topk: # prune by "in topk" and "gap-to-top less than gap" for each mod if prune_labeled: # take argmax among label dim tmp_arc_score, _ = full_score.max(-1) else: # todo(note): may be modified inplaced, but does not matter since will finally be masked later tmp_arc_score = arc_score.squeeze(-1) # first apply mask mask_value = Constants.REAL_PRAC_MIN mask_mul = (mask_value * (1. - mask_expr)) # [*, len] tmp_arc_score += mask_mul.unsqueeze(-1) tmp_arc_score += mask_mul.unsqueeze(-2) maxlen = BK.get_shape(tmp_arc_score, -1) tmp_arc_score += mask_value * BK.eye(maxlen) prune_topk = min(prune_topk, int(maxlen * prune_perc + 1), maxlen) if prune_topk >= maxlen: topk_arc_score = tmp_arc_score else: topk_arc_score, _ = BK.topk(tmp_arc_score, prune_topk, dim=-1, sorted=False) # [*, len, k] min_topk_arc_score = topk_arc_score.min(-1)[0].unsqueeze( -1) # [*, len, 1] max_topk_arc_score = topk_arc_score.max(-1)[0].unsqueeze( -1) # [*, len, 1] arc_score_thresh = BK.max_elem(min_topk_arc_score, max_topk_arc_score - prune_gap) # [*, len, 1] t_valid_mask = (tmp_arc_score > arc_score_thresh ) # [*, len-m, len-h] final_valid_mask |= t_valid_mask return final_valid_mask, arc_marginals