Beispiel #1
0
 def loss(self, enc_expr, pad_mask, gold_mask, margin: float):
     conf = self.conf
     # =====
     # first testing-mode scoring and selecting
     res_mask, all_scores = self.score_and_select(enc_expr, pad_mask)
     # add gold
     if conf.ns_add_gold:
         res_mask += gold_mask
         res_mask.clamp_(max=1.)
     # =====
     with BK.no_grad_env():
         # how to select instances for training
         if conf.train_ratio2gold > 0.:
             # use gold-ratio for training
             masked_all_scores = all_scores + (
                 1. - pad_mask + gold_mask) * Constants.REAL_PRAC_MIN
             loss_mask = self._select_topk(masked_all_scores, pad_mask,
                                           gold_mask, conf.train_ratio2gold,
                                           None)
             loss_mask += gold_mask
             loss_mask.clamp_(max=1.)
         elif not conf.ns_add_gold:
             loss_mask = res_mask + gold_mask
             loss_mask.clamp_(max=1.)
         else:
             # we already have the gold
             loss_mask = res_mask
     # ===== calculating losses [*, L]
     # first aug scores by margin
     aug_scores = all_scores - (conf.margin_pos * margin) * gold_mask + (
         conf.margin_neg * margin) * (1. - gold_mask)
     if self.loss_hinge:
         # multiply pos instances with -1
         flipped_scores = aug_scores * (1. - 2 * gold_mask)
         losses_all = BK.clamp(flipped_scores, min=0.)
     elif self.loss_prob:
         losses_all = BK.binary_cross_entropy_with_logits(aug_scores,
                                                          gold_mask,
                                                          reduction='none')
         if conf.no_loss_satisfy_margin:
             unsatisfy_mask = ((aug_scores * (1. - 2 * gold_mask)) >
                               0.).float()  # those still with hinge loss
             losses_all *= unsatisfy_mask
     else:
         raise NotImplementedError()
     # return prediction and loss(sum/count)
     loss_sum = (losses_all * loss_mask).sum()
     if conf.train_return_loss_mask:
         return [[loss_sum, loss_mask.sum()]], loss_mask
     else:
         return [[loss_sum, loss_mask.sum()]], res_mask
Beispiel #2
0
 def _losses_single(self,
                    score_expr,
                    gold_idxes_expr,
                    single_sample,
                    is_hinge=False,
                    margin=0.):
     # expand the idxes to 0/1
     score_shape = BK.get_shape(score_expr)
     expanded_idxes_expr = BK.constants(score_shape, 0.)
     expanded_idxes_expr = BK.minus_margin(expanded_idxes_expr,
                                           gold_idxes_expr,
                                           -1.)  # minus -1 means +1
     # todo(+N): first adjust margin, since previously only minus margin for golds?
     if margin > 0.:
         adjusted_scores = margin + BK.minus_margin(score_expr,
                                                    gold_idxes_expr, margin)
     else:
         adjusted_scores = score_expr
     # [*, L]
     if is_hinge:
         # multiply pos instances with -1
         flipped_scores = adjusted_scores * (1. - 2 * expanded_idxes_expr)
         losses_all = BK.clamp(flipped_scores, min=0.)
     else:
         losses_all = BK.binary_cross_entropy_with_logits(
             adjusted_scores, expanded_idxes_expr, reduction='none')
     # special interpretation (todo(+2): there can be better implementation)
     if single_sample < 1.:
         # todo(warn): lower bound of sample_rate, ensure 2 samples
         real_sample_rate = max(single_sample, 2. / score_shape[-1])
     elif single_sample >= 2.:
         # including the positive one
         real_sample_rate = max(single_sample, 2.) / score_shape[-1]
     else:  # [1., 2.)
         real_sample_rate = single_sample
     #
     if real_sample_rate < 1.:
         sample_weight = BK.random_bernoulli(score_shape, real_sample_rate,
                                             1.)
         # make sure positive is valid
         sample_weight = (sample_weight +
                          expanded_idxes_expr.float()).clamp_(0., 1.)
         #
         final_losses = (losses_all *
                         sample_weight).sum(-1) / sample_weight.sum(-1)
     else:
         final_losses = losses_all.mean(-1)
     return final_losses