def loss(self, ms_items: List, bert_expr): conf = self.conf max_range = self.conf.max_range bsize = len(ms_items) # collect instances col_efs, _, col_bidxes_t, col_hidxes_t, col_ldists_t, col_rdists_t = self._collect_insts( ms_items, True) if len(col_efs) == 0: zzz = BK.zeros([]) return [[zzz, zzz, zzz], [zzz, zzz, zzz]] left_scores, right_scores = self._score(bert_expr, col_bidxes_t, col_hidxes_t) # [N, R] if conf.use_binary_scorer: left_binaries, right_binaries = (BK.arange_idx(max_range)<=col_ldists_t.unsqueeze(-1)).float(), \ (BK.arange_idx(max_range)<=col_rdists_t.unsqueeze(-1)).float() # [N,R] left_losses = BK.binary_cross_entropy_with_logits( left_scores, left_binaries, reduction='none')[:, 1:] right_losses = BK.binary_cross_entropy_with_logits( right_scores, right_binaries, reduction='none')[:, 1:] left_count = right_count = BK.input_real( BK.get_shape(left_losses, 0) * (max_range - 1)) else: left_losses = BK.loss_nll(left_scores, col_ldists_t) right_losses = BK.loss_nll(right_scores, col_rdists_t) left_count = right_count = BK.input_real( BK.get_shape(left_losses, 0)) return [[left_losses.sum(), left_count, left_count], [right_losses.sum(), right_count, right_count]]
def loss(self, enc_expr, pad_mask, gold_mask, margin: float): conf = self.conf # ===== # first testing-mode scoring and selecting res_mask, all_scores = self.score_and_select(enc_expr, pad_mask) # add gold if conf.ns_add_gold: res_mask += gold_mask res_mask.clamp_(max=1.) # ===== with BK.no_grad_env(): # how to select instances for training if conf.train_ratio2gold > 0.: # use gold-ratio for training masked_all_scores = all_scores + ( 1. - pad_mask + gold_mask) * Constants.REAL_PRAC_MIN loss_mask = self._select_topk(masked_all_scores, pad_mask, gold_mask, conf.train_ratio2gold, None) loss_mask += gold_mask loss_mask.clamp_(max=1.) elif not conf.ns_add_gold: loss_mask = res_mask + gold_mask loss_mask.clamp_(max=1.) else: # we already have the gold loss_mask = res_mask # ===== calculating losses [*, L] # first aug scores by margin aug_scores = all_scores - (conf.margin_pos * margin) * gold_mask + ( conf.margin_neg * margin) * (1. - gold_mask) if self.loss_hinge: # multiply pos instances with -1 flipped_scores = aug_scores * (1. - 2 * gold_mask) losses_all = BK.clamp(flipped_scores, min=0.) elif self.loss_prob: losses_all = BK.binary_cross_entropy_with_logits(aug_scores, gold_mask, reduction='none') if conf.no_loss_satisfy_margin: unsatisfy_mask = ((aug_scores * (1. - 2 * gold_mask)) > 0.).float() # those still with hinge loss losses_all *= unsatisfy_mask else: raise NotImplementedError() # return prediction and loss(sum/count) loss_sum = (losses_all * loss_mask).sum() if conf.train_return_loss_mask: return [[loss_sum, loss_mask.sum()]], loss_mask else: return [[loss_sum, loss_mask.sum()]], res_mask
def _losses_single(self, score_expr, gold_idxes_expr, single_sample, is_hinge=False, margin=0.): # expand the idxes to 0/1 score_shape = BK.get_shape(score_expr) expanded_idxes_expr = BK.constants(score_shape, 0.) expanded_idxes_expr = BK.minus_margin(expanded_idxes_expr, gold_idxes_expr, -1.) # minus -1 means +1 # todo(+N): first adjust margin, since previously only minus margin for golds? if margin > 0.: adjusted_scores = margin + BK.minus_margin(score_expr, gold_idxes_expr, margin) else: adjusted_scores = score_expr # [*, L] if is_hinge: # multiply pos instances with -1 flipped_scores = adjusted_scores * (1. - 2 * expanded_idxes_expr) losses_all = BK.clamp(flipped_scores, min=0.) else: losses_all = BK.binary_cross_entropy_with_logits( adjusted_scores, expanded_idxes_expr, reduction='none') # special interpretation (todo(+2): there can be better implementation) if single_sample < 1.: # todo(warn): lower bound of sample_rate, ensure 2 samples real_sample_rate = max(single_sample, 2. / score_shape[-1]) elif single_sample >= 2.: # including the positive one real_sample_rate = max(single_sample, 2.) / score_shape[-1] else: # [1., 2.) real_sample_rate = single_sample # if real_sample_rate < 1.: sample_weight = BK.random_bernoulli(score_shape, real_sample_rate, 1.) # make sure positive is valid sample_weight = (sample_weight + expanded_idxes_expr.float()).clamp_(0., 1.) # final_losses = (losses_all * sample_weight).sum(-1) / sample_weight.sum(-1) else: final_losses = losses_all.mean(-1) return final_losses