Example #1
0
 def get_selected_label_scores(self, idxes_m_t, idxes_h_t, bsize_range_t,
                               oracle_mask_t, oracle_label_t,
                               arc_margin: float, label_margin: float):
     # todo(note): in this mode, no repeated arc_margin
     dim1_range_t = bsize_range_t
     dim2_range_t = dim1_range_t.unsqueeze(-1)
     if self.system_labeled:
         selected_m_cache = [
             z[dim2_range_t, idxes_m_t] for z in self.mod_label_cache
         ]
         selected_h_repr = self.head_label_cache[dim2_range_t, idxes_h_t]
         ret = self.scorer.score_label(selected_m_cache,
                                       selected_h_repr)  # [*, k, labels]
         if label_margin > 0.:
             oracle_label_idxes = oracle_label_t[dim2_range_t, idxes_m_t,
                                                 idxes_h_t].unsqueeze(
                                                     -1)  # [*, k, 1] of int
             ret.scatter_add_(
                 -1, oracle_label_idxes,
                 BK.constants(oracle_label_idxes.shape, -label_margin))
     else:
         # todo(note): otherwise, simply put zeros (with idx=0 as the slightly best to be consistent)
         ret = BK.zeros(BK.get_shape(idxes_m_t) + [self.num_label])
         ret[:, :, 0] += 0.01
     if self.g1_lab_scores is not None:
         ret += self.g1_lab_scores[dim2_range_t, idxes_m_t, idxes_h_t]
     return ret
Example #2
0
 def _score(self, repr_t, attn_t, mask_t):
     conf = self.conf
     # -----
     repr_m = self.pre_aff_m(repr_t)  # [bs, slen, S]
     repr_h = self.pre_aff_h(repr_t)  # [bs, slen, S]
     scores0 = self.dps_node.paired_score(
         repr_m, repr_h, inputp=attn_t)  # [bs, len_q, len_k, 1+N]
     # mask at outside
     slen = BK.get_shape(mask_t, -1)
     score_mask = BK.constants(BK.get_shape(scores0)[:-1],
                               1.)  # [bs, len_q, len_k]
     score_mask *= (1. - BK.eye(slen))  # no diag
     score_mask *= mask_t.unsqueeze(-1)  # input mask at len_k
     score_mask *= mask_t.unsqueeze(-2)  # input mask at len_q
     NEG = Constants.REAL_PRAC_MIN
     scores1 = scores0 + NEG * (1. - score_mask.unsqueeze(-1)
                                )  # [bs, len_q, len_k, 1+N]
     # add fixed idx0 scores if set
     if conf.fix_s0:
         fix_s0_mask_t = BK.input_real(self.dps_s0_mask)  # [1+N]
         scores1 = (
             1. - fix_s0_mask_t
         ) * scores1 + fix_s0_mask_t * conf.fix_s0_val  # [bs, len_q, len_k, 1+N]
     # minus s0
     if conf.minus_s0:
         scores1 = scores1 - scores1.narrow(-1, 0, 1)  # minus idx=0 scores
     return scores1, score_mask
Example #3
0
def nmst_greedy(scores_expr,
                mask_expr,
                lengths_arr,
                labeled=True,
                ret_arr=False):
    assert labeled
    with BK.no_grad_env():
        scores_shape = BK.get_shape(scores_expr)
        maxlen = scores_shape[1]
        # mask out diag
        scores_expr += BK.diagflat(
            BK.constants([maxlen], Constants.REAL_PRAC_MIN)).unsqueeze(-1)
        # combined last two dimension and Max over them
        combined_scores_expr = scores_expr.view(scores_shape[:-2] + [-1])
        combine_max_scores, combined_max_idxes = BK.max(combined_scores_expr,
                                                        dim=-1)
        # back to real idxes
        last_size = scores_shape[-1]
        greedy_heads = combined_max_idxes // last_size
        greedy_labels = combined_max_idxes % last_size
        if ret_arr:
            mst_heads_arr, mst_labels_arr, mst_scores_arr = [
                BK.get_value(z)
                for z in (greedy_heads, greedy_labels, combine_max_scores)
            ]
            return mst_heads_arr, mst_labels_arr, mst_scores_arr
        else:
            return greedy_heads, greedy_labels, combine_max_scores
Example #4
0
 def get_unpruned_mask(self, valid_expr, gold_pack):
     batch_idxes, m_idxes, h_idxes, _, _, _ = gold_pack
     gold_mask = valid_expr[batch_idxes, m_idxes, h_idxes]
     gold_mask = gold_mask.byte()
     mod_unpruned_mask = BK.constants(BK.get_shape(valid_expr)[:2],
                                      0,
                                      dtype=BK.uint8)
     mod_unpruned_mask[batch_idxes[gold_mask], m_idxes[gold_mask]] = 1
     return mod_unpruned_mask, gold_mask
Example #5
0
 def __call__(self, char_input, add_root_token: bool):
     char_input_t = BK.input_idx(char_input)  # [*, slen, wlen]
     if add_root_token:
         slice_shape = BK.get_shape(char_input_t)
         slice_shape[-2] = 1
         char_input_t0 = BK.constants(slice_shape, 0, dtype=char_input_t.dtype)  # todo(note): simply put 0 here!
         char_input_t1 = BK.concat([char_input_t0, char_input_t], -2)  # [*, 1?+slen, wlen]
     else:
         char_input_t1 = char_input_t
     char_embeds = self.E(char_input_t1)  # [*, 1?+slen, wlen, D]
     char_cat_expr = BK.concat([z(char_embeds) for z in self.char_cnns])
     return self.dropout(char_cat_expr)  # todo(note): only final dropout
Example #6
0
 def _step(self, input_expr, input_mask, hard_coverage, prev_state, force_widx, force_lidx, free_beam_size):
     conf = self.conf
     free_mode = (force_widx is None)
     prev_state_h = prev_state[0]
     # =====
     # collect att scores
     key_up = self.affine_k([input_expr, hard_coverage.unsqueeze(-1)])  # [*, slen, h]
     query_up = self.affine_q([self.repos.unsqueeze(0), prev_state_h.unsqueeze(-2)])  # [*, R, h]
     orig_scores = BK.matmul(key_up, query_up.transpose(-2, -1))  # [*, slen, R]
     orig_scores += (1.-input_mask).unsqueeze(-1) * Constants.REAL_PRAC_MIN  # [*, slen, R]
     # first maximum across the R dim (this step is hard max)
     maxr_scores, maxr_idxes = orig_scores.max(-1)  # [*, slen]
     if conf.zero_eos_score:
         # use mask to make it able to be backward
         tmp_mask = BK.constants(BK.get_shape(maxr_scores), 1.)
         tmp_mask.index_fill_(-1, BK.input_idx(0), 0.)
         maxr_scores *= tmp_mask
     # then select over the slen dim (this step is prob based)
     maxr_logprobs = BK.log_softmax(maxr_scores)  # [*, slen]
     if free_mode:
         cur_beam_size = min(free_beam_size, BK.get_shape(maxr_logprobs, -1))
         sel_tok_logprobs, sel_tok_idxes = maxr_logprobs.topk(cur_beam_size, dim=-1, sorted=False)  # [*, beam]
     else:
         sel_tok_idxes = force_widx.unsqueeze(-1)  # [*, 1]
         sel_tok_logprobs = maxr_logprobs.gather(-1, sel_tok_idxes)  # [*, 1]
     # then collect the info and perform labeling
     lf_input_expr = BK.gather_first_dims(input_expr, sel_tok_idxes, -2)  # [*, ?, ~]
     lf_coverage = hard_coverage.gather(-1, sel_tok_idxes).unsqueeze(-1)  # [*, ?, 1]
     lf_repos = self.repos[maxr_idxes.gather(-1, sel_tok_idxes)]  # [*, ?, ~]  # todo(+3): using soft version?
     lf_prev_state = prev_state_h.unsqueeze(-2)  # [*, 1, ~]
     lab_hid_expr = self.lab_f([lf_input_expr, lf_coverage, lf_repos, lf_prev_state])  # [*, ?, ~]
     # final predicting labels
     # todo(+N): here we select only max at labeling part, only beam at previous one
     if free_mode:
         sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, None)  # [*, ?]
     else:
         sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, force_lidx.unsqueeze(-1))
     # no lab-logprob (*=0) for eos (sel_tok==0)
     sel_lab_logprobs *= (sel_tok_idxes>0).float()
     # compute next-state [*, ?, ~]
     # todo(note): here we flatten the first two dims
     tmp_rnn_dims = BK.get_shape(sel_tok_idxes) + [-1]
     tmp_rnn_input = BK.concat([lab_hid_expr, sel_lab_embeds], -1)
     tmp_rnn_input = tmp_rnn_input.view(-1, BK.get_shape(tmp_rnn_input, -1))
     tmp_rnn_hidden = [z.unsqueeze(-2).expand(tmp_rnn_dims).contiguous().view(-1, BK.get_shape(z, -1))
                       for z in prev_state]  # [*, ?, ?, D]
     next_state = self.rnn_unit(tmp_rnn_input, tmp_rnn_hidden, None)
     next_state = [z.view(tmp_rnn_dims) for z in next_state]
     return sel_tok_idxes, sel_tok_logprobs, sel_lab_idxes, sel_lab_logprobs, sel_lab_embeds, next_state
Example #7
0
 def init_cache(self, enc_repr, enc_mask_arr, insts, g1_pack):
     # init caches and scores, [orig_bsize, max_slen, D]
     self.enc_repr = enc_repr
     self.scoring_fixed_mask_ct = self._init_fixed_mask(enc_mask_arr)
     # init other masks
     self.scoring_mask_ct = BK.copy(self.scoring_fixed_mask_ct)
     full_shape = BK.get_shape(self.scoring_mask_ct)
     # init oracle masks
     oracle_mask_ct = BK.constants(full_shape,
                                   value=0.,
                                   device=BK.CPU_DEVICE)
     # label=0 means nothing, but still need it to avoid index error (dummy oracle for wrong/no-oracle states)
     oracle_label_ct = BK.constants(full_shape,
                                    value=0,
                                    dtype=BK.int64,
                                    device=BK.CPU_DEVICE)
     for i, inst in enumerate(insts):
         EfOracler.init_oracle_mask(inst, oracle_mask_ct[i],
                                    oracle_label_ct[i])
     self.oracle_mask_t = BK.to_device(oracle_mask_ct)
     self.oracle_mask_ct = oracle_mask_ct
     self.oracle_label_t = BK.to_device(oracle_label_ct)
     # scoring cache
     self.scoring_cache.init_cache(enc_repr, g1_pack)
Example #8
0
 def _losses_single(self,
                    score_expr,
                    gold_idxes_expr,
                    single_sample,
                    is_hinge=False,
                    margin=0.):
     # expand the idxes to 0/1
     score_shape = BK.get_shape(score_expr)
     expanded_idxes_expr = BK.constants(score_shape, 0.)
     expanded_idxes_expr = BK.minus_margin(expanded_idxes_expr,
                                           gold_idxes_expr,
                                           -1.)  # minus -1 means +1
     # todo(+N): first adjust margin, since previously only minus margin for golds?
     if margin > 0.:
         adjusted_scores = margin + BK.minus_margin(score_expr,
                                                    gold_idxes_expr, margin)
     else:
         adjusted_scores = score_expr
     # [*, L]
     if is_hinge:
         # multiply pos instances with -1
         flipped_scores = adjusted_scores * (1. - 2 * expanded_idxes_expr)
         losses_all = BK.clamp(flipped_scores, min=0.)
     else:
         losses_all = BK.binary_cross_entropy_with_logits(
             adjusted_scores, expanded_idxes_expr, reduction='none')
     # special interpretation (todo(+2): there can be better implementation)
     if single_sample < 1.:
         # todo(warn): lower bound of sample_rate, ensure 2 samples
         real_sample_rate = max(single_sample, 2. / score_shape[-1])
     elif single_sample >= 2.:
         # including the positive one
         real_sample_rate = max(single_sample, 2.) / score_shape[-1]
     else:  # [1., 2.)
         real_sample_rate = single_sample
     #
     if real_sample_rate < 1.:
         sample_weight = BK.random_bernoulli(score_shape, real_sample_rate,
                                             1.)
         # make sure positive is valid
         sample_weight = (sample_weight +
                          expanded_idxes_expr.float()).clamp_(0., 1.)
         #
         final_losses = (losses_all *
                         sample_weight).sum(-1) / sample_weight.sum(-1)
     else:
         final_losses = losses_all.mean(-1)
     return final_losses
Example #9
0
 def __call__(self, input, add_root_token: bool):
     voc = self.voc
     # todo(note): append a [cls/root] idx, currently use "bos"
     input_t = BK.input_idx(input)  # [*, 1+slen]
     # rare unk in training
     if self.rop.training and self.use_rare_unk:
         rare_unk_rate = self.ec_conf.comp_rare_unk
         cur_unk_imask = (self.rare_mask[input_t] * (BK.rand(BK.get_shape(input_t))<rare_unk_rate)).detach().long()
         input_t = input_t * (1-cur_unk_imask) + self.voc.unk * cur_unk_imask
     # root
     if add_root_token:
         input_t_p0 = BK.constants(BK.get_shape(input_t)[:-1]+[1], voc.bos, dtype=input_t.dtype)  # [*, 1+slen]
         input_t_p1 = BK.concat([input_t_p0, input_t], -1)
     else:
         input_t_p1 = input_t
     expr = self.E(input_t_p1)  # [*, 1?+slen]
     return self.dropout(expr)
Example #10
0
 def _add_margin_inplaced(self, shape, hit_idxes0, hit_idxes1, hit_idxes2,
                          hit_labels, query_idxes0, query_idxes1,
                          query_idxes2, arc_scores, lab_scores,
                          arc_margin: float, lab_margin: float):
     # arc
     gold_arc_mat = BK.constants(shape, 0.)
     gold_arc_mat[hit_idxes0, hit_idxes1, hit_idxes2] = arc_margin
     gold_arc_margins = gold_arc_mat[query_idxes0, query_idxes1,
                                     query_idxes2]
     arc_scores -= gold_arc_margins
     if lab_scores is not None:
         # label
         gold_lab_mat = BK.constants_idx(shape,
                                         0)  # 0 means the padding idx
         gold_lab_mat[hit_idxes0, hit_idxes1, hit_idxes2] = hit_labels
         gold_lab_margin_idxes = gold_lab_mat[query_idxes0, query_idxes1,
                                              query_idxes2]
         lab_scores[BK.arange_idx(BK.get_shape(gold_lab_margin_idxes, 0)),
                    gold_lab_margin_idxes] -= lab_margin
     return
Example #11
0
 def _normalize(self, cnode: ConcreteNode, orig_scores, use_noop: bool,
                noop_fixed_val: float, temperature: float, dim: int):
     cur_shape = BK.get_shape(orig_scores)  # original
     orig_that_dim = cur_shape[dim]
     cur_shape[dim] = 1
     if use_noop:
         noop_scores = BK.constants(cur_shape,
                                    value=noop_fixed_val)  # [*, 1, *]
         to_norm_scores = BK.concat([orig_scores, noop_scores],
                                    dim=dim)  # [*, D+1, *]
     else:
         to_norm_scores = orig_scores  # [*, D, *]
     # normalize
     prob_full = cnode(to_norm_scores, temperature=temperature,
                       dim=dim)  # [*, ?, *]
     if use_noop:
         prob_valid, prob_noop = BK.split(prob_full, [orig_that_dim, 1],
                                          dim)  # [*, D|1, *]
     else:
         prob_valid, prob_noop = prob_full, None
     return prob_valid, prob_noop, prob_full
Example #12
0
 def __call__(self, input_map: Dict):
     exprs = []
     # get masks: this mask is for validing of inst batching
     final_masks = BK.input_real(input_map["mask"])  # [*, slen]
     if self.add_root_token:  # append 1
         slice_t = BK.constants(BK.get_shape(final_masks)[:-1]+[1], 1.)
         final_masks = BK.concat([slice_t, final_masks], -1)  # [*, 1+slen]
     # -----
     # for each component
     for idx, name in enumerate(self.comp_names):
         cur_node = self.nodes[idx]
         cur_input = input_map[name]
         cur_expr = cur_node(cur_input, self.add_root_token)
         exprs.append(cur_expr)
     # -----
     concated_exprs = BK.concat(exprs, dim=-1)
     # optional proj
     if self.has_proj:
         final_expr = self.final_layer(concated_exprs)
     else:
         final_expr = concated_exprs
     return final_expr, final_masks
Example #13
0
 def _get_hit_mask(self, shape, hit_idxes0, hit_idxes1, hit_idxes2,
                   query_idxes0, query_idxes1, query_idxes2):
     hit_mat = BK.constants(shape, 0, dtype=BK.uint8)
     hit_mat[hit_idxes0, hit_idxes1, hit_idxes2] = 1
     hit_mask = hit_mat[query_idxes0, query_idxes1, query_idxes2]
     return hit_mask
Example #14
0
 def _get_tmp_mat(self, shape, val, dtype, idx0, idx1, vals):
     x = BK.constants(shape, val, dtype=dtype)
     x[idx0, idx1] = vals
     return x
Example #15
0
 def forward_batch(self,
                   batched_ids: List,
                   batched_starts: List,
                   batched_typeids: List,
                   training: bool,
                   other_inputs: List[List] = None):
     conf = self.bconf
     tokenizer = self.tokenizer
     PAD_IDX = tokenizer.pad_token_id
     MASK_IDX = tokenizer.mask_token_id
     CLS_IDX = tokenizer.cls_token_id
     SEP_IDX = tokenizer.sep_token_id
     if other_inputs is None:
         other_inputs = []
     # =====
     # batch: here add CLS and SEP
     bsize = len(batched_ids)
     max_len = max(len(z) for z in batched_ids) + 2  # plus [CLS] and [SEP]
     input_shape = (bsize, max_len)
     # first collect on CPU
     input_ids_arr = np.full(input_shape, PAD_IDX, dtype=np.int64)
     input_ids_arr[:, 0] = CLS_IDX
     input_mask_arr = np.full(input_shape, 0, dtype=np.float32)
     input_is_start_arr = np.full(input_shape, 0, dtype=np.int64)
     input_typeids = None if batched_typeids is None else np.full(
         input_shape, 0, dtype=np.int64)
     other_input_arrs = [
         np.full(input_shape, 0, dtype=np.int64) for _ in other_inputs
     ]
     if conf.bert2_retinc_cls:  # act as the ROOT word
         input_is_start_arr[:, 0] = 1
     training_mask_rate = conf.bert2_training_mask_rate if training else 0.
     self_sample_stream = self.random_sample_stream
     for bidx in range(bsize):
         cur_ids, cur_starts = batched_ids[bidx], batched_starts[bidx]
         cur_end = len(cur_ids) + 2  # plus CLS and SEP
         if training_mask_rate > 0.:
             # input dropout
             input_ids_arr[bidx, 1:cur_end] = [
                 (MASK_IDX
                  if next(self_sample_stream) < training_mask_rate else z)
                 for z in cur_ids
             ] + [SEP_IDX]
         else:
             input_ids_arr[bidx, 1:cur_end] = cur_ids + [SEP_IDX]
         input_is_start_arr[bidx, 1:cur_end - 1] = cur_starts
         input_mask_arr[bidx, :cur_end] = 1.
         if batched_typeids is not None and batched_typeids[
                 bidx] is not None:
             input_typeids[bidx, 1:cur_end - 1] = batched_typeids[bidx]
         for one_other_input_arr, one_other_input_list in zip(
                 other_input_arrs, other_inputs):
             one_other_input_arr[bidx,
                                 1:cur_end - 1] = one_other_input_list[bidx]
     # arr to tensor
     input_ids_t = BK.input_idx(input_ids_arr)
     input_mask_t = BK.input_real(input_mask_arr)
     input_is_start_t = BK.input_idx(input_is_start_arr)
     input_typeid_t = None if input_typeids is None else BK.input_idx(
         input_typeids)
     other_input_ts = [BK.input_idx(z) for z in other_input_arrs]
     # =====
     # forward (maybe need multiple times to fit maxlen constraint)
     MAX_LEN = 510  # save two for [CLS] and [SEP]
     BACK_LEN = 100  # for splitting cases, still remaining some of previous sub-tokens for context
     if max_len <= MAX_LEN:
         # directly once
         final_outputs = self.forward_features(
             input_ids_t, input_mask_t, input_typeid_t,
             other_input_ts)  # [bs, slen, *...]
         start_idxes, start_masks = BK.mask2idx(
             input_is_start_t.float())  # [bsize, ?]
     else:
         all_outputs = []
         cur_sub_idx = 0
         slice_size = [bsize, 1]
         slice_cls, slice_sep = BK.constants(slice_size,
                                             CLS_IDX,
                                             dtype=BK.int64), BK.constants(
                                                 slice_size,
                                                 SEP_IDX,
                                                 dtype=BK.int64)
         while cur_sub_idx < max_len - 1:  # minus 1 to ignore ending SEP
             cur_slice_start = max(1, cur_sub_idx - BACK_LEN)
             cur_slice_end = min(cur_slice_start + MAX_LEN, max_len - 1)
             cur_input_ids_t = BK.concat([
                 slice_cls, input_ids_t[:, cur_slice_start:cur_slice_end],
                 slice_sep
             ], 1)
             # here we simply extend extra original masks
             cur_input_mask_t = input_mask_t[:, cur_slice_start -
                                             1:cur_slice_end + 1]
             cur_input_typeid_t = None if input_typeid_t is None else input_typeid_t[:,
                                                                                     cur_slice_start
                                                                                     -
                                                                                     1:
                                                                                     cur_slice_end
                                                                                     +
                                                                                     1]
             cur_other_input_ts = [
                 z[:, cur_slice_start - 1:cur_slice_end + 1]
                 for z in other_input_ts
             ]
             cur_outputs = self.forward_features(cur_input_ids_t,
                                                 cur_input_mask_t,
                                                 cur_input_typeid_t,
                                                 cur_other_input_ts)
             # only include CLS in the first run, no SEP included
             if cur_sub_idx == 0:
                 # include CLS, exclude SEP
                 all_outputs.append(cur_outputs[:, :-1])
             else:
                 # include only new ones, discard BACK ones, exclude CLS, SEP
                 all_outputs.append(cur_outputs[:, cur_sub_idx -
                                                cur_slice_start + 1:-1])
                 zwarn(
                     f"Add multiple-seg range: [{cur_slice_start}, {cur_sub_idx}, {cur_slice_end})] "
                     f"for all-len={max_len}")
             cur_sub_idx = cur_slice_end
         final_outputs = BK.concat(all_outputs, 1)  # [bs, max_len-1, *...]
         start_idxes, start_masks = BK.mask2idx(
             input_is_start_t[:, :-1].float())  # [bsize, ?]
     start_expr = BK.gather_first_dims(final_outputs, start_idxes,
                                       1)  # [bsize, ?, *...]
     return start_expr, start_masks  # [bsize, ?, ...], [bsize, ?]
Example #16
0
 def loss(self, input_expr, loss_mask, gold_idxes, margin=0.):
     gold_all_idxes = self._get_all_idxes(gold_idxes)
     # scoring
     raw_scores = self._raw_scores(input_expr)
     raw_scores_aug = []
     margin_P, margin_R, margin_T = self.conf.margin_lambda_P, self.conf.margin_lambda_R, self.conf.margin_lambda_T
     #
     gold_shape = BK.get_shape(gold_idxes)  # [*]
     gold_bsize_prod = np.prod(gold_shape)
     # gold_arange_idxes = BK.arange_idx(gold_bsize_prod)
     # margin
     for i in range(self.eff_max_layer):
         cur_gold_inputs = gold_all_idxes[i]
         # add margin
         cur_scores = raw_scores[i]  # [*, ?]
         cur_margin = margin * self.margin_lambdas[i]
         if cur_margin > 0.:
             cur_num_target = self.prediction_sizes[i]
             cur_isnil = self.layered_isnil[i].byte()  # [NLab]
             cost_matrix = BK.constants([cur_num_target, cur_num_target],
                                        margin_T)  # [gold, pred]
             cost_matrix[cur_isnil, :] = margin_P
             cost_matrix[:, cur_isnil] = margin_R
             diag_idxes = BK.arange_idx(cur_num_target)
             cost_matrix[diag_idxes, diag_idxes] = 0.
             margin_mat = cost_matrix[cur_gold_inputs]
             cur_aug_scores = cur_scores + margin_mat  # [*, ?]
         else:
             cur_aug_scores = cur_scores
         raw_scores_aug.append(cur_aug_scores)
     # cascade scores
     final_scores = self._cascade_scores(raw_scores_aug)
     # loss weight, todo(note): asserted self.hl_vocab.nil_as_zero before
     loss_weights = ((gold_idxes == 0).float() *
                     (self.loss_fullnil_weight - 1.) +
                     1.) if self.loss_fullnil_weight < 1. else 1.
     # calculate loss
     loss_prob_entropy_lambda = self.conf.loss_prob_entropy_lambda
     loss_prob_reweight = self.conf.loss_prob_reweight
     final_losses = []
     no_loss_max_gold = self.conf.no_loss_max_gold
     if loss_mask is None:
         loss_mask = BK.constants(BK.get_shape(input_expr)[:-1], 1.)
     for i in range(self.eff_max_layer):
         cur_final_scores, cur_gold_inputs = final_scores[
             i], gold_all_idxes[i]  # [*, ?], [*]
         # collect the loss
         if self.is_hinge_loss:
             cur_pred_scores, cur_pred_idxes = cur_final_scores.max(-1)
             cur_gold_scores = BK.gather(cur_final_scores,
                                         cur_gold_inputs.unsqueeze(-1),
                                         -1).squeeze(-1)
             cur_loss = cur_pred_scores - cur_gold_scores  # [*], todo(note): this must be >=0
             if no_loss_max_gold:  # this should be implicit
                 cur_loss = cur_loss * (cur_loss > 0.).float()
         elif self.is_prob_loss:
             # cur_loss = BK.loss_nll(cur_final_scores, cur_gold_inputs)  # [*]
             cur_loss = self._my_loss_prob(cur_final_scores,
                                           cur_gold_inputs,
                                           loss_prob_entropy_lambda,
                                           loss_mask,
                                           loss_prob_reweight)  # [*]
             if no_loss_max_gold:
                 cur_pred_scores, cur_pred_idxes = cur_final_scores.max(-1)
                 cur_gold_scores = BK.gather(cur_final_scores,
                                             cur_gold_inputs.unsqueeze(-1),
                                             -1).squeeze(-1)
                 cur_loss = cur_loss * (cur_gold_scores >
                                        cur_pred_scores).float()
         else:
             raise NotImplementedError(
                 f"UNK loss {self.conf.loss_function}")
         # here first summing up, divided at the outside
         one_loss_sum = (
             cur_loss *
             (loss_mask * loss_weights)).sum() * self.loss_lambdas[i]
         final_losses.append(one_loss_sum)
     # final sum
     final_loss_sum = BK.stack(final_losses).sum()
     _, ret_lab_idxes, ret_lab_embeds = self._predict(final_scores, None)
     return [[final_loss_sum,
              loss_mask.sum()]], ret_lab_idxes, ret_lab_embeds
Example #17
0
 def prune_with_scores(arc_score,
                       label_score,
                       mask_expr,
                       pconf: PruneG1Conf,
                       arc_marginals=None):
     prune_use_topk, prune_use_marginal, prune_labeled, prune_perc, prune_topk, prune_gap, prune_mthresh, prune_mthresh_rel = \
         pconf.pruning_use_topk, pconf.pruning_use_marginal, pconf.pruning_labeled, pconf.pruning_perc, pconf.pruning_topk, \
         pconf.pruning_gap, pconf.pruning_mthresh, pconf.pruning_mthresh_rel
     full_score = arc_score + label_score
     final_valid_mask = BK.constants(BK.get_shape(arc_score),
                                     0,
                                     dtype=BK.uint8).squeeze(-1)
     # (put as argument) arc_marginals = None  # [*, mlen, hlen]
     if prune_use_marginal:
         if arc_marginals is None:  # does not provided, calculate from scores
             if prune_labeled:
                 # arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).max(-1)[0]
                 # use sum of label marginals instead of max
                 arc_marginals = nmarginal_unproj(full_score,
                                                  mask_expr,
                                                  None,
                                                  labeled=True).sum(-1)
             else:
                 arc_marginals = nmarginal_unproj(arc_score,
                                                  mask_expr,
                                                  None,
                                                  labeled=True).squeeze(-1)
         if prune_mthresh_rel:
             # relative value
             max_arc_marginals = arc_marginals.max(-1)[0].log().unsqueeze(
                 -1)
             m_valid_mask = (arc_marginals.log() -
                             max_arc_marginals) > float(
                                 np.log(prune_mthresh))
         else:
             # absolute value
             m_valid_mask = (arc_marginals > prune_mthresh
                             )  # [*, len-m, len-h]
         final_valid_mask |= m_valid_mask
     if prune_use_topk:
         # prune by "in topk" and "gap-to-top less than gap" for each mod
         if prune_labeled:  # take argmax among label dim
             tmp_arc_score, _ = full_score.max(-1)
         else:
             # todo(note): may be modified inplaced, but does not matter since will finally be masked later
             tmp_arc_score = arc_score.squeeze(-1)
         # first apply mask
         mask_value = Constants.REAL_PRAC_MIN
         mask_mul = (mask_value * (1. - mask_expr))  # [*, len]
         tmp_arc_score += mask_mul.unsqueeze(-1)
         tmp_arc_score += mask_mul.unsqueeze(-2)
         maxlen = BK.get_shape(tmp_arc_score, -1)
         tmp_arc_score += mask_value * BK.eye(maxlen)
         prune_topk = min(prune_topk, int(maxlen * prune_perc + 1), maxlen)
         if prune_topk >= maxlen:
             topk_arc_score = tmp_arc_score
         else:
             topk_arc_score, _ = BK.topk(tmp_arc_score,
                                         prune_topk,
                                         dim=-1,
                                         sorted=False)  # [*, len, k]
         min_topk_arc_score = topk_arc_score.min(-1)[0].unsqueeze(
             -1)  # [*, len, 1]
         max_topk_arc_score = topk_arc_score.max(-1)[0].unsqueeze(
             -1)  # [*, len, 1]
         arc_score_thresh = BK.max_elem(min_topk_arc_score,
                                        max_topk_arc_score -
                                        prune_gap)  # [*, len, 1]
         t_valid_mask = (tmp_arc_score > arc_score_thresh
                         )  # [*, len-m, len-h]
         final_valid_mask |= t_valid_mask
     return final_valid_mask, arc_marginals