Beispiel #1
0
 def __call__(self, scores, temperature=1., dim=-1):
     is_training = self.rop.training
     # only use stochastic at training
     if is_training:
         if self.use_gumbel:
             gumbel_eps = self.gumbel_eps
             G = (BK.rand(BK.get_shape(scores)) + gumbel_eps).clamp(
                 max=1.)  # [0,1)
             scores = scores - (gumbel_eps - G.log()).log()
     # normalize
     probs = BK.softmax(scores / temperature, dim=dim)  # [*, S]
     # prune and re-normalize?
     if self.prune_val > 0.:
         probs = probs * (probs > self.prune_val).float()
         # todo(note): currently no re-normalize
         # probs = probs / probs.sum(dim=dim, keepdim=True)  # [*, S]
     # argmax and ste
     if self.use_argmax:  # use the hard argmax
         max_probs, _ = probs.max(dim, keepdim=True)  # [*, 1]
         # todo(+N): currently we do not re-normalize here, should it be done here?
         st_probs = (probs >= max_probs).float() * probs  # [*, S]
         if is_training:  # (hard-soft).detach() + soft
             st_probs = (st_probs - probs).detach() + probs  # [*, S]
         return st_probs
     else:
         return probs
Beispiel #2
0
 def _select_cands_training(self, input_mask, gold_mask, train_min_rate):
     # first select examples (randomly)
     sel_mask = (BK.rand(BK.get_shape(input_mask)) <
                 train_min_rate).float()  # [*, slen]
     # add gold and exclude pad
     sel_mask += gold_mask
     sel_mask.clamp_(max=1.)
     sel_mask *= input_mask
     return sel_mask
Beispiel #3
0
 def _prepare_tmask(self, input_mask, gold_mask, trate):
     # todo(+3): currently simple sampling
     sel_mask = (BK.rand(BK.get_shape(gold_mask)) < trate).float()
     # add gold and exclude pad
     sel_mask += gold_mask
     sel_mask.clamp_(max=1.)
     if input_mask.dim() < sel_mask.dim():
         input_mask = input_mask.unsqueeze(-1)
     sel_mask *= input_mask
     return sel_mask
Beispiel #4
0
 def __call__(self, input, add_root_token: bool):
     voc = self.voc
     # todo(note): append a [cls/root] idx, currently use "bos"
     input_t = BK.input_idx(input)  # [*, 1+slen]
     # rare unk in training
     if self.rop.training and self.use_rare_unk:
         rare_unk_rate = self.ec_conf.comp_rare_unk
         cur_unk_imask = (self.rare_mask[input_t] * (BK.rand(BK.get_shape(input_t))<rare_unk_rate)).detach().long()
         input_t = input_t * (1-cur_unk_imask) + self.voc.unk * cur_unk_imask
     # root
     if add_root_token:
         input_t_p0 = BK.constants(BK.get_shape(input_t)[:-1]+[1], voc.bos, dtype=input_t.dtype)  # [*, 1+slen]
         input_t_p1 = BK.concat([input_t_p0, input_t], -1)
     else:
         input_t_p1 = input_t
     expr = self.E(input_t_p1)  # [*, 1?+slen]
     return self.dropout(expr)
Beispiel #5
0
 def loss(self,
          repr_ef,
          repr_evt,
          lab_ef,
          lab_evt,
          mask_ef,
          mask_evt,
          gold_idxes,
          margin=0.):
     conf = self.conf
     # -----
     if np.prod(BK.get_shape(gold_idxes)) == 0:
         return [[BK.zeros([]), BK.zeros([])]]
     # -----
     # todo(note): +1 for space of DROPED(UNK)
     lab_ef = self._dropout_idxes(lab_ef + 1, conf.train_drop_ef_lab)
     lab_evt = self._dropout_idxes(lab_evt + 1, conf.train_drop_evt_lab)
     if conf.linker_ef_detach:
         repr_ef = repr_ef.detach()
     if conf.linker_evt_detach:
         repr_evt = repr_evt.detach()
     full_score = self._score(repr_ef, repr_evt, lab_ef,
                              lab_evt)  # [*, len-ef, len-evt, D]
     if margin > 0.:
         aug_score = BK.zeros(BK.get_shape(full_score)) + margin
         aug_score.scatter_(-1, gold_idxes.unsqueeze(-1), 0.)
         full_score += aug_score
     full_logprobs = BK.log_softmax(full_score, -1)
     gold_logprobs = full_logprobs.gather(-1,
                                          gold_idxes.unsqueeze(-1)).squeeze(
                                              -1)  # [*, len-ef, len-evt]
     # sampling and mask
     loss_mask = mask_ef.unsqueeze(-1) * mask_evt.unsqueeze(-2)
     # ====
     # first select examples (randomly)
     sel_mask = (BK.rand(BK.get_shape(loss_mask)) <
                 conf.train_min_rate).float()  # [*, len-ef, len-evt]
     # add gold and exclude pad
     sel_mask += (gold_idxes > 0).float()
     sel_mask.clamp_(max=1.)
     loss_mask *= sel_mask
     # =====
     loss_sum = -(gold_logprobs * loss_mask).sum()
     loss_count = loss_mask.sum()
     ret_losses = [[loss_sum, loss_count]]
     return ret_losses
Beispiel #6
0
 def _dropout_idxes(self, idxes, rate):
     zero_mask = (BK.rand(BK.get_shape(idxes)) < rate).long()
     return zero_mask * idxes
Beispiel #7
0
 def loss(self, insts: List[GeneralSentence], repr_t, attn_t, mask_t,
          **kwargs):
     conf = self.conf
     # detach input?
     if self.no_detach_input.value <= 0.:
         repr_t = repr_t.detach()  # no grad back if no_detach_input<=0.
     # scoring
     label_scores, score_masks = self._score(
         repr_t, attn_t,
         mask_t)  # [bs, len_q, len_k, 1+N], [bs, len_q, len_k]
     # -----
     # get golds
     bsize, max_len = BK.get_shape(mask_t)
     shape_lidxes = [bsize, max_len, max_len]
     gold_lidxes = np.zeros(shape_lidxes, dtype=np.long)  # [bs, mlen, mlen]
     gold_heads = np.zeros(shape_lidxes[:-1], dtype=np.long)  # [bs, mlen]
     for bidx, inst in enumerate(insts):
         cur_dep_tree = inst.dep_tree
         cur_len = len(cur_dep_tree)
         gold_lidxes[bidx, :cur_len, :cur_len] = cur_dep_tree.label_matrix
         gold_heads[bidx, :cur_len] = cur_dep_tree.heads
     # -----
     margin = self.margin.value
     all_losses = []
     # first is loss_labels
     lambda_label = conf.lambda_label
     if lambda_label > 0.:
         gold_lidxes_t = BK.input_idx(gold_lidxes)  # [bs, len_q, len_k]
         label_losses = BK.loss_nll(label_scores,
                                    gold_lidxes_t,
                                    margin=margin)  # [bs, mlen, mlen]
         positive_mask_t = (gold_lidxes_t > 0).float()  # [bs, mlen, mlen]
         negative_mask_t = (BK.rand(shape_lidxes) <
                            conf.label_neg_rate).float()  # [bs, mlen, mlen]
         loss_mask_t = score_masks * (positive_mask_t + negative_mask_t
                                      )  # [bs, mlen, mlen]
         loss_mask_t.clamp_(max=1.)
         masked_label_losses = label_losses * loss_mask_t
         # compile loss
         final_label_loss = LossHelper.compile_leaf_info(
             f"label",
             masked_label_losses.sum(),
             loss_mask_t.sum(),
             loss_lambda=lambda_label,
             npos=positive_mask_t.sum())
         all_losses.append(final_label_loss)
     # then head loss
     lambda_head = conf.lambda_head
     if lambda_head > 0.:
         # get head score simply by argmax on ranges
         head_scores, _ = self._ranged_label_scores(label_scores).max(
             -1)  # [bs, mlen, mlen]
         gold_heads_t = BK.input_idx(gold_heads)
         head_losses = BK.loss_nll(head_scores, gold_heads_t,
                                   margin=margin)  # [bs, mlen]
         # mask
         head_mask_t = BK.copy(mask_t)
         head_mask_t[:, 0] = 0  # not for ARTI_ROOT
         masked_head_losses = head_losses * head_mask_t
         # compile loss
         final_head_loss = LossHelper.compile_leaf_info(
             f"head",
             masked_head_losses.sum(),
             head_mask_t.sum(),
             loss_lambda=lambda_label)
         all_losses.append(final_head_loss)
     # --
     return self._compile_component_loss("dp", all_losses)