Exemple #1
0
 def _select_topk(self, masked_scores, pad_mask, ratio_mask, topk_ratio,
                  thresh_k):
     slen = BK.get_shape(masked_scores, -1)
     sel_mask = BK.copy(pad_mask)
     # first apply the absolute thresh
     if thresh_k is not None:
         sel_mask *= (masked_scores > thresh_k).float()
     # then ratio-ed topk
     if topk_ratio > 0.:
         # prepare number
         cur_topk_num = ratio_mask.sum(-1)  # [*]
         cur_topk_num = (cur_topk_num * topk_ratio).long()  # [*]
         cur_topk_num.clamp_(min=1, max=slen)  # at least one, at most all
         # topk
         actual_max_k = max(cur_topk_num.max().item(), 1)
         topk_score, _ = BK.topk(masked_scores,
                                 actual_max_k,
                                 dim=-1,
                                 sorted=True)  # [*, k]
         thresh_score = topk_score.gather(
             -1,
             cur_topk_num.clamp(min=1).unsqueeze(-1) - 1)  # [*, 1]
         # get mask and apply
         sel_mask *= (masked_scores >= thresh_score).float()
     return sel_mask
Exemple #2
0
 def loss_special(self, repr_t, mask_t, disturb_keep_arr, input_map,
                  masklm_node):
     pred_mask_t = BK.copy(mask_t)
     pred_mask_t *= BK.input_real(
         1. - disturb_keep_arr)  # not for the non-shuffled ones
     abs_posi = input_map["posi"]  # shuffled positions
     # no predictions for ARTI_ROOT
     if self.add_root_token:
         pred_mask_t[:, 0] = 0.  # [bs, slen]
         abs_posi = (abs_posi[:, 1:] - 1)  # offset by 1
         pred_mask_t = pred_mask_t[:, 1:]  # remove root
     corr_targets_arr = input_map["word"][
         np.arange(len(abs_posi))[:, np.newaxis], abs_posi]
     repr_t_hid = self.speical_hid_layer(repr_t)  # go through hid here!!
     loss_item = masklm_node.loss([repr_t_hid],
                                  pred_mask_t, {"word": corr_targets_arr},
                                  active_hid=False)
     if len(loss_item) == 0:
         return loss_item
     # todo(note): simply change its name
     vs = [v for v in loss_item.values()]
     assert len(vs) == 1
     return {"orp.d-2": vs[0]}
Exemple #3
0
 def loss(self, insts: List, input_lexi, input_expr, input_mask, margin=0.):
     # todo(+N): currently margin is not used
     conf = self.conf
     bsize = len(insts)
     arange_t = BK.arange_idx(bsize)
     assert conf.train_force, "currently only have forced training"
     # get the gold ones
     gold_widxes, gold_lidxes, gold_vmasks, ret_items, _ = self.batch_inputs_g1(insts)  # [*, ?]
     # for all the steps
     num_step = BK.get_shape(gold_widxes, -1)
     # recurrent states
     hard_coverage = BK.zeros(BK.get_shape(input_mask))  # [*, slen]
     prev_state = self.rnn_unit.zero_init_hidden(bsize)  # tuple([*, D], )
     all_tok_logprobs, all_lab_logprobs = [], []
     for cstep in range(num_step):
         slice_widx, slice_lidx = gold_widxes[:,cstep], gold_lidxes[:,cstep]
         _, sel_tok_logprobs, _, sel_lab_logprobs, _, next_state = \
             self._step(input_expr, input_mask, hard_coverage, prev_state, slice_widx, slice_lidx, None)
         all_tok_logprobs.append(sel_tok_logprobs)  # add one of [*, 1]
         all_lab_logprobs.append(sel_lab_logprobs)
         hard_coverage = BK.copy(hard_coverage)  # todo(note): cannot modify inplace!
         hard_coverage[arange_t, slice_widx] += 1.
         prev_state = [z.squeeze(-2) for z in next_state]
     # concat all the loss and mask
     # todo(note): no need to use gold_valid since things are telled in vmasks
     cat_tok_logprobs = BK.concat(all_tok_logprobs, -1) * gold_vmasks  # [*, steps]
     cat_lab_logprobs = BK.concat(all_lab_logprobs, -1) * gold_vmasks
     loss_sum = - (cat_tok_logprobs.sum() * conf.lambda_att + cat_lab_logprobs.sum() * conf.lambda_lab)
     # todo(+N): here we are dividing lab_logprobs with the all-count, do we need to separate?
     loss_count = gold_vmasks.sum()
     ret_losses = [[loss_sum, loss_count]]
     # =====
     # make eos unvalid for return
     ret_valid_mask = gold_vmasks * (gold_widxes>0).float()
     # embeddings
     sel_lab_embeds = self._hl_lookup(gold_lidxes)
     return ret_losses, ret_items, gold_widxes, ret_valid_mask, gold_lidxes, sel_lab_embeds
Exemple #4
0
    def __call__(self, enc_expr, valid_expr, arc_marg_expr):
        # ===== avoid NAN
        def _sum_marg(m, dim):
            s = m.sum(dim).unsqueeze(dim)
            s += (s < 1e-5).float() * 1e-5
            return s

        # =====
        output = enc_expr
        arc_marg_expr = arc_marg_expr * valid_expr  # only keep the after-pruning ones
        if self.use_par:
            m_mask = valid_expr  # [*, lem-m, len-h]
            m_marg = arc_marg_expr / _sum_marg(arc_marg_expr, -1)
            senc_par_expr = self._calc_one_node(self.node_par, self.ff_par,
                                                enc_expr, m_mask, m_marg)
            output = output + senc_par_expr
        if self.use_chs:
            h_mask = BK.copy(valid_expr.transpose(-1, -2))  # [*, len-h, len-m]
            h_marg = (arc_marg_expr / _sum_marg(arc_marg_expr, -2)).transpose(
                -1, -2)
            senc_chs_expr = self._calc_one_node(self.node_chs, self.ff_chs,
                                                enc_expr, h_mask, h_marg)
            output = output + senc_chs_expr
        return output
Exemple #5
0
 def init_cache(self, enc_repr, enc_mask_arr, insts, g1_pack):
     # init caches and scores, [orig_bsize, max_slen, D]
     self.enc_repr = enc_repr
     self.scoring_fixed_mask_ct = self._init_fixed_mask(enc_mask_arr)
     # init other masks
     self.scoring_mask_ct = BK.copy(self.scoring_fixed_mask_ct)
     full_shape = BK.get_shape(self.scoring_mask_ct)
     # init oracle masks
     oracle_mask_ct = BK.constants(full_shape,
                                   value=0.,
                                   device=BK.CPU_DEVICE)
     # label=0 means nothing, but still need it to avoid index error (dummy oracle for wrong/no-oracle states)
     oracle_label_ct = BK.constants(full_shape,
                                    value=0,
                                    dtype=BK.int64,
                                    device=BK.CPU_DEVICE)
     for i, inst in enumerate(insts):
         EfOracler.init_oracle_mask(inst, oracle_mask_ct[i],
                                    oracle_label_ct[i])
     self.oracle_mask_t = BK.to_device(oracle_mask_ct)
     self.oracle_mask_ct = oracle_mask_ct
     self.oracle_label_t = BK.to_device(oracle_label_ct)
     # scoring cache
     self.scoring_cache.init_cache(enc_repr, g1_pack)
Exemple #6
0
 def loss(self, repr_t, attn_t, mask_t, disturb_keep_arr, **kwargs):
     conf = self.conf
     CR, PR = conf.cand_range, conf.pred_range
     # -----
     mask_single = BK.copy(mask_t)
     # no predictions for ARTI_ROOT
     if self.add_root_token:
         mask_single[:, 0] = 0.  # [bs, slen]
     # casting predicting range
     cur_slen = BK.get_shape(mask_single, -1)
     arange_t = BK.arange_idx(cur_slen)  # [slen]
     # [1, len] - [len, 1] = [len, len]
     reldist_t = (arange_t.unsqueeze(-2) - arange_t.unsqueeze(-1)
                  )  # [slen, slen]
     mask_pair = ((reldist_t.abs() <= CR) &
                  (reldist_t != 0)).float()  # within CR-range; [slen, slen]
     mask_pair = mask_pair * mask_single.unsqueeze(
         -1) * mask_single.unsqueeze(-2)  # [bs, slen, slen]
     if disturb_keep_arr is not None:
         mask_pair *= BK.input_real(1. - disturb_keep_arr).unsqueeze(
             -1)  # no predictions for the kept ones!
     # get all pair scores
     score_t = self.ps_node.paired_score(
         repr_t, repr_t, attn_t, maskp=mask_pair)  # [bs, len_q, len_k, 2*R]
     # -----
     # loss: normalize on which dim?
     # get the answers first
     if conf.pred_abs:
         answer_t = reldist_t.abs()  # [1,2,3,...,PR]
         answer_t.clamp_(
             min=0, max=PR -
             1)  # [slen, slen], clip in range, distinguish using masks
     else:
         answer_t = BK.where(
             (reldist_t >= 0), reldist_t - 1,
             reldist_t + 2 * PR)  # [1,2,3,...PR,-PR,-PR+1,...,-1]
         answer_t.clamp_(
             min=0, max=2 * PR -
             1)  # [slen, slen], clip in range, distinguish using masks
     # expand answer into idxes
     answer_hit_t = BK.zeros(BK.get_shape(answer_t) +
                             [2 * PR])  # [len_q, len_k, 2*R]
     answer_hit_t.scatter_(-1, answer_t.unsqueeze(-1), 1.)
     answer_valid_t = ((reldist_t.abs() <= PR) &
                       (reldist_t != 0)).float().unsqueeze(
                           -1)  # [bs, len_q, len_k, 1]
     answer_hit_t = answer_hit_t * mask_pair.unsqueeze(
         -1) * answer_valid_t  # clear invalid ones; [bs, len_q, len_k, 2*R]
     # get losses sum(log(answer*prob))
     # -- dim=-1 is standard 2*PR classification, dim=-2 usually have 2*PR candidates, but can be less at edges
     all_losses = []
     for one_dim, one_lambda in zip([-1, -2],
                                    [conf.lambda_n1, conf.lambda_n2]):
         if one_lambda > 0.:
             # since currently there can be only one or zero correct answer
             logprob_t = BK.log_softmax(score_t,
                                        one_dim)  # [bs, len_q, len_k, 2*R]
             sumlogprob_t = (logprob_t * answer_hit_t).sum(
                 one_dim)  # [bs, len_q, len_k||2*R]
             cur_dim_mask_t = (answer_hit_t.sum(one_dim) >
                               0.).float()  # [bs, len_q, len_k||2*R]
             # loss
             cur_dim_loss = -(sumlogprob_t * cur_dim_mask_t).sum()
             cur_dim_count = cur_dim_mask_t.sum()
             # argmax and corr (any correct counts)
             _, cur_argmax_idxes = score_t.max(one_dim)
             cur_corrs = answer_hit_t.gather(
                 one_dim, cur_argmax_idxes.unsqueeze(
                     one_dim))  # [bs, len_q, len_k|1, 2*R|1]
             cur_dim_corr_count = cur_corrs.sum()
             # compile loss
             one_loss = LossHelper.compile_leaf_info(
                 f"d{one_dim}",
                 cur_dim_loss,
                 cur_dim_count,
                 loss_lambda=one_lambda,
                 corr=cur_dim_corr_count)
             all_losses.append(one_loss)
     return self._compile_component_loss("orp", all_losses)
Exemple #7
0
 def loss(self, insts: List[GeneralSentence], repr_t, attn_t, mask_t,
          **kwargs):
     conf = self.conf
     # detach input?
     if self.no_detach_input.value <= 0.:
         repr_t = repr_t.detach()  # no grad back if no_detach_input<=0.
     # scoring
     label_scores, score_masks = self._score(
         repr_t, attn_t,
         mask_t)  # [bs, len_q, len_k, 1+N], [bs, len_q, len_k]
     # -----
     # get golds
     bsize, max_len = BK.get_shape(mask_t)
     shape_lidxes = [bsize, max_len, max_len]
     gold_lidxes = np.zeros(shape_lidxes, dtype=np.long)  # [bs, mlen, mlen]
     gold_heads = np.zeros(shape_lidxes[:-1], dtype=np.long)  # [bs, mlen]
     for bidx, inst in enumerate(insts):
         cur_dep_tree = inst.dep_tree
         cur_len = len(cur_dep_tree)
         gold_lidxes[bidx, :cur_len, :cur_len] = cur_dep_tree.label_matrix
         gold_heads[bidx, :cur_len] = cur_dep_tree.heads
     # -----
     margin = self.margin.value
     all_losses = []
     # first is loss_labels
     lambda_label = conf.lambda_label
     if lambda_label > 0.:
         gold_lidxes_t = BK.input_idx(gold_lidxes)  # [bs, len_q, len_k]
         label_losses = BK.loss_nll(label_scores,
                                    gold_lidxes_t,
                                    margin=margin)  # [bs, mlen, mlen]
         positive_mask_t = (gold_lidxes_t > 0).float()  # [bs, mlen, mlen]
         negative_mask_t = (BK.rand(shape_lidxes) <
                            conf.label_neg_rate).float()  # [bs, mlen, mlen]
         loss_mask_t = score_masks * (positive_mask_t + negative_mask_t
                                      )  # [bs, mlen, mlen]
         loss_mask_t.clamp_(max=1.)
         masked_label_losses = label_losses * loss_mask_t
         # compile loss
         final_label_loss = LossHelper.compile_leaf_info(
             f"label",
             masked_label_losses.sum(),
             loss_mask_t.sum(),
             loss_lambda=lambda_label,
             npos=positive_mask_t.sum())
         all_losses.append(final_label_loss)
     # then head loss
     lambda_head = conf.lambda_head
     if lambda_head > 0.:
         # get head score simply by argmax on ranges
         head_scores, _ = self._ranged_label_scores(label_scores).max(
             -1)  # [bs, mlen, mlen]
         gold_heads_t = BK.input_idx(gold_heads)
         head_losses = BK.loss_nll(head_scores, gold_heads_t,
                                   margin=margin)  # [bs, mlen]
         # mask
         head_mask_t = BK.copy(mask_t)
         head_mask_t[:, 0] = 0  # not for ARTI_ROOT
         masked_head_losses = head_losses * head_mask_t
         # compile loss
         final_head_loss = LossHelper.compile_leaf_info(
             f"head",
             masked_head_losses.sum(),
             head_mask_t.sum(),
             loss_lambda=lambda_label)
         all_losses.append(final_head_loss)
     # --
     return self._compile_component_loss("dp", all_losses)