def __call__(self, scores, temperature=1., dim=-1): is_training = self.rop.training # only use stochastic at training if is_training: if self.use_gumbel: gumbel_eps = self.gumbel_eps G = (BK.rand(BK.get_shape(scores)) + gumbel_eps).clamp( max=1.) # [0,1) scores = scores - (gumbel_eps - G.log()).log() # normalize probs = BK.softmax(scores / temperature, dim=dim) # [*, S] # prune and re-normalize? if self.prune_val > 0.: probs = probs * (probs > self.prune_val).float() # todo(note): currently no re-normalize # probs = probs / probs.sum(dim=dim, keepdim=True) # [*, S] # argmax and ste if self.use_argmax: # use the hard argmax max_probs, _ = probs.max(dim, keepdim=True) # [*, 1] # todo(+N): currently we do not re-normalize here, should it be done here? st_probs = (probs >= max_probs).float() * probs # [*, S] if is_training: # (hard-soft).detach() + soft st_probs = (st_probs - probs).detach() + probs # [*, S] return st_probs else: return probs
def _select_cands_training(self, input_mask, gold_mask, train_min_rate): # first select examples (randomly) sel_mask = (BK.rand(BK.get_shape(input_mask)) < train_min_rate).float() # [*, slen] # add gold and exclude pad sel_mask += gold_mask sel_mask.clamp_(max=1.) sel_mask *= input_mask return sel_mask
def _prepare_tmask(self, input_mask, gold_mask, trate): # todo(+3): currently simple sampling sel_mask = (BK.rand(BK.get_shape(gold_mask)) < trate).float() # add gold and exclude pad sel_mask += gold_mask sel_mask.clamp_(max=1.) if input_mask.dim() < sel_mask.dim(): input_mask = input_mask.unsqueeze(-1) sel_mask *= input_mask return sel_mask
def __call__(self, input, add_root_token: bool): voc = self.voc # todo(note): append a [cls/root] idx, currently use "bos" input_t = BK.input_idx(input) # [*, 1+slen] # rare unk in training if self.rop.training and self.use_rare_unk: rare_unk_rate = self.ec_conf.comp_rare_unk cur_unk_imask = (self.rare_mask[input_t] * (BK.rand(BK.get_shape(input_t))<rare_unk_rate)).detach().long() input_t = input_t * (1-cur_unk_imask) + self.voc.unk * cur_unk_imask # root if add_root_token: input_t_p0 = BK.constants(BK.get_shape(input_t)[:-1]+[1], voc.bos, dtype=input_t.dtype) # [*, 1+slen] input_t_p1 = BK.concat([input_t_p0, input_t], -1) else: input_t_p1 = input_t expr = self.E(input_t_p1) # [*, 1?+slen] return self.dropout(expr)
def loss(self, repr_ef, repr_evt, lab_ef, lab_evt, mask_ef, mask_evt, gold_idxes, margin=0.): conf = self.conf # ----- if np.prod(BK.get_shape(gold_idxes)) == 0: return [[BK.zeros([]), BK.zeros([])]] # ----- # todo(note): +1 for space of DROPED(UNK) lab_ef = self._dropout_idxes(lab_ef + 1, conf.train_drop_ef_lab) lab_evt = self._dropout_idxes(lab_evt + 1, conf.train_drop_evt_lab) if conf.linker_ef_detach: repr_ef = repr_ef.detach() if conf.linker_evt_detach: repr_evt = repr_evt.detach() full_score = self._score(repr_ef, repr_evt, lab_ef, lab_evt) # [*, len-ef, len-evt, D] if margin > 0.: aug_score = BK.zeros(BK.get_shape(full_score)) + margin aug_score.scatter_(-1, gold_idxes.unsqueeze(-1), 0.) full_score += aug_score full_logprobs = BK.log_softmax(full_score, -1) gold_logprobs = full_logprobs.gather(-1, gold_idxes.unsqueeze(-1)).squeeze( -1) # [*, len-ef, len-evt] # sampling and mask loss_mask = mask_ef.unsqueeze(-1) * mask_evt.unsqueeze(-2) # ==== # first select examples (randomly) sel_mask = (BK.rand(BK.get_shape(loss_mask)) < conf.train_min_rate).float() # [*, len-ef, len-evt] # add gold and exclude pad sel_mask += (gold_idxes > 0).float() sel_mask.clamp_(max=1.) loss_mask *= sel_mask # ===== loss_sum = -(gold_logprobs * loss_mask).sum() loss_count = loss_mask.sum() ret_losses = [[loss_sum, loss_count]] return ret_losses
def _dropout_idxes(self, idxes, rate): zero_mask = (BK.rand(BK.get_shape(idxes)) < rate).long() return zero_mask * idxes
def loss(self, insts: List[GeneralSentence], repr_t, attn_t, mask_t, **kwargs): conf = self.conf # detach input? if self.no_detach_input.value <= 0.: repr_t = repr_t.detach() # no grad back if no_detach_input<=0. # scoring label_scores, score_masks = self._score( repr_t, attn_t, mask_t) # [bs, len_q, len_k, 1+N], [bs, len_q, len_k] # ----- # get golds bsize, max_len = BK.get_shape(mask_t) shape_lidxes = [bsize, max_len, max_len] gold_lidxes = np.zeros(shape_lidxes, dtype=np.long) # [bs, mlen, mlen] gold_heads = np.zeros(shape_lidxes[:-1], dtype=np.long) # [bs, mlen] for bidx, inst in enumerate(insts): cur_dep_tree = inst.dep_tree cur_len = len(cur_dep_tree) gold_lidxes[bidx, :cur_len, :cur_len] = cur_dep_tree.label_matrix gold_heads[bidx, :cur_len] = cur_dep_tree.heads # ----- margin = self.margin.value all_losses = [] # first is loss_labels lambda_label = conf.lambda_label if lambda_label > 0.: gold_lidxes_t = BK.input_idx(gold_lidxes) # [bs, len_q, len_k] label_losses = BK.loss_nll(label_scores, gold_lidxes_t, margin=margin) # [bs, mlen, mlen] positive_mask_t = (gold_lidxes_t > 0).float() # [bs, mlen, mlen] negative_mask_t = (BK.rand(shape_lidxes) < conf.label_neg_rate).float() # [bs, mlen, mlen] loss_mask_t = score_masks * (positive_mask_t + negative_mask_t ) # [bs, mlen, mlen] loss_mask_t.clamp_(max=1.) masked_label_losses = label_losses * loss_mask_t # compile loss final_label_loss = LossHelper.compile_leaf_info( f"label", masked_label_losses.sum(), loss_mask_t.sum(), loss_lambda=lambda_label, npos=positive_mask_t.sum()) all_losses.append(final_label_loss) # then head loss lambda_head = conf.lambda_head if lambda_head > 0.: # get head score simply by argmax on ranges head_scores, _ = self._ranged_label_scores(label_scores).max( -1) # [bs, mlen, mlen] gold_heads_t = BK.input_idx(gold_heads) head_losses = BK.loss_nll(head_scores, gold_heads_t, margin=margin) # [bs, mlen] # mask head_mask_t = BK.copy(mask_t) head_mask_t[:, 0] = 0 # not for ARTI_ROOT masked_head_losses = head_losses * head_mask_t # compile loss final_head_loss = LossHelper.compile_leaf_info( f"head", masked_head_losses.sum(), head_mask_t.sum(), loss_lambda=lambda_label) all_losses.append(final_head_loss) # -- return self._compile_component_loss("dp", all_losses)