示例#1
0
 def jpos_loss(self, jpos_pack, mask_expr):
     jpos_losses_expr = jpos_pack[1]
     if jpos_losses_expr is not None:
         # collect loss with mask, also excluding the first symbol of ROOT
         final_losses_masked = (jpos_losses_expr * mask_expr)[:, 1:]
         # todo(note): no need to scale lambda since already multiplied previously
         final_loss_sum = BK.sum(final_losses_masked)
         return final_loss_sum
     else:
         return None
示例#2
0
 def fb(self, annotated_insts, scoring_expr_pack, training: bool,
        loss_factor: float):
     # depth constrain: <= sched_depth
     cur_depth_constrain = int(self.sched_depth.value)
     # run
     ags = [
         BfsLinearAgenda.init_agenda(TdState, z, self.require_sg)
         for z in annotated_insts
     ]
     self.oracle_manager.refresh_insts(annotated_insts)
     self.searcher.refresh(scoring_expr_pack)
     self.searcher.go(ags)
     # collect local loss: credit assignment
     if self.train_force or self.train_ss:
         states = []
         for ag in ags:
             for final_state in ag.local_golds:
                 # todo(warn): remember to use depth_eff rather than depth
                 # todo(warn): deprecated
                 # if final_state.depth_eff > cur_depth_constrain:
                 #     continue
                 states.append(final_state)
         logprobs_arc = [s.arc_score_slice for s in states]
         # no labeling scores for reduce operations
         logprobs_label = [
             s.label_score_slice for s in states
             if s.label_score_slice is not None
         ]
         credits_arc, credits_label = None, None
     elif self.train_of:
         states = []
         for ag in ags:
             for final_state in ag.ends:
                 for s in final_state.get_path(True):
                     states.append(s)
         logprobs_arc = [s.arc_score_slice for s in states]
         # no labeling scores for reduce operations
         logprobs_label = [
             s.label_score_slice for s in states
             if s.label_score_slice is not None
         ]
         credits_arc, credits_label = None, None
     elif self.train_rl:
         logprobs_arc, logprobs_label, credits_arc, credits_label = [], [], [], []
         for ag in ags:
             # todo(+2): need to check search failure?
             # todo(+2): ignoring labels when reducing or wrong-arc
             for final_state in ag.ends:
                 # todo(warn): deprecated
                 # if final_state.depth_eff > cur_depth_constrain:
                 #     continue
                 one_credits_arc = []
                 one_credits_label = []
                 self.oracle_manager.set_losses(final_state)
                 for s in final_state.get_path(True):
                     _, _, delta_arc, delta_label = s.oracle_loss_cache
                     logprobs_arc.append(s.arc_score_slice)
                     if delta_arc > 0:
                         # only blame arc
                         one_credits_arc.append(-delta_arc)
                     else:
                         one_credits_arc.append(0)
                         if delta_label > 0:
                             logprobs_label.append(s.label_score_slice)
                             one_credits_label.append(-delta_label)
                         elif s.label_score_slice is not None:
                             # not bad labeling
                             logprobs_label.append(s.label_score_slice)
                             one_credits_label.append(0)
                 # TODO(+N): minus average may encourage bad moves?
                 # balance
                 # avg_arc = sum(one_credits_arc) / len(one_credits_arc)
                 # avg_label = 0. if len(one_credits_label)==0 else sum(one_credits_label) / len(one_credits_label)
                 baseline_arc = baseline_label = -0.5
                 credits_arc.extend(z - baseline_arc
                                    for z in one_credits_arc)
                 credits_label.extend(z - baseline_label
                                      for z in one_credits_label)
     else:
         raise NotImplementedError("CANNOT get here!")
     # sum all local losses
     loss_zero = BK.zeros([])
     if len(logprobs_arc) > 0:
         batched_logprobs_arc = SliceManager.combine_slices(
             logprobs_arc, None)
         loss_arc = (-BK.sum(batched_logprobs_arc)) if (credits_arc is None) \
             else (-BK.sum(batched_logprobs_arc * BK.input_real(credits_arc)))
     else:
         loss_arc = loss_zero
     if len(logprobs_label) > 0:
         batched_logprobs_label = SliceManager.combine_slices(
             logprobs_label, None)
         loss_label = (-BK.sum(batched_logprobs_label)) if (credits_label is None) \
             else (-BK.sum(batched_logprobs_label*BK.input_real(credits_label)))
     else:
         loss_label = loss_zero
     final_loss_sum = loss_arc + loss_label
     # divide loss by what?
     num_sent = len(annotated_insts)
     num_valid_arcs, num_valid_labels = len(logprobs_arc), len(
         logprobs_label)
     # num_valid_steps = len(states)
     if self.tconf.loss_div_step:
         final_loss = loss_arc / max(1, num_valid_arcs) + loss_label / max(
             1, num_valid_labels)
     else:
         final_loss = final_loss_sum / num_sent
     #
     val_loss_arc = BK.get_value(loss_arc).item()
     val_loss_label = BK.get_value(loss_label).item()
     val_loss_sum = val_loss_arc + val_loss_label
     #
     cur_has_loss = 1 if ((num_valid_arcs + num_valid_labels) > 0) else 0
     if training and cur_has_loss:
         BK.backward(final_loss, loss_factor)
     # todo(warn): make tok==steps for dividing in common.run
     info = {
         "sent": num_sent,
         "tok": num_valid_arcs,
         "valid_arc": num_valid_arcs,
         "valid_label": num_valid_labels,
         "loss_sum": val_loss_sum,
         "loss_arc": val_loss_arc,
         "loss_label": val_loss_label,
         "fb_all": 1,
         "fb_valid": cur_has_loss
     }
     return info
示例#3
0
 def fb_on_batch(self,
                 annotated_insts,
                 training=True,
                 loss_factor=1,
                 **kwargs):
     self.refresh_batch(training)
     margin = self.margin.value
     # gold heads and labels
     gold_heads_arr, _ = self.predict_padder.pad(
         [z.heads.vals for z in annotated_insts])
     gold_labels_arr, _ = self.predict_padder.pad(
         [self.real2pred_labels(z.labels.idxes) for z in annotated_insts])
     gold_heads_expr = BK.input_idx(gold_heads_arr)  # [BS, Len]
     gold_labels_expr = BK.input_idx(gold_labels_arr)  # [BS, Len]
     # ===== calculate
     scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score(
         annotated_insts, training)
     full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr,
                                           training, margin,
                                           gold_heads_expr)
     #
     final_losses = None
     if self.norm_local or self.norm_single:
         select_label_score = self._score_label_selected(
             scoring_expr_pack, mask_expr, training, margin,
             gold_heads_expr, gold_labels_expr)
         # already added margin previously
         losses_heads = losses_labels = None
         if self.loss_prob:
             if self.norm_local:
                 losses_heads = BK.loss_nll(full_arc_score, gold_heads_expr)
                 losses_labels = BK.loss_nll(select_label_score,
                                             gold_labels_expr)
             elif self.norm_single:
                 single_sample = self.conf.tconf.loss_single_sample
                 losses_heads = self._losses_single(full_arc_score,
                                                    gold_heads_expr,
                                                    single_sample,
                                                    is_hinge=False)
                 losses_labels = self._losses_single(select_label_score,
                                                     gold_labels_expr,
                                                     single_sample,
                                                     is_hinge=False)
             # simply adding
             final_losses = losses_heads + losses_labels
         elif self.loss_hinge:
             if self.norm_local:
                 losses_heads = BK.loss_hinge(full_arc_score,
                                              gold_heads_expr)
                 losses_labels = BK.loss_hinge(select_label_score,
                                               gold_labels_expr)
             elif self.norm_single:
                 single_sample = self.conf.tconf.loss_single_sample
                 losses_heads = self._losses_single(full_arc_score,
                                                    gold_heads_expr,
                                                    single_sample,
                                                    is_hinge=True,
                                                    margin=margin)
                 losses_labels = self._losses_single(select_label_score,
                                                     gold_labels_expr,
                                                     single_sample,
                                                     is_hinge=True,
                                                     margin=margin)
             # simply adding
             final_losses = losses_heads + losses_labels
         elif self.loss_mr:
             # special treatment!
             probs_heads = BK.softmax(full_arc_score, dim=-1)  # [bs, m, h]
             probs_labels = BK.softmax(select_label_score,
                                       dim=-1)  # [bs, m, h]
             # select
             probs_head_gold = BK.gather_one_lastdim(
                 probs_heads, gold_heads_expr).squeeze(-1)  # [bs, m]
             probs_label_gold = BK.gather_one_lastdim(
                 probs_labels, gold_labels_expr).squeeze(-1)  # [bs, m]
             # root and pad will be excluded later
             # Reward = \sum_i 1.*marginal(GEdge_i); while for global models, need to gradient on marginal-functions
             # todo(warn): have problem since steps will be quite small, not used!
             final_losses = (mask_expr - probs_head_gold * probs_label_gold
                             )  # let loss>=0
     elif self.norm_global:
         full_label_score = self._score_label_full(scoring_expr_pack,
                                                   mask_expr, training,
                                                   margin, gold_heads_expr,
                                                   gold_labels_expr)
         # for this one, use the merged full score
         full_score = full_arc_score.unsqueeze(
             -1) + full_label_score  # [BS, m, h, L]
         # +=1 to include ROOT for mst decoding
         mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts],
                                      dtype=np.int32)
         # do inference
         if self.loss_prob:
             marginals_expr = self._marginal(
                 full_score, mask_expr, mst_lengths_arr)  # [BS, m, h, L]
             final_losses = self._losses_global_prob(
                 full_score, gold_heads_expr, gold_labels_expr,
                 marginals_expr, mask_expr)
             if self.alg_proj:
                 # todo(+N): deal with search-error-like problem, discard unproj neg losses (score>weighted-avg),
                 #  but this might be too loose, although the unproj edges are few?
                 gold_unproj_arr, _ = self.predict_padder.pad(
                     [z.unprojs for z in annotated_insts])
                 gold_unproj_expr = BK.input_real(
                     gold_unproj_arr)  # [BS, Len]
                 comparing_expr = Constants.REAL_PRAC_MIN * (
                     1. - gold_unproj_expr)
                 final_losses = BK.max_elem(final_losses, comparing_expr)
         elif self.loss_hinge:
             pred_heads_arr, pred_labels_arr, _ = self._decode(
                 full_score, mask_expr, mst_lengths_arr)
             pred_heads_expr = BK.input_idx(pred_heads_arr)  # [BS, Len]
             pred_labels_expr = BK.input_idx(pred_labels_arr)  # [BS, Len]
             #
             final_losses = self._losses_global_hinge(
                 full_score, gold_heads_expr, gold_labels_expr,
                 pred_heads_expr, pred_labels_expr, mask_expr)
         elif self.loss_mr:
             # todo(+N): Loss = -Reward = \sum marginals, which requires gradients on marginal-one-edge, or marginal-two-edges
             raise NotImplementedError(
                 "Not implemented for global-loss + mr.")
     elif self.norm_hlocal:
         # firstly label losses are the same
         select_label_score = self._score_label_selected(
             scoring_expr_pack, mask_expr, training, margin,
             gold_heads_expr, gold_labels_expr)
         losses_labels = BK.loss_nll(select_label_score, gold_labels_expr)
         # then specially for arc loss
         children_masks_arr, _ = self.hlocal_padder.pad(
             [z.get_children_mask_arr() for z in annotated_insts])
         children_masks_expr = BK.input_real(
             children_masks_arr)  # [bs, h, m]
         # [bs, h]
         # todo(warn): use prod rather than sum, but still only an approximation for the top-down
         # losses_arc = -BK.log(BK.sum(BK.softmax(full_arc_score, -2).transpose(-1, -2) * children_masks_expr, dim=-1) + (1-mask_expr))
         losses_arc = -BK.sum(BK.log_softmax(full_arc_score, -2).transpose(
             -1, -2) * children_masks_expr,
                              dim=-1)
         # including the root-head is important
         losses_arc[:, 1] += losses_arc[:, 0]
         final_losses = losses_arc + losses_labels
     #
     # jpos loss? (the same mask as parsing)
     jpos_losses_expr = jpos_pack[1]
     if jpos_losses_expr is not None:
         final_losses += jpos_losses_expr
     # collect loss with mask, also excluding the first symbol of ROOT
     final_losses_masked = (final_losses * mask_expr)[:, 1:]
     final_loss_sum = BK.sum(final_losses_masked)
     # divide loss by what?
     num_sent = len(annotated_insts)
     num_valid_tok = sum(len(z) for z in annotated_insts)
     if self.conf.tconf.loss_div_tok:
         final_loss = final_loss_sum / num_valid_tok
     else:
         final_loss = final_loss_sum / num_sent
     #
     final_loss_sum_val = float(BK.get_value(final_loss_sum))
     info = {
         "sent": num_sent,
         "tok": num_valid_tok,
         "loss_sum": final_loss_sum_val
     }
     if training:
         BK.backward(final_loss, loss_factor)
     return info
示例#4
0
 def _loss(self,
           annotated_insts: List[ParseInstance],
           full_score_expr,
           mask_expr,
           valid_expr=None):
     bsize, max_len = BK.get_shape(mask_expr)
     # gold heads and labels
     gold_heads_arr, _ = self.predict_padder.pad(
         [z.heads.vals for z in annotated_insts])
     # todo(note): here use the original idx of label, no shift!
     gold_labels_arr, _ = self.predict_padder.pad(
         [z.labels.idxes for z in annotated_insts])
     gold_heads_expr = BK.input_idx(gold_heads_arr)  # [BS, Len]
     gold_labels_expr = BK.input_idx(gold_labels_arr)  # [BS, Len]
     #
     idxes_bs_expr = BK.arange_idx(bsize).unsqueeze(-1)
     idxes_m_expr = BK.arange_idx(max_len).unsqueeze(0)
     # scores for decoding or marginal
     margin = self.margin.value
     decoding_scores = full_score_expr.clone().detach()
     decoding_scores = self.scorer_helper.postprocess_scores(
         decoding_scores, mask_expr, margin, gold_heads_expr,
         gold_labels_expr)
     if self.loss_hinge:
         mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts],
                                      dtype=np.int32)
         pred_heads_expr, pred_labels_expr, _ = nmst_unproj(decoding_scores,
                                                            mask_expr,
                                                            mst_lengths_arr,
                                                            labeled=True,
                                                            ret_arr=False)
         # ===== add margin*cost, [bs, len]
         gold_final_scores = full_score_expr[idxes_bs_expr, idxes_m_expr,
                                             gold_heads_expr,
                                             gold_labels_expr]
         pred_final_scores = full_score_expr[
             idxes_bs_expr, idxes_m_expr, pred_heads_expr,
             pred_labels_expr] + margin * (
                 gold_heads_expr != pred_heads_expr).float() + margin * (
                     gold_labels_expr !=
                     pred_labels_expr).float()  # plus margin
         hinge_losses = pred_final_scores - gold_final_scores
         valid_losses = ((hinge_losses * mask_expr)[:, 1:].sum(-1) >
                         0.).float().unsqueeze(-1)  # [*, 1]
         final_losses = hinge_losses * valid_losses
     else:
         lab_marginals = nmarginal_unproj(decoding_scores,
                                          mask_expr,
                                          None,
                                          labeled=True)
         lab_marginals[idxes_bs_expr, idxes_m_expr, gold_heads_expr,
                       gold_labels_expr] -= 1.
         grads_masked = lab_marginals * mask_expr.unsqueeze(-1).unsqueeze(
             -1) * mask_expr.unsqueeze(-2).unsqueeze(-1)
         final_losses = (full_score_expr * grads_masked).sum(-1).sum(
             -1)  # [bs, m]
     # divide loss by what?
     num_sent = len(annotated_insts)
     num_valid_tok = sum(len(z) for z in annotated_insts)
     # exclude non-valid ones: there can be pruning error
     if valid_expr is not None:
         final_valids = valid_expr[idxes_bs_expr, idxes_m_expr,
                                   gold_heads_expr]  # [bs, m] of (0. or 1.)
         final_losses = final_losses * final_valids
         tok_valid = float(BK.get_value(final_valids[:, 1:].sum()))
         assert tok_valid <= num_valid_tok
         tok_prune_err = num_valid_tok - tok_valid
     else:
         tok_prune_err = 0
     # collect loss with mask, also excluding the first symbol of ROOT
     final_losses_masked = (final_losses * mask_expr)[:, 1:]
     final_loss_sum = BK.sum(final_losses_masked)
     if self.conf.tconf.loss_div_tok:
         final_loss = final_loss_sum / num_valid_tok
     else:
         final_loss = final_loss_sum / num_sent
     final_loss_sum_val = float(BK.get_value(final_loss_sum))
     info = {
         "sent": num_sent,
         "tok": num_valid_tok,
         "tok_prune_err": tok_prune_err,
         "loss_sum": final_loss_sum_val
     }
     return final_loss, info