def jpos_loss(self, jpos_pack, mask_expr): jpos_losses_expr = jpos_pack[1] if jpos_losses_expr is not None: # collect loss with mask, also excluding the first symbol of ROOT final_losses_masked = (jpos_losses_expr * mask_expr)[:, 1:] # todo(note): no need to scale lambda since already multiplied previously final_loss_sum = BK.sum(final_losses_masked) return final_loss_sum else: return None
def fb(self, annotated_insts, scoring_expr_pack, training: bool, loss_factor: float): # depth constrain: <= sched_depth cur_depth_constrain = int(self.sched_depth.value) # run ags = [ BfsLinearAgenda.init_agenda(TdState, z, self.require_sg) for z in annotated_insts ] self.oracle_manager.refresh_insts(annotated_insts) self.searcher.refresh(scoring_expr_pack) self.searcher.go(ags) # collect local loss: credit assignment if self.train_force or self.train_ss: states = [] for ag in ags: for final_state in ag.local_golds: # todo(warn): remember to use depth_eff rather than depth # todo(warn): deprecated # if final_state.depth_eff > cur_depth_constrain: # continue states.append(final_state) logprobs_arc = [s.arc_score_slice for s in states] # no labeling scores for reduce operations logprobs_label = [ s.label_score_slice for s in states if s.label_score_slice is not None ] credits_arc, credits_label = None, None elif self.train_of: states = [] for ag in ags: for final_state in ag.ends: for s in final_state.get_path(True): states.append(s) logprobs_arc = [s.arc_score_slice for s in states] # no labeling scores for reduce operations logprobs_label = [ s.label_score_slice for s in states if s.label_score_slice is not None ] credits_arc, credits_label = None, None elif self.train_rl: logprobs_arc, logprobs_label, credits_arc, credits_label = [], [], [], [] for ag in ags: # todo(+2): need to check search failure? # todo(+2): ignoring labels when reducing or wrong-arc for final_state in ag.ends: # todo(warn): deprecated # if final_state.depth_eff > cur_depth_constrain: # continue one_credits_arc = [] one_credits_label = [] self.oracle_manager.set_losses(final_state) for s in final_state.get_path(True): _, _, delta_arc, delta_label = s.oracle_loss_cache logprobs_arc.append(s.arc_score_slice) if delta_arc > 0: # only blame arc one_credits_arc.append(-delta_arc) else: one_credits_arc.append(0) if delta_label > 0: logprobs_label.append(s.label_score_slice) one_credits_label.append(-delta_label) elif s.label_score_slice is not None: # not bad labeling logprobs_label.append(s.label_score_slice) one_credits_label.append(0) # TODO(+N): minus average may encourage bad moves? # balance # avg_arc = sum(one_credits_arc) / len(one_credits_arc) # avg_label = 0. if len(one_credits_label)==0 else sum(one_credits_label) / len(one_credits_label) baseline_arc = baseline_label = -0.5 credits_arc.extend(z - baseline_arc for z in one_credits_arc) credits_label.extend(z - baseline_label for z in one_credits_label) else: raise NotImplementedError("CANNOT get here!") # sum all local losses loss_zero = BK.zeros([]) if len(logprobs_arc) > 0: batched_logprobs_arc = SliceManager.combine_slices( logprobs_arc, None) loss_arc = (-BK.sum(batched_logprobs_arc)) if (credits_arc is None) \ else (-BK.sum(batched_logprobs_arc * BK.input_real(credits_arc))) else: loss_arc = loss_zero if len(logprobs_label) > 0: batched_logprobs_label = SliceManager.combine_slices( logprobs_label, None) loss_label = (-BK.sum(batched_logprobs_label)) if (credits_label is None) \ else (-BK.sum(batched_logprobs_label*BK.input_real(credits_label))) else: loss_label = loss_zero final_loss_sum = loss_arc + loss_label # divide loss by what? num_sent = len(annotated_insts) num_valid_arcs, num_valid_labels = len(logprobs_arc), len( logprobs_label) # num_valid_steps = len(states) if self.tconf.loss_div_step: final_loss = loss_arc / max(1, num_valid_arcs) + loss_label / max( 1, num_valid_labels) else: final_loss = final_loss_sum / num_sent # val_loss_arc = BK.get_value(loss_arc).item() val_loss_label = BK.get_value(loss_label).item() val_loss_sum = val_loss_arc + val_loss_label # cur_has_loss = 1 if ((num_valid_arcs + num_valid_labels) > 0) else 0 if training and cur_has_loss: BK.backward(final_loss, loss_factor) # todo(warn): make tok==steps for dividing in common.run info = { "sent": num_sent, "tok": num_valid_arcs, "valid_arc": num_valid_arcs, "valid_label": num_valid_labels, "loss_sum": val_loss_sum, "loss_arc": val_loss_arc, "loss_label": val_loss_label, "fb_all": 1, "fb_valid": cur_has_loss } return info
def fb_on_batch(self, annotated_insts, training=True, loss_factor=1, **kwargs): self.refresh_batch(training) margin = self.margin.value # gold heads and labels gold_heads_arr, _ = self.predict_padder.pad( [z.heads.vals for z in annotated_insts]) gold_labels_arr, _ = self.predict_padder.pad( [self.real2pred_labels(z.labels.idxes) for z in annotated_insts]) gold_heads_expr = BK.input_idx(gold_heads_arr) # [BS, Len] gold_labels_expr = BK.input_idx(gold_labels_arr) # [BS, Len] # ===== calculate scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score( annotated_insts, training) full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr, training, margin, gold_heads_expr) # final_losses = None if self.norm_local or self.norm_single: select_label_score = self._score_label_selected( scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) # already added margin previously losses_heads = losses_labels = None if self.loss_prob: if self.norm_local: losses_heads = BK.loss_nll(full_arc_score, gold_heads_expr) losses_labels = BK.loss_nll(select_label_score, gold_labels_expr) elif self.norm_single: single_sample = self.conf.tconf.loss_single_sample losses_heads = self._losses_single(full_arc_score, gold_heads_expr, single_sample, is_hinge=False) losses_labels = self._losses_single(select_label_score, gold_labels_expr, single_sample, is_hinge=False) # simply adding final_losses = losses_heads + losses_labels elif self.loss_hinge: if self.norm_local: losses_heads = BK.loss_hinge(full_arc_score, gold_heads_expr) losses_labels = BK.loss_hinge(select_label_score, gold_labels_expr) elif self.norm_single: single_sample = self.conf.tconf.loss_single_sample losses_heads = self._losses_single(full_arc_score, gold_heads_expr, single_sample, is_hinge=True, margin=margin) losses_labels = self._losses_single(select_label_score, gold_labels_expr, single_sample, is_hinge=True, margin=margin) # simply adding final_losses = losses_heads + losses_labels elif self.loss_mr: # special treatment! probs_heads = BK.softmax(full_arc_score, dim=-1) # [bs, m, h] probs_labels = BK.softmax(select_label_score, dim=-1) # [bs, m, h] # select probs_head_gold = BK.gather_one_lastdim( probs_heads, gold_heads_expr).squeeze(-1) # [bs, m] probs_label_gold = BK.gather_one_lastdim( probs_labels, gold_labels_expr).squeeze(-1) # [bs, m] # root and pad will be excluded later # Reward = \sum_i 1.*marginal(GEdge_i); while for global models, need to gradient on marginal-functions # todo(warn): have problem since steps will be quite small, not used! final_losses = (mask_expr - probs_head_gold * probs_label_gold ) # let loss>=0 elif self.norm_global: full_label_score = self._score_label_full(scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) # for this one, use the merged full score full_score = full_arc_score.unsqueeze( -1) + full_label_score # [BS, m, h, L] # +=1 to include ROOT for mst decoding mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts], dtype=np.int32) # do inference if self.loss_prob: marginals_expr = self._marginal( full_score, mask_expr, mst_lengths_arr) # [BS, m, h, L] final_losses = self._losses_global_prob( full_score, gold_heads_expr, gold_labels_expr, marginals_expr, mask_expr) if self.alg_proj: # todo(+N): deal with search-error-like problem, discard unproj neg losses (score>weighted-avg), # but this might be too loose, although the unproj edges are few? gold_unproj_arr, _ = self.predict_padder.pad( [z.unprojs for z in annotated_insts]) gold_unproj_expr = BK.input_real( gold_unproj_arr) # [BS, Len] comparing_expr = Constants.REAL_PRAC_MIN * ( 1. - gold_unproj_expr) final_losses = BK.max_elem(final_losses, comparing_expr) elif self.loss_hinge: pred_heads_arr, pred_labels_arr, _ = self._decode( full_score, mask_expr, mst_lengths_arr) pred_heads_expr = BK.input_idx(pred_heads_arr) # [BS, Len] pred_labels_expr = BK.input_idx(pred_labels_arr) # [BS, Len] # final_losses = self._losses_global_hinge( full_score, gold_heads_expr, gold_labels_expr, pred_heads_expr, pred_labels_expr, mask_expr) elif self.loss_mr: # todo(+N): Loss = -Reward = \sum marginals, which requires gradients on marginal-one-edge, or marginal-two-edges raise NotImplementedError( "Not implemented for global-loss + mr.") elif self.norm_hlocal: # firstly label losses are the same select_label_score = self._score_label_selected( scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) losses_labels = BK.loss_nll(select_label_score, gold_labels_expr) # then specially for arc loss children_masks_arr, _ = self.hlocal_padder.pad( [z.get_children_mask_arr() for z in annotated_insts]) children_masks_expr = BK.input_real( children_masks_arr) # [bs, h, m] # [bs, h] # todo(warn): use prod rather than sum, but still only an approximation for the top-down # losses_arc = -BK.log(BK.sum(BK.softmax(full_arc_score, -2).transpose(-1, -2) * children_masks_expr, dim=-1) + (1-mask_expr)) losses_arc = -BK.sum(BK.log_softmax(full_arc_score, -2).transpose( -1, -2) * children_masks_expr, dim=-1) # including the root-head is important losses_arc[:, 1] += losses_arc[:, 0] final_losses = losses_arc + losses_labels # # jpos loss? (the same mask as parsing) jpos_losses_expr = jpos_pack[1] if jpos_losses_expr is not None: final_losses += jpos_losses_expr # collect loss with mask, also excluding the first symbol of ROOT final_losses_masked = (final_losses * mask_expr)[:, 1:] final_loss_sum = BK.sum(final_losses_masked) # divide loss by what? num_sent = len(annotated_insts) num_valid_tok = sum(len(z) for z in annotated_insts) if self.conf.tconf.loss_div_tok: final_loss = final_loss_sum / num_valid_tok else: final_loss = final_loss_sum / num_sent # final_loss_sum_val = float(BK.get_value(final_loss_sum)) info = { "sent": num_sent, "tok": num_valid_tok, "loss_sum": final_loss_sum_val } if training: BK.backward(final_loss, loss_factor) return info
def _loss(self, annotated_insts: List[ParseInstance], full_score_expr, mask_expr, valid_expr=None): bsize, max_len = BK.get_shape(mask_expr) # gold heads and labels gold_heads_arr, _ = self.predict_padder.pad( [z.heads.vals for z in annotated_insts]) # todo(note): here use the original idx of label, no shift! gold_labels_arr, _ = self.predict_padder.pad( [z.labels.idxes for z in annotated_insts]) gold_heads_expr = BK.input_idx(gold_heads_arr) # [BS, Len] gold_labels_expr = BK.input_idx(gold_labels_arr) # [BS, Len] # idxes_bs_expr = BK.arange_idx(bsize).unsqueeze(-1) idxes_m_expr = BK.arange_idx(max_len).unsqueeze(0) # scores for decoding or marginal margin = self.margin.value decoding_scores = full_score_expr.clone().detach() decoding_scores = self.scorer_helper.postprocess_scores( decoding_scores, mask_expr, margin, gold_heads_expr, gold_labels_expr) if self.loss_hinge: mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts], dtype=np.int32) pred_heads_expr, pred_labels_expr, _ = nmst_unproj(decoding_scores, mask_expr, mst_lengths_arr, labeled=True, ret_arr=False) # ===== add margin*cost, [bs, len] gold_final_scores = full_score_expr[idxes_bs_expr, idxes_m_expr, gold_heads_expr, gold_labels_expr] pred_final_scores = full_score_expr[ idxes_bs_expr, idxes_m_expr, pred_heads_expr, pred_labels_expr] + margin * ( gold_heads_expr != pred_heads_expr).float() + margin * ( gold_labels_expr != pred_labels_expr).float() # plus margin hinge_losses = pred_final_scores - gold_final_scores valid_losses = ((hinge_losses * mask_expr)[:, 1:].sum(-1) > 0.).float().unsqueeze(-1) # [*, 1] final_losses = hinge_losses * valid_losses else: lab_marginals = nmarginal_unproj(decoding_scores, mask_expr, None, labeled=True) lab_marginals[idxes_bs_expr, idxes_m_expr, gold_heads_expr, gold_labels_expr] -= 1. grads_masked = lab_marginals * mask_expr.unsqueeze(-1).unsqueeze( -1) * mask_expr.unsqueeze(-2).unsqueeze(-1) final_losses = (full_score_expr * grads_masked).sum(-1).sum( -1) # [bs, m] # divide loss by what? num_sent = len(annotated_insts) num_valid_tok = sum(len(z) for z in annotated_insts) # exclude non-valid ones: there can be pruning error if valid_expr is not None: final_valids = valid_expr[idxes_bs_expr, idxes_m_expr, gold_heads_expr] # [bs, m] of (0. or 1.) final_losses = final_losses * final_valids tok_valid = float(BK.get_value(final_valids[:, 1:].sum())) assert tok_valid <= num_valid_tok tok_prune_err = num_valid_tok - tok_valid else: tok_prune_err = 0 # collect loss with mask, also excluding the first symbol of ROOT final_losses_masked = (final_losses * mask_expr)[:, 1:] final_loss_sum = BK.sum(final_losses_masked) if self.conf.tconf.loss_div_tok: final_loss = final_loss_sum / num_valid_tok else: final_loss = final_loss_sum / num_sent final_loss_sum_val = float(BK.get_value(final_loss_sum)) info = { "sent": num_sent, "tok": num_valid_tok, "tok_prune_err": tok_prune_err, "loss_sum": final_loss_sum_val } return final_loss, info