def loss(self, ms_items: List, bert_expr): conf = self.conf max_range = self.conf.max_range bsize = len(ms_items) # collect instances col_efs, _, col_bidxes_t, col_hidxes_t, col_ldists_t, col_rdists_t = self._collect_insts( ms_items, True) if len(col_efs) == 0: zzz = BK.zeros([]) return [[zzz, zzz, zzz], [zzz, zzz, zzz]] left_scores, right_scores = self._score(bert_expr, col_bidxes_t, col_hidxes_t) # [N, R] if conf.use_binary_scorer: left_binaries, right_binaries = (BK.arange_idx(max_range)<=col_ldists_t.unsqueeze(-1)).float(), \ (BK.arange_idx(max_range)<=col_rdists_t.unsqueeze(-1)).float() # [N,R] left_losses = BK.binary_cross_entropy_with_logits( left_scores, left_binaries, reduction='none')[:, 1:] right_losses = BK.binary_cross_entropy_with_logits( right_scores, right_binaries, reduction='none')[:, 1:] left_count = right_count = BK.input_real( BK.get_shape(left_losses, 0) * (max_range - 1)) else: left_losses = BK.loss_nll(left_scores, col_ldists_t) right_losses = BK.loss_nll(right_scores, col_rdists_t) left_count = right_count = BK.input_real( BK.get_shape(left_losses, 0)) return [[left_losses.sum(), left_count, left_count], [right_losses.sum(), right_count, right_count]]
def loss(self, repr_t, pred_mask_repl_arr, pred_idx_arr): mask_idxes, mask_valids = BK.mask2idx( BK.input_real(pred_mask_repl_arr)) # [bsize, ?] if BK.get_shape(mask_idxes, -1) == 0: # no loss zzz = BK.zeros([]) return [[zzz, zzz, zzz]] else: target_reprs = BK.gather_first_dims(repr_t, mask_idxes, 1) # [bsize, ?, *] target_hids = self.hid_layer(target_reprs) target_scores = self.pred_layer(target_hids) # [bsize, ?, V] pred_idx_t = BK.input_idx(pred_idx_arr) # [bsize, slen] target_idx_t = pred_idx_t.gather(-1, mask_idxes) # [bsize, ?] target_idx_t[(mask_valids < 1.)] = 0 # make sure invalid ones in range # get loss pred_losses = BK.loss_nll(target_scores, target_idx_t) # [bsize, ?] pred_loss_sum = (pred_losses * mask_valids).sum() pred_loss_count = mask_valids.sum() # argmax _, argmax_idxes = target_scores.max(-1) pred_corrs = (argmax_idxes == target_idx_t).float() * mask_valids pred_corr_count = pred_corrs.sum() return [[pred_loss_sum, pred_loss_count, pred_corr_count]]
def loss(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs): conf = self.conf # score scores_t = self._score(repr_t) # [bs, ?+rlen, D] # get gold gold_pidxes = np.zeros(BK.get_shape(mask_t), dtype=np.long) # [bs, ?+rlen] for bidx, inst in enumerate(insts): cur_seq_idxes = getattr(inst, self.attr_name).idxes if self.add_root_token: gold_pidxes[bidx, 1:1 + len(cur_seq_idxes)] = cur_seq_idxes else: gold_pidxes[bidx, :len(cur_seq_idxes)] = cur_seq_idxes # get loss margin = self.margin.value gold_pidxes_t = BK.input_idx(gold_pidxes) gold_pidxes_t *= (gold_pidxes_t < self.pred_out_dim).long() # 0 means invalid ones!! loss_mask_t = (gold_pidxes_t > 0).float() * mask_t # [bs, ?+rlen] lab_losses_t = BK.loss_nll(scores_t, gold_pidxes_t, margin=margin) # [bs, ?+rlen] # argmax _, argmax_idxes = scores_t.max(-1) pred_corrs = (argmax_idxes == gold_pidxes_t).float() * loss_mask_t # compile loss lab_loss = LossHelper.compile_leaf_info("slab", lab_losses_t.sum(), loss_mask_t.sum(), corr=pred_corrs.sum()) return self._compile_component_loss(self.pname, [lab_loss])
def loss(self, repr_t, orig_map: Dict, **kwargs): conf = self.conf _tie_input_embeddings = conf.tie_input_embeddings # -- # specify input add_root_token = self.add_root_token # get from inputs if isinstance(repr_t, (list, tuple)): l2r_repr_t, r2l_repr_t = repr_t elif self.split_input_blm: l2r_repr_t, r2l_repr_t = BK.chunk(repr_t, 2, -1) else: l2r_repr_t, r2l_repr_t = repr_t, None # l2r and r2l word_t = BK.input_idx(orig_map["word"]) # [bs, rlen] slice_zero_t = BK.zeros([BK.get_shape(word_t, 0), 1]).long() # [bs, 1] if add_root_token: l2r_trg_t = BK.concat([word_t, slice_zero_t], -1) # pad one extra 0, [bs, rlen+1] r2l_trg_t = BK.concat( [slice_zero_t, slice_zero_t, word_t[:, :-1]], -1) # pad two extra 0 at front, [bs, 2+rlen-1] else: l2r_trg_t = BK.concat( [word_t[:, 1:], slice_zero_t], -1 ) # pad one extra 0, but remove the first one, [bs, -1+rlen+1] r2l_trg_t = BK.concat( [slice_zero_t, word_t[:, :-1]], -1) # pad one extra 0 at front, [bs, 1+rlen-1] # gather the losses all_losses = [] pred_range_min, pred_range_max = max( 1, conf.min_pred_rank), self.pred_size - 1 if _tie_input_embeddings: pred_W = self.inputter_embed_node.E.E[:self. pred_size] # [PSize, Dim] else: pred_W = None # get input embeddings for output for pred_name, hid_node, pred_node, input_t, trg_t in \ zip(["l2r", "r2l"], [self.l2r_hid_layer, self.r2l_hid_layer], [self.l2r_pred, self.r2l_pred], [l2r_repr_t, r2l_repr_t], [l2r_trg_t, r2l_trg_t]): if input_t is None: continue # hidden hid_t = hid_node( input_t) if hid_node else input_t # [bs, slen, hid] # pred: [bs, slen, Vsize] if _tie_input_embeddings: scores_t = BK.matmul(hid_t, pred_W.T) else: scores_t = pred_node(hid_t) # loss mask_t = ((trg_t >= pred_range_min) & (trg_t <= pred_range_max)).float() # [bs, slen] trg_t.clamp_(max=pred_range_max) # make it in range losses_t = BK.loss_nll(scores_t, trg_t) * mask_t # [bs, slen] _, argmax_idxes = scores_t.max(-1) # [bs, slen] corrs_t = (argmax_idxes == trg_t).float() * mask_t # [bs, slen] # compile leaf loss one_loss = LossHelper.compile_leaf_info(pred_name, losses_t.sum(), mask_t.sum(), loss_lambda=1., corr=corrs_t.sum()) all_losses.append(one_loss) return self._compile_component_loss("plm", all_losses)
def fb_on_batch(self, annotated_insts, training=True, loss_factor=1, **kwargs): self.refresh_batch(training) margin = self.margin.value # gold heads and labels gold_heads_arr, _ = self.predict_padder.pad( [z.heads.vals for z in annotated_insts]) gold_labels_arr, _ = self.predict_padder.pad( [self.real2pred_labels(z.labels.idxes) for z in annotated_insts]) gold_heads_expr = BK.input_idx(gold_heads_arr) # [BS, Len] gold_labels_expr = BK.input_idx(gold_labels_arr) # [BS, Len] # ===== calculate scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score( annotated_insts, training) full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr, training, margin, gold_heads_expr) # final_losses = None if self.norm_local or self.norm_single: select_label_score = self._score_label_selected( scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) # already added margin previously losses_heads = losses_labels = None if self.loss_prob: if self.norm_local: losses_heads = BK.loss_nll(full_arc_score, gold_heads_expr) losses_labels = BK.loss_nll(select_label_score, gold_labels_expr) elif self.norm_single: single_sample = self.conf.tconf.loss_single_sample losses_heads = self._losses_single(full_arc_score, gold_heads_expr, single_sample, is_hinge=False) losses_labels = self._losses_single(select_label_score, gold_labels_expr, single_sample, is_hinge=False) # simply adding final_losses = losses_heads + losses_labels elif self.loss_hinge: if self.norm_local: losses_heads = BK.loss_hinge(full_arc_score, gold_heads_expr) losses_labels = BK.loss_hinge(select_label_score, gold_labels_expr) elif self.norm_single: single_sample = self.conf.tconf.loss_single_sample losses_heads = self._losses_single(full_arc_score, gold_heads_expr, single_sample, is_hinge=True, margin=margin) losses_labels = self._losses_single(select_label_score, gold_labels_expr, single_sample, is_hinge=True, margin=margin) # simply adding final_losses = losses_heads + losses_labels elif self.loss_mr: # special treatment! probs_heads = BK.softmax(full_arc_score, dim=-1) # [bs, m, h] probs_labels = BK.softmax(select_label_score, dim=-1) # [bs, m, h] # select probs_head_gold = BK.gather_one_lastdim( probs_heads, gold_heads_expr).squeeze(-1) # [bs, m] probs_label_gold = BK.gather_one_lastdim( probs_labels, gold_labels_expr).squeeze(-1) # [bs, m] # root and pad will be excluded later # Reward = \sum_i 1.*marginal(GEdge_i); while for global models, need to gradient on marginal-functions # todo(warn): have problem since steps will be quite small, not used! final_losses = (mask_expr - probs_head_gold * probs_label_gold ) # let loss>=0 elif self.norm_global: full_label_score = self._score_label_full(scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) # for this one, use the merged full score full_score = full_arc_score.unsqueeze( -1) + full_label_score # [BS, m, h, L] # +=1 to include ROOT for mst decoding mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts], dtype=np.int32) # do inference if self.loss_prob: marginals_expr = self._marginal( full_score, mask_expr, mst_lengths_arr) # [BS, m, h, L] final_losses = self._losses_global_prob( full_score, gold_heads_expr, gold_labels_expr, marginals_expr, mask_expr) if self.alg_proj: # todo(+N): deal with search-error-like problem, discard unproj neg losses (score>weighted-avg), # but this might be too loose, although the unproj edges are few? gold_unproj_arr, _ = self.predict_padder.pad( [z.unprojs for z in annotated_insts]) gold_unproj_expr = BK.input_real( gold_unproj_arr) # [BS, Len] comparing_expr = Constants.REAL_PRAC_MIN * ( 1. - gold_unproj_expr) final_losses = BK.max_elem(final_losses, comparing_expr) elif self.loss_hinge: pred_heads_arr, pred_labels_arr, _ = self._decode( full_score, mask_expr, mst_lengths_arr) pred_heads_expr = BK.input_idx(pred_heads_arr) # [BS, Len] pred_labels_expr = BK.input_idx(pred_labels_arr) # [BS, Len] # final_losses = self._losses_global_hinge( full_score, gold_heads_expr, gold_labels_expr, pred_heads_expr, pred_labels_expr, mask_expr) elif self.loss_mr: # todo(+N): Loss = -Reward = \sum marginals, which requires gradients on marginal-one-edge, or marginal-two-edges raise NotImplementedError( "Not implemented for global-loss + mr.") elif self.norm_hlocal: # firstly label losses are the same select_label_score = self._score_label_selected( scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) losses_labels = BK.loss_nll(select_label_score, gold_labels_expr) # then specially for arc loss children_masks_arr, _ = self.hlocal_padder.pad( [z.get_children_mask_arr() for z in annotated_insts]) children_masks_expr = BK.input_real( children_masks_arr) # [bs, h, m] # [bs, h] # todo(warn): use prod rather than sum, but still only an approximation for the top-down # losses_arc = -BK.log(BK.sum(BK.softmax(full_arc_score, -2).transpose(-1, -2) * children_masks_expr, dim=-1) + (1-mask_expr)) losses_arc = -BK.sum(BK.log_softmax(full_arc_score, -2).transpose( -1, -2) * children_masks_expr, dim=-1) # including the root-head is important losses_arc[:, 1] += losses_arc[:, 0] final_losses = losses_arc + losses_labels # # jpos loss? (the same mask as parsing) jpos_losses_expr = jpos_pack[1] if jpos_losses_expr is not None: final_losses += jpos_losses_expr # collect loss with mask, also excluding the first symbol of ROOT final_losses_masked = (final_losses * mask_expr)[:, 1:] final_loss_sum = BK.sum(final_losses_masked) # divide loss by what? num_sent = len(annotated_insts) num_valid_tok = sum(len(z) for z in annotated_insts) if self.conf.tconf.loss_div_tok: final_loss = final_loss_sum / num_valid_tok else: final_loss = final_loss_sum / num_sent # final_loss_sum_val = float(BK.get_value(final_loss_sum)) info = { "sent": num_sent, "tok": num_valid_tok, "loss_sum": final_loss_sum_val } if training: BK.backward(final_loss, loss_factor) return info
def loss(self, insts: List[GeneralSentence], repr_t, attn_t, mask_t, **kwargs): conf = self.conf # detach input? if self.no_detach_input.value <= 0.: repr_t = repr_t.detach() # no grad back if no_detach_input<=0. # scoring label_scores, score_masks = self._score( repr_t, attn_t, mask_t) # [bs, len_q, len_k, 1+N], [bs, len_q, len_k] # ----- # get golds bsize, max_len = BK.get_shape(mask_t) shape_lidxes = [bsize, max_len, max_len] gold_lidxes = np.zeros(shape_lidxes, dtype=np.long) # [bs, mlen, mlen] gold_heads = np.zeros(shape_lidxes[:-1], dtype=np.long) # [bs, mlen] for bidx, inst in enumerate(insts): cur_dep_tree = inst.dep_tree cur_len = len(cur_dep_tree) gold_lidxes[bidx, :cur_len, :cur_len] = cur_dep_tree.label_matrix gold_heads[bidx, :cur_len] = cur_dep_tree.heads # ----- margin = self.margin.value all_losses = [] # first is loss_labels lambda_label = conf.lambda_label if lambda_label > 0.: gold_lidxes_t = BK.input_idx(gold_lidxes) # [bs, len_q, len_k] label_losses = BK.loss_nll(label_scores, gold_lidxes_t, margin=margin) # [bs, mlen, mlen] positive_mask_t = (gold_lidxes_t > 0).float() # [bs, mlen, mlen] negative_mask_t = (BK.rand(shape_lidxes) < conf.label_neg_rate).float() # [bs, mlen, mlen] loss_mask_t = score_masks * (positive_mask_t + negative_mask_t ) # [bs, mlen, mlen] loss_mask_t.clamp_(max=1.) masked_label_losses = label_losses * loss_mask_t # compile loss final_label_loss = LossHelper.compile_leaf_info( f"label", masked_label_losses.sum(), loss_mask_t.sum(), loss_lambda=lambda_label, npos=positive_mask_t.sum()) all_losses.append(final_label_loss) # then head loss lambda_head = conf.lambda_head if lambda_head > 0.: # get head score simply by argmax on ranges head_scores, _ = self._ranged_label_scores(label_scores).max( -1) # [bs, mlen, mlen] gold_heads_t = BK.input_idx(gold_heads) head_losses = BK.loss_nll(head_scores, gold_heads_t, margin=margin) # [bs, mlen] # mask head_mask_t = BK.copy(mask_t) head_mask_t[:, 0] = 0 # not for ARTI_ROOT masked_head_losses = head_losses * head_mask_t # compile loss final_head_loss = LossHelper.compile_leaf_info( f"head", masked_head_losses.sum(), head_mask_t.sum(), loss_lambda=lambda_label) all_losses.append(final_head_loss) # -- return self._compile_component_loss("dp", all_losses)
def loss(self, repr_ts, input_erase_mask_arr, orig_map: Dict, active_hid=True, **kwargs): conf = self.conf _tie_input_embeddings = conf.tie_input_embeddings # prepare idxes for the masked ones if self.add_root_token: # offset for the special root added in embedder mask_idxes, mask_valids = BK.mask2idx( BK.input_real(input_erase_mask_arr), padding_idx=-1) # [bsize, ?] repr_mask_idxes = mask_idxes + 1 mask_idxes.clamp_(min=0) else: mask_idxes, mask_valids = BK.mask2idx( BK.input_real(input_erase_mask_arr)) # [bsize, ?] repr_mask_idxes = mask_idxes # get the losses if BK.get_shape(mask_idxes, -1) == 0: # no loss return self._compile_component_loss("mlm", []) else: if not isinstance(repr_ts, (List, Tuple)): repr_ts = [repr_ts] target_word_scores, target_pos_scores = [], [] target_pos_scores = None # todo(+N): for simplicity, currently ignore this one!! for layer_idx in conf.loss_layers: # calculate scores target_reprs = BK.gather_first_dims(repr_ts[layer_idx], repr_mask_idxes, 1) # [bsize, ?, *] if self.hid_layer and active_hid: # todo(+N): sometimes, we only want last softmax, need to ensure dim at outside! target_hids = self.hid_layer(target_reprs) else: target_hids = target_reprs if _tie_input_embeddings: pred_W = self.inputter_word_node.E.E[:self. pred_word_size] # [PSize, Dim] target_word_scores.append(BK.matmul( target_hids, pred_W.T)) # List[bsize, ?, Vw] else: target_word_scores.append(self.pred_word_layer( target_hids)) # List[bsize, ?, Vw] # gather the losses all_losses = [] for pred_name, target_scores, loss_lambda, range_min, range_max in \ zip(["word", "pos"], [target_word_scores, target_pos_scores], [conf.lambda_word, conf.lambda_pos], [conf.min_pred_rank, 0], [min(conf.max_pred_rank, self.pred_word_size-1), self.pred_pos_size-1]): if loss_lambda > 0.: seq_idx_t = BK.input_idx( orig_map[pred_name]) # [bsize, slen] target_idx_t = seq_idx_t.gather(-1, mask_idxes) # [bsize, ?] ranged_mask_valids = mask_valids * ( target_idx_t >= range_min).float() * ( target_idx_t <= range_max).float() target_idx_t[(ranged_mask_valids < 1.)] = 0 # make sure invalid ones in range # calculate for each layer all_layer_losses, all_layer_scores = [], [] for one_layer_idx, one_target_scores in enumerate( target_scores): # get loss: [bsize, ?] one_pred_losses = BK.loss_nll( one_target_scores, target_idx_t) * conf.loss_weights[one_layer_idx] all_layer_losses.append(one_pred_losses) # get scores one_pred_scores = BK.log_softmax( one_target_scores, -1) * conf.loss_weights[one_layer_idx] all_layer_scores.append(one_pred_scores) # combine all layers pred_losses = self.loss_comb_f(all_layer_losses) pred_loss_sum = (pred_losses * ranged_mask_valids).sum() pred_loss_count = ranged_mask_valids.sum() # argmax _, argmax_idxes = self.score_comb_f(all_layer_scores).max( -1) pred_corrs = (argmax_idxes == target_idx_t).float() * ranged_mask_valids pred_corr_count = pred_corrs.sum() # compile leaf loss r_loss = LossHelper.compile_leaf_info( pred_name, pred_loss_sum, pred_loss_count, loss_lambda=loss_lambda, corr=pred_corr_count) all_losses.append(r_loss) return self._compile_component_loss("mlm", all_losses)