Beispiel #1
0
 def _my_loss_prob(self, score_expr, gold_idxes_expr, entropy_lambda: float,
                   loss_mask, neg_reweight: bool):
     probs = BK.softmax(score_expr, -1)  # [*, NLab]
     log_probs = BK.log(probs + 1e-8)
     # first plain NLL loss
     nll_loss = -BK.gather_one_lastdim(log_probs,
                                       gold_idxes_expr).squeeze(-1)
     # next the special loss
     if entropy_lambda > 0.:
         negative_entropy = probs * log_probs  # [*, NLab]
         last_dim = BK.get_shape(score_expr, -1)
         confusion_matrix = 1. - BK.eye(last_dim)  # [Nlab, Nlab]
         entropy_mask = confusion_matrix[gold_idxes_expr]  # [*, Nlab]
         entropy_loss = (negative_entropy * entropy_mask).sum(-1)
         final_loss = nll_loss + entropy_lambda * entropy_loss
     else:
         final_loss = nll_loss
     # reweight?
     if neg_reweight:
         golden_prob = BK.gather_one_lastdim(probs,
                                             gold_idxes_expr).squeeze(-1)
         is_full_nil = (gold_idxes_expr == 0.).float()
         not_full_nil = 1. - is_full_nil
         count_pos = (loss_mask * not_full_nil).sum()
         count_neg = (loss_mask * is_full_nil).sum()
         prob_pos = (loss_mask * not_full_nil * golden_prob).sum()
         prob_neg = (loss_mask * is_full_nil * golden_prob).sum()
         neg_weight = prob_pos / (count_pos + count_neg - prob_neg + 1e-8)
         final_weights = not_full_nil + is_full_nil * neg_weight
         # todo(note): final mask will be applied at outside
         final_loss = final_loss * final_weights
     return final_loss
Beispiel #2
0
 def get_losses_global_hinge(full_score_expr,
                             gold_heads_expr,
                             gold_labels_expr,
                             pred_heads_expr,
                             pred_labels_expr,
                             mask_expr,
                             clamping=True):
     # combine the last two dimension
     full_shape = BK.get_shape(full_score_expr)
     # [*, m, h*L]
     last_size = full_shape[-1]
     combiend_score_expr = full_score_expr.view(full_shape[:-2] + [-1])
     # [*, m]
     gold_combined_idx_expr = gold_heads_expr * last_size + gold_labels_expr
     pred_combined_idx_expr = pred_heads_expr * last_size + pred_labels_expr
     # [*, m]
     gold_scores = BK.gather_one_lastdim(combiend_score_expr,
                                         gold_combined_idx_expr).squeeze(-1)
     pred_scores = BK.gather_one_lastdim(combiend_score_expr,
                                         pred_combined_idx_expr).squeeze(-1)
     # todo(warn): be aware of search error!
     # hinge_losses = BK.clamp(pred_scores - gold_scores, min=0.)  # this is previous version
     hinge_losses = pred_scores - gold_scores  # [*, len]
     if clamping:
         valid_losses = ((hinge_losses * mask_expr)[:, 1:].sum(-1) >
                         0.).float().unsqueeze(-1)  # [*, 1]
         return hinge_losses * valid_losses
     else:
         # for this mode, will there be problems of search error? Maybe rare.
         return hinge_losses
Beispiel #3
0
 def _loss(self, enc_repr, action_list: List[EfAction], arc_weight_list: List[float], label_weight_list: List[float],
           bidxes_list: List[int]):
     # 1. collect (batched) features; todo(note): use prev state for scoring
     hm_features = self.hm_feature_getter.get_hm_features(action_list, [a.state_from for a in action_list])
     # 2. get new sreprs
     scorer = self.scorer
     s_enc = self.slayer
     bsize_range_t = BK.input_idx(bidxes_list)
     node_h_idxes_t, node_h_srepr = ScorerHelper.calc_repr(s_enc, hm_features[0], enc_repr, bsize_range_t)
     node_m_idxes_t, node_m_srepr = ScorerHelper.calc_repr(s_enc, hm_features[1], enc_repr, bsize_range_t)
     # label loss
     if self.system_labeled:
         node_lh_expr, _ = scorer.transform_space_label(node_h_srepr, True, False)
         _, node_lm_pack = scorer.transform_space_label(node_m_srepr, False, True)
         label_scores_full = scorer.score_label(node_lm_pack, node_lh_expr)  # [*, Lab]
         label_scores = BK.gather_one_lastdim(label_scores_full, [a.label for a in action_list]).squeeze(-1)
         final_label_loss_sum = (label_scores * BK.input_real(label_weight_list)).sum()
     else:
         label_scores = final_label_loss_sum = BK.zeros([])
     # arc loss
     node_ah_expr, _ = scorer.transform_space_arc(node_h_srepr, True, False)
     _, node_am_pack = scorer.transform_space_arc(node_m_srepr, False, True)
     arc_scores = scorer.score_arc(node_am_pack, node_ah_expr).squeeze(-1)
     final_arc_loss_sum = (arc_scores * BK.input_real(arc_weight_list)).sum()
     # score reg
     return final_arc_loss_sum, final_label_loss_sum, arc_scores, label_scores
Beispiel #4
0
 def __call__(self,
              input_repr,
              mask_arr,
              require_loss,
              require_pred,
              gold_pos_arr=None):
     enc0_expr = self.enc(input_repr, mask_arr)  # [*, len, d]
     #
     enc1_expr = enc0_expr
     pos_probs, pos_losses_expr, pos_preds_expr = None, None, None
     if self.jpos_multitask:
         # get probabilities
         pos_logits = self.pred(enc0_expr)  # [*, len, nl]
         pos_probs = BK.softmax(pos_logits, dim=-1)
         # stacking for input -> output
         if self.jpos_stacking:
             enc1_expr = enc0_expr + BK.matmul(pos_probs, self.pos_weights)
         # simple cross entropy loss
         if require_loss and self.jpos_lambda > 0.:
             gold_probs = BK.gather_one_lastdim(
                 pos_probs, gold_pos_arr).squeeze(-1)  # [*, len]
             # todo(warn): multiplying the factor here, but not maksing here (masking in the final steps)
             pos_losses_expr = (-self.jpos_lambda) * gold_probs.log()
         # simple argmax for prediction
         if require_pred and self.jpos_decode:
             pos_preds_expr = pos_probs.max(dim=-1)[1]
     return enc1_expr, (pos_probs, pos_losses_expr, pos_preds_expr)
Beispiel #5
0
 def loss(self, ms_items: List, bert_expr, basic_expr, margin=0.):
     conf = self.conf
     bsize = len(ms_items)
     # build targets (include all sents)
     # todo(note): use "x.entity_fillers" for getting gold args
     offsets_t, masks_t, _, items_arr, labels_t = PrepHelper.prep_targets(
         ms_items, lambda x: x.entity_fillers, True, True,
         conf.train_neg_rate, conf.train_neg_rate_outside, True)
     labels_t.clamp_(max=1)  # either 0 or 1
     # -----
     # return 0 if all no targets
     if BK.get_shape(offsets_t, -1) == 0:
         zzz = BK.zeros([])
         return [[zzz, zzz, zzz]]
     # -----
     arange_t = BK.arange_idx(bsize).unsqueeze(-1)  # [bsize, 1]
     sel_bert_t = bert_expr[arange_t, offsets_t]  # [bsize, ?, Fold, D]
     sel_basic_t = None if basic_expr is None else basic_expr[
         arange_t, offsets_t]  # [bsize, ?, D']
     hiddens = self.adp(sel_bert_t, sel_basic_t, [])  # [bsize, ?, D"]
     # build loss
     logits = self.predictor(hiddens)  # [bsize, ?, Out]
     log_probs = BK.log_softmax(logits, -1)
     picked_log_probs = -BK.gather_one_lastdim(log_probs, labels_t).squeeze(
         -1)  # [bsize, ?]
     masked_losses = picked_log_probs * masks_t
     # loss_sum, loss_count, gold_count
     return [[
         masked_losses.sum(),
         masks_t.sum(), (labels_t > 0).float().sum()
     ]]
Beispiel #6
0
 def _get_one_loss(self, predictor, hidden_t, labels_t, masks_t,
                   lambda_loss):
     logits = predictor(hidden_t)  # [bsize, ?, Out]
     log_probs = BK.log_softmax(logits, -1)
     picked_neg_log_probs = -BK.gather_one_lastdim(
         log_probs, labels_t).squeeze(-1)  # [bsize, ?]
     masked_losses = picked_neg_log_probs * masks_t
     # loss_sum, loss_count, gold_count(only for type)
     return [
         masked_losses.sum() * lambda_loss,
         masks_t.sum(), (labels_t > 0).float().sum()
     ]
Beispiel #7
0
 def inference_on_batch(self, insts: List[ParseInstance], **kwargs):
     # iconf = self.conf.iconf
     with BK.no_grad_env():
         self.refresh_batch(False)
         # pruning and scores from g1
         valid_mask, go1_pack = self._get_g1_pack(
             insts, self.lambda_g1_arc_testing, self.lambda_g1_lab_testing)
         # encode
         input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
             insts, False)
         mask_expr = BK.input_real(mask_arr)
         # decode
         final_valid_expr = self._make_final_valid(valid_mask, mask_expr)
         ret_heads, ret_labels, _, _ = self.dl.decode(
             insts, enc_repr, final_valid_expr, go1_pack, False, 0.)
         # collect the results together
         all_heads = Helper.join_list(ret_heads)
         if ret_labels is None:
             # todo(note): simply get labels from the go1-label classifier; must provide g1parser
             if go1_pack is None:
                 _, go1_pack = self._get_g1_pack(insts, 1., 1.)
             _, go1_label_max_idxes = go1_pack[1].max(
                 -1)  # [bs, slen, slen]
             pred_heads_arr, _ = self.predict_padder.pad(
                 all_heads)  # [bs, slen]
             pred_heads_expr = BK.input_idx(pred_heads_arr)
             pred_labels_expr = BK.gather_one_lastdim(
                 go1_label_max_idxes, pred_heads_expr).squeeze(-1)
             all_labels = BK.get_value(pred_labels_expr)  # [bs, slen]
         else:
             all_labels = np.concatenate(ret_labels, 0)
         # ===== assign, todo(warn): here, the labels are directly original idx, no need to change
         for one_idx, one_inst in enumerate(insts):
             cur_length = len(one_inst) + 1
             one_inst.pred_heads.set_vals(
                 all_heads[one_idx]
                 [:cur_length])  # directly int-val for heads
             one_inst.pred_labels.build_vals(
                 all_labels[one_idx][:cur_length], self.label_vocab)
             # one_inst.pred_par_scores.set_vals(all_scores[one_idx][:cur_length])
         # =====
         # put jpos result (possibly)
         self.jpos_decode(insts, jpos_pack)
         # -----
         info = {"sent": len(insts), "tok": sum(map(len, insts))}
         return info
Beispiel #8
0
def _common_nmst(CPU_f, scores_expr, mask_expr, lengths_arr, labeled, ret_arr):
    assert labeled
    with BK.no_grad_env():
        # argmax-label: [BS, m, h]
        scores_unlabeled_max, labels_argmax = scores_expr.max(-1)
        #
        scores_unlabeled_max_arr = BK.get_value(scores_unlabeled_max)
        mst_heads_arr, _, mst_scores_arr = CPU_f(scores_unlabeled_max_arr,
                                                 lengths_arr,
                                                 labeled=False)
        # [BS, m]
        mst_heads_expr = BK.input_idx(mst_heads_arr)
        mst_labels_expr = BK.gather_one_lastdim(labels_argmax,
                                                mst_heads_expr).squeeze(-1)
        # prepare for the outputs
        if ret_arr:
            return mst_heads_arr, BK.get_value(mst_labels_expr), mst_scores_arr
        else:
            return mst_heads_expr, mst_labels_expr, BK.input_real(
                mst_scores_arr)
Beispiel #9
0
 def loss(self, insts: List[ParseInstance], enc_expr, final_valid_expr,
          go1_pack, training: bool, margin: float):
     # first do decoding and related preparation
     with BK.no_grad_env():
         _, _, g_packs, p_packs = self.decode(insts, enc_expr,
                                              final_valid_expr, go1_pack,
                                              training, margin)
         # flatten the packs (remember to rebase the indexes)
         gold_pack = self._flatten_packs(g_packs)
         pred_pack = self._flatten_packs(p_packs)
         if self.filter_pruned:
             # filter out non-valid (pruned) edges, to avoid prune error
             mod_unpruned_mask, gold_mask = self.helper.get_unpruned_mask(
                 final_valid_expr, gold_pack)
             pred_mask = mod_unpruned_mask[
                 pred_pack[0], pred_pack[1]]  # filter by specific mod
             gold_pack = [(None if z is None else z[gold_mask])
                          for z in gold_pack]
             pred_pack = [(None if z is None else z[pred_mask])
                          for z in pred_pack]
     # calculate the scores for loss
     gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes, gold_lab_idxes = gold_pack
     pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes, pred_lab_idxes = pred_pack
     gold_arc_score, gold_label_score_all = self._get_basic_score(
         enc_expr, gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes,
         gold_gp_idxes)
     pred_arc_score, pred_label_score_all = self._get_basic_score(
         enc_expr, pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes,
         pred_gp_idxes)
     # whether have labeled scores
     if self.system_labeled:
         gold_label_score = BK.gather_one_lastdim(
             gold_label_score_all, gold_lab_idxes).squeeze(-1)
         pred_label_score = BK.gather_one_lastdim(
             pred_label_score_all, pred_lab_idxes).squeeze(-1)
         ret_scores = (gold_arc_score, pred_arc_score, gold_label_score,
                       pred_label_score)
         pred_full_scores, gold_full_scores = pred_arc_score + pred_label_score, gold_arc_score + gold_label_score
     else:
         ret_scores = (gold_arc_score, pred_arc_score)
         pred_full_scores, gold_full_scores = pred_arc_score, gold_arc_score
     # hinge loss: filter-margin by loss*margin to be aware of search error
     if self.filter_margin:
         with BK.no_grad_env():
             mat_shape = BK.get_shape(enc_expr)[:2]  # [bs, slen]
             heads_gold = self._get_tmp_mat(mat_shape, 0, BK.int64,
                                            gold_b_idxes, gold_m_idxes,
                                            gold_h_idxes)
             heads_pred = self._get_tmp_mat(mat_shape, 0, BK.int64,
                                            pred_b_idxes, pred_m_idxes,
                                            pred_h_idxes)
             error_count = (heads_gold != heads_pred).float()
             if self.system_labeled:
                 labels_gold = self._get_tmp_mat(mat_shape, 0, BK.int64,
                                                 gold_b_idxes, gold_m_idxes,
                                                 gold_lab_idxes)
                 labels_pred = self._get_tmp_mat(mat_shape, 0, BK.int64,
                                                 pred_b_idxes, pred_m_idxes,
                                                 pred_lab_idxes)
                 error_count += (labels_gold != labels_pred).float()
             scores_gold = self._get_tmp_mat(mat_shape, 0., BK.float32,
                                             gold_b_idxes, gold_m_idxes,
                                             gold_full_scores)
             scores_pred = self._get_tmp_mat(mat_shape, 0., BK.float32,
                                             pred_b_idxes, pred_m_idxes,
                                             pred_full_scores)
             # todo(note): here, a small 0.1 is to exclude zero error: anyway they will get zero gradient
             sent_mask = ((scores_gold.sum(-1) - scores_pred.sum(-1)) <=
                          (margin * error_count.sum(-1) + 0.1)).float()
             num_valid_sent = float(BK.get_value(sent_mask.sum()))
         final_loss_sum = (
             pred_full_scores * sent_mask[pred_b_idxes] -
             gold_full_scores * sent_mask[gold_b_idxes]).sum()
     else:
         num_valid_sent = len(insts)
         final_loss_sum = (pred_full_scores - gold_full_scores).sum()
     # prepare final loss
     # divide loss by what?
     num_sent = len(insts)
     num_valid_tok = sum(len(z) for z in insts)
     if self.loss_div_tok:
         final_loss = final_loss_sum / num_valid_tok
     else:
         final_loss = final_loss_sum / num_sent
     final_loss_sum_val = float(BK.get_value(final_loss_sum))
     info = {
         "sent": num_sent,
         "sent_valid": num_valid_sent,
         "tok": num_valid_tok,
         "loss_sum": final_loss_sum_val
     }
     return final_loss, ret_scores, info
Beispiel #10
0
 def fb_on_batch(self,
                 annotated_insts,
                 training=True,
                 loss_factor=1,
                 **kwargs):
     self.refresh_batch(training)
     margin = self.margin.value
     # gold heads and labels
     gold_heads_arr, _ = self.predict_padder.pad(
         [z.heads.vals for z in annotated_insts])
     gold_labels_arr, _ = self.predict_padder.pad(
         [self.real2pred_labels(z.labels.idxes) for z in annotated_insts])
     gold_heads_expr = BK.input_idx(gold_heads_arr)  # [BS, Len]
     gold_labels_expr = BK.input_idx(gold_labels_arr)  # [BS, Len]
     # ===== calculate
     scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score(
         annotated_insts, training)
     full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr,
                                           training, margin,
                                           gold_heads_expr)
     #
     final_losses = None
     if self.norm_local or self.norm_single:
         select_label_score = self._score_label_selected(
             scoring_expr_pack, mask_expr, training, margin,
             gold_heads_expr, gold_labels_expr)
         # already added margin previously
         losses_heads = losses_labels = None
         if self.loss_prob:
             if self.norm_local:
                 losses_heads = BK.loss_nll(full_arc_score, gold_heads_expr)
                 losses_labels = BK.loss_nll(select_label_score,
                                             gold_labels_expr)
             elif self.norm_single:
                 single_sample = self.conf.tconf.loss_single_sample
                 losses_heads = self._losses_single(full_arc_score,
                                                    gold_heads_expr,
                                                    single_sample,
                                                    is_hinge=False)
                 losses_labels = self._losses_single(select_label_score,
                                                     gold_labels_expr,
                                                     single_sample,
                                                     is_hinge=False)
             # simply adding
             final_losses = losses_heads + losses_labels
         elif self.loss_hinge:
             if self.norm_local:
                 losses_heads = BK.loss_hinge(full_arc_score,
                                              gold_heads_expr)
                 losses_labels = BK.loss_hinge(select_label_score,
                                               gold_labels_expr)
             elif self.norm_single:
                 single_sample = self.conf.tconf.loss_single_sample
                 losses_heads = self._losses_single(full_arc_score,
                                                    gold_heads_expr,
                                                    single_sample,
                                                    is_hinge=True,
                                                    margin=margin)
                 losses_labels = self._losses_single(select_label_score,
                                                     gold_labels_expr,
                                                     single_sample,
                                                     is_hinge=True,
                                                     margin=margin)
             # simply adding
             final_losses = losses_heads + losses_labels
         elif self.loss_mr:
             # special treatment!
             probs_heads = BK.softmax(full_arc_score, dim=-1)  # [bs, m, h]
             probs_labels = BK.softmax(select_label_score,
                                       dim=-1)  # [bs, m, h]
             # select
             probs_head_gold = BK.gather_one_lastdim(
                 probs_heads, gold_heads_expr).squeeze(-1)  # [bs, m]
             probs_label_gold = BK.gather_one_lastdim(
                 probs_labels, gold_labels_expr).squeeze(-1)  # [bs, m]
             # root and pad will be excluded later
             # Reward = \sum_i 1.*marginal(GEdge_i); while for global models, need to gradient on marginal-functions
             # todo(warn): have problem since steps will be quite small, not used!
             final_losses = (mask_expr - probs_head_gold * probs_label_gold
                             )  # let loss>=0
     elif self.norm_global:
         full_label_score = self._score_label_full(scoring_expr_pack,
                                                   mask_expr, training,
                                                   margin, gold_heads_expr,
                                                   gold_labels_expr)
         # for this one, use the merged full score
         full_score = full_arc_score.unsqueeze(
             -1) + full_label_score  # [BS, m, h, L]
         # +=1 to include ROOT for mst decoding
         mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts],
                                      dtype=np.int32)
         # do inference
         if self.loss_prob:
             marginals_expr = self._marginal(
                 full_score, mask_expr, mst_lengths_arr)  # [BS, m, h, L]
             final_losses = self._losses_global_prob(
                 full_score, gold_heads_expr, gold_labels_expr,
                 marginals_expr, mask_expr)
             if self.alg_proj:
                 # todo(+N): deal with search-error-like problem, discard unproj neg losses (score>weighted-avg),
                 #  but this might be too loose, although the unproj edges are few?
                 gold_unproj_arr, _ = self.predict_padder.pad(
                     [z.unprojs for z in annotated_insts])
                 gold_unproj_expr = BK.input_real(
                     gold_unproj_arr)  # [BS, Len]
                 comparing_expr = Constants.REAL_PRAC_MIN * (
                     1. - gold_unproj_expr)
                 final_losses = BK.max_elem(final_losses, comparing_expr)
         elif self.loss_hinge:
             pred_heads_arr, pred_labels_arr, _ = self._decode(
                 full_score, mask_expr, mst_lengths_arr)
             pred_heads_expr = BK.input_idx(pred_heads_arr)  # [BS, Len]
             pred_labels_expr = BK.input_idx(pred_labels_arr)  # [BS, Len]
             #
             final_losses = self._losses_global_hinge(
                 full_score, gold_heads_expr, gold_labels_expr,
                 pred_heads_expr, pred_labels_expr, mask_expr)
         elif self.loss_mr:
             # todo(+N): Loss = -Reward = \sum marginals, which requires gradients on marginal-one-edge, or marginal-two-edges
             raise NotImplementedError(
                 "Not implemented for global-loss + mr.")
     elif self.norm_hlocal:
         # firstly label losses are the same
         select_label_score = self._score_label_selected(
             scoring_expr_pack, mask_expr, training, margin,
             gold_heads_expr, gold_labels_expr)
         losses_labels = BK.loss_nll(select_label_score, gold_labels_expr)
         # then specially for arc loss
         children_masks_arr, _ = self.hlocal_padder.pad(
             [z.get_children_mask_arr() for z in annotated_insts])
         children_masks_expr = BK.input_real(
             children_masks_arr)  # [bs, h, m]
         # [bs, h]
         # todo(warn): use prod rather than sum, but still only an approximation for the top-down
         # losses_arc = -BK.log(BK.sum(BK.softmax(full_arc_score, -2).transpose(-1, -2) * children_masks_expr, dim=-1) + (1-mask_expr))
         losses_arc = -BK.sum(BK.log_softmax(full_arc_score, -2).transpose(
             -1, -2) * children_masks_expr,
                              dim=-1)
         # including the root-head is important
         losses_arc[:, 1] += losses_arc[:, 0]
         final_losses = losses_arc + losses_labels
     #
     # jpos loss? (the same mask as parsing)
     jpos_losses_expr = jpos_pack[1]
     if jpos_losses_expr is not None:
         final_losses += jpos_losses_expr
     # collect loss with mask, also excluding the first symbol of ROOT
     final_losses_masked = (final_losses * mask_expr)[:, 1:]
     final_loss_sum = BK.sum(final_losses_masked)
     # divide loss by what?
     num_sent = len(annotated_insts)
     num_valid_tok = sum(len(z) for z in annotated_insts)
     if self.conf.tconf.loss_div_tok:
         final_loss = final_loss_sum / num_valid_tok
     else:
         final_loss = final_loss_sum / num_sent
     #
     final_loss_sum_val = float(BK.get_value(final_loss_sum))
     info = {
         "sent": num_sent,
         "tok": num_valid_tok,
         "loss_sum": final_loss_sum_val
     }
     if training:
         BK.backward(final_loss, loss_factor)
     return info