def _my_loss_prob(self, score_expr, gold_idxes_expr, entropy_lambda: float, loss_mask, neg_reweight: bool): probs = BK.softmax(score_expr, -1) # [*, NLab] log_probs = BK.log(probs + 1e-8) # first plain NLL loss nll_loss = -BK.gather_one_lastdim(log_probs, gold_idxes_expr).squeeze(-1) # next the special loss if entropy_lambda > 0.: negative_entropy = probs * log_probs # [*, NLab] last_dim = BK.get_shape(score_expr, -1) confusion_matrix = 1. - BK.eye(last_dim) # [Nlab, Nlab] entropy_mask = confusion_matrix[gold_idxes_expr] # [*, Nlab] entropy_loss = (negative_entropy * entropy_mask).sum(-1) final_loss = nll_loss + entropy_lambda * entropy_loss else: final_loss = nll_loss # reweight? if neg_reweight: golden_prob = BK.gather_one_lastdim(probs, gold_idxes_expr).squeeze(-1) is_full_nil = (gold_idxes_expr == 0.).float() not_full_nil = 1. - is_full_nil count_pos = (loss_mask * not_full_nil).sum() count_neg = (loss_mask * is_full_nil).sum() prob_pos = (loss_mask * not_full_nil * golden_prob).sum() prob_neg = (loss_mask * is_full_nil * golden_prob).sum() neg_weight = prob_pos / (count_pos + count_neg - prob_neg + 1e-8) final_weights = not_full_nil + is_full_nil * neg_weight # todo(note): final mask will be applied at outside final_loss = final_loss * final_weights return final_loss
def get_losses_global_hinge(full_score_expr, gold_heads_expr, gold_labels_expr, pred_heads_expr, pred_labels_expr, mask_expr, clamping=True): # combine the last two dimension full_shape = BK.get_shape(full_score_expr) # [*, m, h*L] last_size = full_shape[-1] combiend_score_expr = full_score_expr.view(full_shape[:-2] + [-1]) # [*, m] gold_combined_idx_expr = gold_heads_expr * last_size + gold_labels_expr pred_combined_idx_expr = pred_heads_expr * last_size + pred_labels_expr # [*, m] gold_scores = BK.gather_one_lastdim(combiend_score_expr, gold_combined_idx_expr).squeeze(-1) pred_scores = BK.gather_one_lastdim(combiend_score_expr, pred_combined_idx_expr).squeeze(-1) # todo(warn): be aware of search error! # hinge_losses = BK.clamp(pred_scores - gold_scores, min=0.) # this is previous version hinge_losses = pred_scores - gold_scores # [*, len] if clamping: valid_losses = ((hinge_losses * mask_expr)[:, 1:].sum(-1) > 0.).float().unsqueeze(-1) # [*, 1] return hinge_losses * valid_losses else: # for this mode, will there be problems of search error? Maybe rare. return hinge_losses
def _loss(self, enc_repr, action_list: List[EfAction], arc_weight_list: List[float], label_weight_list: List[float], bidxes_list: List[int]): # 1. collect (batched) features; todo(note): use prev state for scoring hm_features = self.hm_feature_getter.get_hm_features(action_list, [a.state_from for a in action_list]) # 2. get new sreprs scorer = self.scorer s_enc = self.slayer bsize_range_t = BK.input_idx(bidxes_list) node_h_idxes_t, node_h_srepr = ScorerHelper.calc_repr(s_enc, hm_features[0], enc_repr, bsize_range_t) node_m_idxes_t, node_m_srepr = ScorerHelper.calc_repr(s_enc, hm_features[1], enc_repr, bsize_range_t) # label loss if self.system_labeled: node_lh_expr, _ = scorer.transform_space_label(node_h_srepr, True, False) _, node_lm_pack = scorer.transform_space_label(node_m_srepr, False, True) label_scores_full = scorer.score_label(node_lm_pack, node_lh_expr) # [*, Lab] label_scores = BK.gather_one_lastdim(label_scores_full, [a.label for a in action_list]).squeeze(-1) final_label_loss_sum = (label_scores * BK.input_real(label_weight_list)).sum() else: label_scores = final_label_loss_sum = BK.zeros([]) # arc loss node_ah_expr, _ = scorer.transform_space_arc(node_h_srepr, True, False) _, node_am_pack = scorer.transform_space_arc(node_m_srepr, False, True) arc_scores = scorer.score_arc(node_am_pack, node_ah_expr).squeeze(-1) final_arc_loss_sum = (arc_scores * BK.input_real(arc_weight_list)).sum() # score reg return final_arc_loss_sum, final_label_loss_sum, arc_scores, label_scores
def __call__(self, input_repr, mask_arr, require_loss, require_pred, gold_pos_arr=None): enc0_expr = self.enc(input_repr, mask_arr) # [*, len, d] # enc1_expr = enc0_expr pos_probs, pos_losses_expr, pos_preds_expr = None, None, None if self.jpos_multitask: # get probabilities pos_logits = self.pred(enc0_expr) # [*, len, nl] pos_probs = BK.softmax(pos_logits, dim=-1) # stacking for input -> output if self.jpos_stacking: enc1_expr = enc0_expr + BK.matmul(pos_probs, self.pos_weights) # simple cross entropy loss if require_loss and self.jpos_lambda > 0.: gold_probs = BK.gather_one_lastdim( pos_probs, gold_pos_arr).squeeze(-1) # [*, len] # todo(warn): multiplying the factor here, but not maksing here (masking in the final steps) pos_losses_expr = (-self.jpos_lambda) * gold_probs.log() # simple argmax for prediction if require_pred and self.jpos_decode: pos_preds_expr = pos_probs.max(dim=-1)[1] return enc1_expr, (pos_probs, pos_losses_expr, pos_preds_expr)
def loss(self, ms_items: List, bert_expr, basic_expr, margin=0.): conf = self.conf bsize = len(ms_items) # build targets (include all sents) # todo(note): use "x.entity_fillers" for getting gold args offsets_t, masks_t, _, items_arr, labels_t = PrepHelper.prep_targets( ms_items, lambda x: x.entity_fillers, True, True, conf.train_neg_rate, conf.train_neg_rate_outside, True) labels_t.clamp_(max=1) # either 0 or 1 # ----- # return 0 if all no targets if BK.get_shape(offsets_t, -1) == 0: zzz = BK.zeros([]) return [[zzz, zzz, zzz]] # ----- arange_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] sel_bert_t = bert_expr[arange_t, offsets_t] # [bsize, ?, Fold, D] sel_basic_t = None if basic_expr is None else basic_expr[ arange_t, offsets_t] # [bsize, ?, D'] hiddens = self.adp(sel_bert_t, sel_basic_t, []) # [bsize, ?, D"] # build loss logits = self.predictor(hiddens) # [bsize, ?, Out] log_probs = BK.log_softmax(logits, -1) picked_log_probs = -BK.gather_one_lastdim(log_probs, labels_t).squeeze( -1) # [bsize, ?] masked_losses = picked_log_probs * masks_t # loss_sum, loss_count, gold_count return [[ masked_losses.sum(), masks_t.sum(), (labels_t > 0).float().sum() ]]
def _get_one_loss(self, predictor, hidden_t, labels_t, masks_t, lambda_loss): logits = predictor(hidden_t) # [bsize, ?, Out] log_probs = BK.log_softmax(logits, -1) picked_neg_log_probs = -BK.gather_one_lastdim( log_probs, labels_t).squeeze(-1) # [bsize, ?] masked_losses = picked_neg_log_probs * masks_t # loss_sum, loss_count, gold_count(only for type) return [ masked_losses.sum() * lambda_loss, masks_t.sum(), (labels_t > 0).float().sum() ]
def inference_on_batch(self, insts: List[ParseInstance], **kwargs): # iconf = self.conf.iconf with BK.no_grad_env(): self.refresh_batch(False) # pruning and scores from g1 valid_mask, go1_pack = self._get_g1_pack( insts, self.lambda_g1_arc_testing, self.lambda_g1_lab_testing) # encode input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( insts, False) mask_expr = BK.input_real(mask_arr) # decode final_valid_expr = self._make_final_valid(valid_mask, mask_expr) ret_heads, ret_labels, _, _ = self.dl.decode( insts, enc_repr, final_valid_expr, go1_pack, False, 0.) # collect the results together all_heads = Helper.join_list(ret_heads) if ret_labels is None: # todo(note): simply get labels from the go1-label classifier; must provide g1parser if go1_pack is None: _, go1_pack = self._get_g1_pack(insts, 1., 1.) _, go1_label_max_idxes = go1_pack[1].max( -1) # [bs, slen, slen] pred_heads_arr, _ = self.predict_padder.pad( all_heads) # [bs, slen] pred_heads_expr = BK.input_idx(pred_heads_arr) pred_labels_expr = BK.gather_one_lastdim( go1_label_max_idxes, pred_heads_expr).squeeze(-1) all_labels = BK.get_value(pred_labels_expr) # [bs, slen] else: all_labels = np.concatenate(ret_labels, 0) # ===== assign, todo(warn): here, the labels are directly original idx, no need to change for one_idx, one_inst in enumerate(insts): cur_length = len(one_inst) + 1 one_inst.pred_heads.set_vals( all_heads[one_idx] [:cur_length]) # directly int-val for heads one_inst.pred_labels.build_vals( all_labels[one_idx][:cur_length], self.label_vocab) # one_inst.pred_par_scores.set_vals(all_scores[one_idx][:cur_length]) # ===== # put jpos result (possibly) self.jpos_decode(insts, jpos_pack) # ----- info = {"sent": len(insts), "tok": sum(map(len, insts))} return info
def _common_nmst(CPU_f, scores_expr, mask_expr, lengths_arr, labeled, ret_arr): assert labeled with BK.no_grad_env(): # argmax-label: [BS, m, h] scores_unlabeled_max, labels_argmax = scores_expr.max(-1) # scores_unlabeled_max_arr = BK.get_value(scores_unlabeled_max) mst_heads_arr, _, mst_scores_arr = CPU_f(scores_unlabeled_max_arr, lengths_arr, labeled=False) # [BS, m] mst_heads_expr = BK.input_idx(mst_heads_arr) mst_labels_expr = BK.gather_one_lastdim(labels_argmax, mst_heads_expr).squeeze(-1) # prepare for the outputs if ret_arr: return mst_heads_arr, BK.get_value(mst_labels_expr), mst_scores_arr else: return mst_heads_expr, mst_labels_expr, BK.input_real( mst_scores_arr)
def loss(self, insts: List[ParseInstance], enc_expr, final_valid_expr, go1_pack, training: bool, margin: float): # first do decoding and related preparation with BK.no_grad_env(): _, _, g_packs, p_packs = self.decode(insts, enc_expr, final_valid_expr, go1_pack, training, margin) # flatten the packs (remember to rebase the indexes) gold_pack = self._flatten_packs(g_packs) pred_pack = self._flatten_packs(p_packs) if self.filter_pruned: # filter out non-valid (pruned) edges, to avoid prune error mod_unpruned_mask, gold_mask = self.helper.get_unpruned_mask( final_valid_expr, gold_pack) pred_mask = mod_unpruned_mask[ pred_pack[0], pred_pack[1]] # filter by specific mod gold_pack = [(None if z is None else z[gold_mask]) for z in gold_pack] pred_pack = [(None if z is None else z[pred_mask]) for z in pred_pack] # calculate the scores for loss gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes, gold_lab_idxes = gold_pack pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes, pred_lab_idxes = pred_pack gold_arc_score, gold_label_score_all = self._get_basic_score( enc_expr, gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes) pred_arc_score, pred_label_score_all = self._get_basic_score( enc_expr, pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes) # whether have labeled scores if self.system_labeled: gold_label_score = BK.gather_one_lastdim( gold_label_score_all, gold_lab_idxes).squeeze(-1) pred_label_score = BK.gather_one_lastdim( pred_label_score_all, pred_lab_idxes).squeeze(-1) ret_scores = (gold_arc_score, pred_arc_score, gold_label_score, pred_label_score) pred_full_scores, gold_full_scores = pred_arc_score + pred_label_score, gold_arc_score + gold_label_score else: ret_scores = (gold_arc_score, pred_arc_score) pred_full_scores, gold_full_scores = pred_arc_score, gold_arc_score # hinge loss: filter-margin by loss*margin to be aware of search error if self.filter_margin: with BK.no_grad_env(): mat_shape = BK.get_shape(enc_expr)[:2] # [bs, slen] heads_gold = self._get_tmp_mat(mat_shape, 0, BK.int64, gold_b_idxes, gold_m_idxes, gold_h_idxes) heads_pred = self._get_tmp_mat(mat_shape, 0, BK.int64, pred_b_idxes, pred_m_idxes, pred_h_idxes) error_count = (heads_gold != heads_pred).float() if self.system_labeled: labels_gold = self._get_tmp_mat(mat_shape, 0, BK.int64, gold_b_idxes, gold_m_idxes, gold_lab_idxes) labels_pred = self._get_tmp_mat(mat_shape, 0, BK.int64, pred_b_idxes, pred_m_idxes, pred_lab_idxes) error_count += (labels_gold != labels_pred).float() scores_gold = self._get_tmp_mat(mat_shape, 0., BK.float32, gold_b_idxes, gold_m_idxes, gold_full_scores) scores_pred = self._get_tmp_mat(mat_shape, 0., BK.float32, pred_b_idxes, pred_m_idxes, pred_full_scores) # todo(note): here, a small 0.1 is to exclude zero error: anyway they will get zero gradient sent_mask = ((scores_gold.sum(-1) - scores_pred.sum(-1)) <= (margin * error_count.sum(-1) + 0.1)).float() num_valid_sent = float(BK.get_value(sent_mask.sum())) final_loss_sum = ( pred_full_scores * sent_mask[pred_b_idxes] - gold_full_scores * sent_mask[gold_b_idxes]).sum() else: num_valid_sent = len(insts) final_loss_sum = (pred_full_scores - gold_full_scores).sum() # prepare final loss # divide loss by what? num_sent = len(insts) num_valid_tok = sum(len(z) for z in insts) if self.loss_div_tok: final_loss = final_loss_sum / num_valid_tok else: final_loss = final_loss_sum / num_sent final_loss_sum_val = float(BK.get_value(final_loss_sum)) info = { "sent": num_sent, "sent_valid": num_valid_sent, "tok": num_valid_tok, "loss_sum": final_loss_sum_val } return final_loss, ret_scores, info
def fb_on_batch(self, annotated_insts, training=True, loss_factor=1, **kwargs): self.refresh_batch(training) margin = self.margin.value # gold heads and labels gold_heads_arr, _ = self.predict_padder.pad( [z.heads.vals for z in annotated_insts]) gold_labels_arr, _ = self.predict_padder.pad( [self.real2pred_labels(z.labels.idxes) for z in annotated_insts]) gold_heads_expr = BK.input_idx(gold_heads_arr) # [BS, Len] gold_labels_expr = BK.input_idx(gold_labels_arr) # [BS, Len] # ===== calculate scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score( annotated_insts, training) full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr, training, margin, gold_heads_expr) # final_losses = None if self.norm_local or self.norm_single: select_label_score = self._score_label_selected( scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) # already added margin previously losses_heads = losses_labels = None if self.loss_prob: if self.norm_local: losses_heads = BK.loss_nll(full_arc_score, gold_heads_expr) losses_labels = BK.loss_nll(select_label_score, gold_labels_expr) elif self.norm_single: single_sample = self.conf.tconf.loss_single_sample losses_heads = self._losses_single(full_arc_score, gold_heads_expr, single_sample, is_hinge=False) losses_labels = self._losses_single(select_label_score, gold_labels_expr, single_sample, is_hinge=False) # simply adding final_losses = losses_heads + losses_labels elif self.loss_hinge: if self.norm_local: losses_heads = BK.loss_hinge(full_arc_score, gold_heads_expr) losses_labels = BK.loss_hinge(select_label_score, gold_labels_expr) elif self.norm_single: single_sample = self.conf.tconf.loss_single_sample losses_heads = self._losses_single(full_arc_score, gold_heads_expr, single_sample, is_hinge=True, margin=margin) losses_labels = self._losses_single(select_label_score, gold_labels_expr, single_sample, is_hinge=True, margin=margin) # simply adding final_losses = losses_heads + losses_labels elif self.loss_mr: # special treatment! probs_heads = BK.softmax(full_arc_score, dim=-1) # [bs, m, h] probs_labels = BK.softmax(select_label_score, dim=-1) # [bs, m, h] # select probs_head_gold = BK.gather_one_lastdim( probs_heads, gold_heads_expr).squeeze(-1) # [bs, m] probs_label_gold = BK.gather_one_lastdim( probs_labels, gold_labels_expr).squeeze(-1) # [bs, m] # root and pad will be excluded later # Reward = \sum_i 1.*marginal(GEdge_i); while for global models, need to gradient on marginal-functions # todo(warn): have problem since steps will be quite small, not used! final_losses = (mask_expr - probs_head_gold * probs_label_gold ) # let loss>=0 elif self.norm_global: full_label_score = self._score_label_full(scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) # for this one, use the merged full score full_score = full_arc_score.unsqueeze( -1) + full_label_score # [BS, m, h, L] # +=1 to include ROOT for mst decoding mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts], dtype=np.int32) # do inference if self.loss_prob: marginals_expr = self._marginal( full_score, mask_expr, mst_lengths_arr) # [BS, m, h, L] final_losses = self._losses_global_prob( full_score, gold_heads_expr, gold_labels_expr, marginals_expr, mask_expr) if self.alg_proj: # todo(+N): deal with search-error-like problem, discard unproj neg losses (score>weighted-avg), # but this might be too loose, although the unproj edges are few? gold_unproj_arr, _ = self.predict_padder.pad( [z.unprojs for z in annotated_insts]) gold_unproj_expr = BK.input_real( gold_unproj_arr) # [BS, Len] comparing_expr = Constants.REAL_PRAC_MIN * ( 1. - gold_unproj_expr) final_losses = BK.max_elem(final_losses, comparing_expr) elif self.loss_hinge: pred_heads_arr, pred_labels_arr, _ = self._decode( full_score, mask_expr, mst_lengths_arr) pred_heads_expr = BK.input_idx(pred_heads_arr) # [BS, Len] pred_labels_expr = BK.input_idx(pred_labels_arr) # [BS, Len] # final_losses = self._losses_global_hinge( full_score, gold_heads_expr, gold_labels_expr, pred_heads_expr, pred_labels_expr, mask_expr) elif self.loss_mr: # todo(+N): Loss = -Reward = \sum marginals, which requires gradients on marginal-one-edge, or marginal-two-edges raise NotImplementedError( "Not implemented for global-loss + mr.") elif self.norm_hlocal: # firstly label losses are the same select_label_score = self._score_label_selected( scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) losses_labels = BK.loss_nll(select_label_score, gold_labels_expr) # then specially for arc loss children_masks_arr, _ = self.hlocal_padder.pad( [z.get_children_mask_arr() for z in annotated_insts]) children_masks_expr = BK.input_real( children_masks_arr) # [bs, h, m] # [bs, h] # todo(warn): use prod rather than sum, but still only an approximation for the top-down # losses_arc = -BK.log(BK.sum(BK.softmax(full_arc_score, -2).transpose(-1, -2) * children_masks_expr, dim=-1) + (1-mask_expr)) losses_arc = -BK.sum(BK.log_softmax(full_arc_score, -2).transpose( -1, -2) * children_masks_expr, dim=-1) # including the root-head is important losses_arc[:, 1] += losses_arc[:, 0] final_losses = losses_arc + losses_labels # # jpos loss? (the same mask as parsing) jpos_losses_expr = jpos_pack[1] if jpos_losses_expr is not None: final_losses += jpos_losses_expr # collect loss with mask, also excluding the first symbol of ROOT final_losses_masked = (final_losses * mask_expr)[:, 1:] final_loss_sum = BK.sum(final_losses_masked) # divide loss by what? num_sent = len(annotated_insts) num_valid_tok = sum(len(z) for z in annotated_insts) if self.conf.tconf.loss_div_tok: final_loss = final_loss_sum / num_valid_tok else: final_loss = final_loss_sum / num_sent # final_loss_sum_val = float(BK.get_value(final_loss_sum)) info = { "sent": num_sent, "tok": num_valid_tok, "loss_sum": final_loss_sum_val } if training: BK.backward(final_loss, loss_factor) return info