Beispiel #1
0
 def loss(self, ms_items: List, bert_expr):
     conf = self.conf
     max_range = self.conf.max_range
     bsize = len(ms_items)
     # collect instances
     col_efs, _, col_bidxes_t, col_hidxes_t, col_ldists_t, col_rdists_t = self._collect_insts(
         ms_items, True)
     if len(col_efs) == 0:
         zzz = BK.zeros([])
         return [[zzz, zzz, zzz], [zzz, zzz, zzz]]
     left_scores, right_scores = self._score(bert_expr, col_bidxes_t,
                                             col_hidxes_t)  # [N, R]
     if conf.use_binary_scorer:
         left_binaries, right_binaries = (BK.arange_idx(max_range)<=col_ldists_t.unsqueeze(-1)).float(), \
                                         (BK.arange_idx(max_range)<=col_rdists_t.unsqueeze(-1)).float()  # [N,R]
         left_losses = BK.binary_cross_entropy_with_logits(
             left_scores, left_binaries, reduction='none')[:, 1:]
         right_losses = BK.binary_cross_entropy_with_logits(
             right_scores, right_binaries, reduction='none')[:, 1:]
         left_count = right_count = BK.input_real(
             BK.get_shape(left_losses, 0) * (max_range - 1))
     else:
         left_losses = BK.loss_nll(left_scores, col_ldists_t)
         right_losses = BK.loss_nll(right_scores, col_rdists_t)
         left_count = right_count = BK.input_real(
             BK.get_shape(left_losses, 0))
     return [[left_losses.sum(), left_count, left_count],
             [right_losses.sum(), right_count, right_count]]
Beispiel #2
0
 def loss(self, repr_t, pred_mask_repl_arr, pred_idx_arr):
     mask_idxes, mask_valids = BK.mask2idx(
         BK.input_real(pred_mask_repl_arr))  # [bsize, ?]
     if BK.get_shape(mask_idxes, -1) == 0:  # no loss
         zzz = BK.zeros([])
         return [[zzz, zzz, zzz]]
     else:
         target_reprs = BK.gather_first_dims(repr_t, mask_idxes,
                                             1)  # [bsize, ?, *]
         target_hids = self.hid_layer(target_reprs)
         target_scores = self.pred_layer(target_hids)  # [bsize, ?, V]
         pred_idx_t = BK.input_idx(pred_idx_arr)  # [bsize, slen]
         target_idx_t = pred_idx_t.gather(-1, mask_idxes)  # [bsize, ?]
         target_idx_t[(mask_valids <
                       1.)] = 0  # make sure invalid ones in range
         # get loss
         pred_losses = BK.loss_nll(target_scores,
                                   target_idx_t)  # [bsize, ?]
         pred_loss_sum = (pred_losses * mask_valids).sum()
         pred_loss_count = mask_valids.sum()
         # argmax
         _, argmax_idxes = target_scores.max(-1)
         pred_corrs = (argmax_idxes == target_idx_t).float() * mask_valids
         pred_corr_count = pred_corrs.sum()
         return [[pred_loss_sum, pred_loss_count, pred_corr_count]]
Beispiel #3
0
 def loss(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs):
     conf = self.conf
     # score
     scores_t = self._score(repr_t)  # [bs, ?+rlen, D]
     # get gold
     gold_pidxes = np.zeros(BK.get_shape(mask_t),
                            dtype=np.long)  # [bs, ?+rlen]
     for bidx, inst in enumerate(insts):
         cur_seq_idxes = getattr(inst, self.attr_name).idxes
         if self.add_root_token:
             gold_pidxes[bidx, 1:1 + len(cur_seq_idxes)] = cur_seq_idxes
         else:
             gold_pidxes[bidx, :len(cur_seq_idxes)] = cur_seq_idxes
     # get loss
     margin = self.margin.value
     gold_pidxes_t = BK.input_idx(gold_pidxes)
     gold_pidxes_t *= (gold_pidxes_t <
                       self.pred_out_dim).long()  # 0 means invalid ones!!
     loss_mask_t = (gold_pidxes_t > 0).float() * mask_t  # [bs, ?+rlen]
     lab_losses_t = BK.loss_nll(scores_t, gold_pidxes_t,
                                margin=margin)  # [bs, ?+rlen]
     # argmax
     _, argmax_idxes = scores_t.max(-1)
     pred_corrs = (argmax_idxes == gold_pidxes_t).float() * loss_mask_t
     # compile loss
     lab_loss = LossHelper.compile_leaf_info("slab",
                                             lab_losses_t.sum(),
                                             loss_mask_t.sum(),
                                             corr=pred_corrs.sum())
     return self._compile_component_loss(self.pname, [lab_loss])
Beispiel #4
0
 def loss(self, repr_t, orig_map: Dict, **kwargs):
     conf = self.conf
     _tie_input_embeddings = conf.tie_input_embeddings
     # --
     # specify input
     add_root_token = self.add_root_token
     # get from inputs
     if isinstance(repr_t, (list, tuple)):
         l2r_repr_t, r2l_repr_t = repr_t
     elif self.split_input_blm:
         l2r_repr_t, r2l_repr_t = BK.chunk(repr_t, 2, -1)
     else:
         l2r_repr_t, r2l_repr_t = repr_t, None
     # l2r and r2l
     word_t = BK.input_idx(orig_map["word"])  # [bs, rlen]
     slice_zero_t = BK.zeros([BK.get_shape(word_t, 0), 1]).long()  # [bs, 1]
     if add_root_token:
         l2r_trg_t = BK.concat([word_t, slice_zero_t],
                               -1)  # pad one extra 0, [bs, rlen+1]
         r2l_trg_t = BK.concat(
             [slice_zero_t, slice_zero_t, word_t[:, :-1]],
             -1)  # pad two extra 0 at front, [bs, 2+rlen-1]
     else:
         l2r_trg_t = BK.concat(
             [word_t[:, 1:], slice_zero_t], -1
         )  # pad one extra 0, but remove the first one, [bs, -1+rlen+1]
         r2l_trg_t = BK.concat(
             [slice_zero_t, word_t[:, :-1]],
             -1)  # pad one extra 0 at front, [bs, 1+rlen-1]
     # gather the losses
     all_losses = []
     pred_range_min, pred_range_max = max(
         1, conf.min_pred_rank), self.pred_size - 1
     if _tie_input_embeddings:
         pred_W = self.inputter_embed_node.E.E[:self.
                                               pred_size]  # [PSize, Dim]
     else:
         pred_W = None
     # get input embeddings for output
     for pred_name, hid_node, pred_node, input_t, trg_t in \
                 zip(["l2r", "r2l"], [self.l2r_hid_layer, self.r2l_hid_layer], [self.l2r_pred, self.r2l_pred],
                     [l2r_repr_t, r2l_repr_t], [l2r_trg_t, r2l_trg_t]):
         if input_t is None:
             continue
         # hidden
         hid_t = hid_node(
             input_t) if hid_node else input_t  # [bs, slen, hid]
         # pred: [bs, slen, Vsize]
         if _tie_input_embeddings:
             scores_t = BK.matmul(hid_t, pred_W.T)
         else:
             scores_t = pred_node(hid_t)
         # loss
         mask_t = ((trg_t >= pred_range_min) &
                   (trg_t <= pred_range_max)).float()  # [bs, slen]
         trg_t.clamp_(max=pred_range_max)  # make it in range
         losses_t = BK.loss_nll(scores_t, trg_t) * mask_t  # [bs, slen]
         _, argmax_idxes = scores_t.max(-1)  # [bs, slen]
         corrs_t = (argmax_idxes == trg_t).float() * mask_t  # [bs, slen]
         # compile leaf loss
         one_loss = LossHelper.compile_leaf_info(pred_name,
                                                 losses_t.sum(),
                                                 mask_t.sum(),
                                                 loss_lambda=1.,
                                                 corr=corrs_t.sum())
         all_losses.append(one_loss)
     return self._compile_component_loss("plm", all_losses)
Beispiel #5
0
 def fb_on_batch(self,
                 annotated_insts,
                 training=True,
                 loss_factor=1,
                 **kwargs):
     self.refresh_batch(training)
     margin = self.margin.value
     # gold heads and labels
     gold_heads_arr, _ = self.predict_padder.pad(
         [z.heads.vals for z in annotated_insts])
     gold_labels_arr, _ = self.predict_padder.pad(
         [self.real2pred_labels(z.labels.idxes) for z in annotated_insts])
     gold_heads_expr = BK.input_idx(gold_heads_arr)  # [BS, Len]
     gold_labels_expr = BK.input_idx(gold_labels_arr)  # [BS, Len]
     # ===== calculate
     scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score(
         annotated_insts, training)
     full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr,
                                           training, margin,
                                           gold_heads_expr)
     #
     final_losses = None
     if self.norm_local or self.norm_single:
         select_label_score = self._score_label_selected(
             scoring_expr_pack, mask_expr, training, margin,
             gold_heads_expr, gold_labels_expr)
         # already added margin previously
         losses_heads = losses_labels = None
         if self.loss_prob:
             if self.norm_local:
                 losses_heads = BK.loss_nll(full_arc_score, gold_heads_expr)
                 losses_labels = BK.loss_nll(select_label_score,
                                             gold_labels_expr)
             elif self.norm_single:
                 single_sample = self.conf.tconf.loss_single_sample
                 losses_heads = self._losses_single(full_arc_score,
                                                    gold_heads_expr,
                                                    single_sample,
                                                    is_hinge=False)
                 losses_labels = self._losses_single(select_label_score,
                                                     gold_labels_expr,
                                                     single_sample,
                                                     is_hinge=False)
             # simply adding
             final_losses = losses_heads + losses_labels
         elif self.loss_hinge:
             if self.norm_local:
                 losses_heads = BK.loss_hinge(full_arc_score,
                                              gold_heads_expr)
                 losses_labels = BK.loss_hinge(select_label_score,
                                               gold_labels_expr)
             elif self.norm_single:
                 single_sample = self.conf.tconf.loss_single_sample
                 losses_heads = self._losses_single(full_arc_score,
                                                    gold_heads_expr,
                                                    single_sample,
                                                    is_hinge=True,
                                                    margin=margin)
                 losses_labels = self._losses_single(select_label_score,
                                                     gold_labels_expr,
                                                     single_sample,
                                                     is_hinge=True,
                                                     margin=margin)
             # simply adding
             final_losses = losses_heads + losses_labels
         elif self.loss_mr:
             # special treatment!
             probs_heads = BK.softmax(full_arc_score, dim=-1)  # [bs, m, h]
             probs_labels = BK.softmax(select_label_score,
                                       dim=-1)  # [bs, m, h]
             # select
             probs_head_gold = BK.gather_one_lastdim(
                 probs_heads, gold_heads_expr).squeeze(-1)  # [bs, m]
             probs_label_gold = BK.gather_one_lastdim(
                 probs_labels, gold_labels_expr).squeeze(-1)  # [bs, m]
             # root and pad will be excluded later
             # Reward = \sum_i 1.*marginal(GEdge_i); while for global models, need to gradient on marginal-functions
             # todo(warn): have problem since steps will be quite small, not used!
             final_losses = (mask_expr - probs_head_gold * probs_label_gold
                             )  # let loss>=0
     elif self.norm_global:
         full_label_score = self._score_label_full(scoring_expr_pack,
                                                   mask_expr, training,
                                                   margin, gold_heads_expr,
                                                   gold_labels_expr)
         # for this one, use the merged full score
         full_score = full_arc_score.unsqueeze(
             -1) + full_label_score  # [BS, m, h, L]
         # +=1 to include ROOT for mst decoding
         mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts],
                                      dtype=np.int32)
         # do inference
         if self.loss_prob:
             marginals_expr = self._marginal(
                 full_score, mask_expr, mst_lengths_arr)  # [BS, m, h, L]
             final_losses = self._losses_global_prob(
                 full_score, gold_heads_expr, gold_labels_expr,
                 marginals_expr, mask_expr)
             if self.alg_proj:
                 # todo(+N): deal with search-error-like problem, discard unproj neg losses (score>weighted-avg),
                 #  but this might be too loose, although the unproj edges are few?
                 gold_unproj_arr, _ = self.predict_padder.pad(
                     [z.unprojs for z in annotated_insts])
                 gold_unproj_expr = BK.input_real(
                     gold_unproj_arr)  # [BS, Len]
                 comparing_expr = Constants.REAL_PRAC_MIN * (
                     1. - gold_unproj_expr)
                 final_losses = BK.max_elem(final_losses, comparing_expr)
         elif self.loss_hinge:
             pred_heads_arr, pred_labels_arr, _ = self._decode(
                 full_score, mask_expr, mst_lengths_arr)
             pred_heads_expr = BK.input_idx(pred_heads_arr)  # [BS, Len]
             pred_labels_expr = BK.input_idx(pred_labels_arr)  # [BS, Len]
             #
             final_losses = self._losses_global_hinge(
                 full_score, gold_heads_expr, gold_labels_expr,
                 pred_heads_expr, pred_labels_expr, mask_expr)
         elif self.loss_mr:
             # todo(+N): Loss = -Reward = \sum marginals, which requires gradients on marginal-one-edge, or marginal-two-edges
             raise NotImplementedError(
                 "Not implemented for global-loss + mr.")
     elif self.norm_hlocal:
         # firstly label losses are the same
         select_label_score = self._score_label_selected(
             scoring_expr_pack, mask_expr, training, margin,
             gold_heads_expr, gold_labels_expr)
         losses_labels = BK.loss_nll(select_label_score, gold_labels_expr)
         # then specially for arc loss
         children_masks_arr, _ = self.hlocal_padder.pad(
             [z.get_children_mask_arr() for z in annotated_insts])
         children_masks_expr = BK.input_real(
             children_masks_arr)  # [bs, h, m]
         # [bs, h]
         # todo(warn): use prod rather than sum, but still only an approximation for the top-down
         # losses_arc = -BK.log(BK.sum(BK.softmax(full_arc_score, -2).transpose(-1, -2) * children_masks_expr, dim=-1) + (1-mask_expr))
         losses_arc = -BK.sum(BK.log_softmax(full_arc_score, -2).transpose(
             -1, -2) * children_masks_expr,
                              dim=-1)
         # including the root-head is important
         losses_arc[:, 1] += losses_arc[:, 0]
         final_losses = losses_arc + losses_labels
     #
     # jpos loss? (the same mask as parsing)
     jpos_losses_expr = jpos_pack[1]
     if jpos_losses_expr is not None:
         final_losses += jpos_losses_expr
     # collect loss with mask, also excluding the first symbol of ROOT
     final_losses_masked = (final_losses * mask_expr)[:, 1:]
     final_loss_sum = BK.sum(final_losses_masked)
     # divide loss by what?
     num_sent = len(annotated_insts)
     num_valid_tok = sum(len(z) for z in annotated_insts)
     if self.conf.tconf.loss_div_tok:
         final_loss = final_loss_sum / num_valid_tok
     else:
         final_loss = final_loss_sum / num_sent
     #
     final_loss_sum_val = float(BK.get_value(final_loss_sum))
     info = {
         "sent": num_sent,
         "tok": num_valid_tok,
         "loss_sum": final_loss_sum_val
     }
     if training:
         BK.backward(final_loss, loss_factor)
     return info
Beispiel #6
0
 def loss(self, insts: List[GeneralSentence], repr_t, attn_t, mask_t,
          **kwargs):
     conf = self.conf
     # detach input?
     if self.no_detach_input.value <= 0.:
         repr_t = repr_t.detach()  # no grad back if no_detach_input<=0.
     # scoring
     label_scores, score_masks = self._score(
         repr_t, attn_t,
         mask_t)  # [bs, len_q, len_k, 1+N], [bs, len_q, len_k]
     # -----
     # get golds
     bsize, max_len = BK.get_shape(mask_t)
     shape_lidxes = [bsize, max_len, max_len]
     gold_lidxes = np.zeros(shape_lidxes, dtype=np.long)  # [bs, mlen, mlen]
     gold_heads = np.zeros(shape_lidxes[:-1], dtype=np.long)  # [bs, mlen]
     for bidx, inst in enumerate(insts):
         cur_dep_tree = inst.dep_tree
         cur_len = len(cur_dep_tree)
         gold_lidxes[bidx, :cur_len, :cur_len] = cur_dep_tree.label_matrix
         gold_heads[bidx, :cur_len] = cur_dep_tree.heads
     # -----
     margin = self.margin.value
     all_losses = []
     # first is loss_labels
     lambda_label = conf.lambda_label
     if lambda_label > 0.:
         gold_lidxes_t = BK.input_idx(gold_lidxes)  # [bs, len_q, len_k]
         label_losses = BK.loss_nll(label_scores,
                                    gold_lidxes_t,
                                    margin=margin)  # [bs, mlen, mlen]
         positive_mask_t = (gold_lidxes_t > 0).float()  # [bs, mlen, mlen]
         negative_mask_t = (BK.rand(shape_lidxes) <
                            conf.label_neg_rate).float()  # [bs, mlen, mlen]
         loss_mask_t = score_masks * (positive_mask_t + negative_mask_t
                                      )  # [bs, mlen, mlen]
         loss_mask_t.clamp_(max=1.)
         masked_label_losses = label_losses * loss_mask_t
         # compile loss
         final_label_loss = LossHelper.compile_leaf_info(
             f"label",
             masked_label_losses.sum(),
             loss_mask_t.sum(),
             loss_lambda=lambda_label,
             npos=positive_mask_t.sum())
         all_losses.append(final_label_loss)
     # then head loss
     lambda_head = conf.lambda_head
     if lambda_head > 0.:
         # get head score simply by argmax on ranges
         head_scores, _ = self._ranged_label_scores(label_scores).max(
             -1)  # [bs, mlen, mlen]
         gold_heads_t = BK.input_idx(gold_heads)
         head_losses = BK.loss_nll(head_scores, gold_heads_t,
                                   margin=margin)  # [bs, mlen]
         # mask
         head_mask_t = BK.copy(mask_t)
         head_mask_t[:, 0] = 0  # not for ARTI_ROOT
         masked_head_losses = head_losses * head_mask_t
         # compile loss
         final_head_loss = LossHelper.compile_leaf_info(
             f"head",
             masked_head_losses.sum(),
             head_mask_t.sum(),
             loss_lambda=lambda_label)
         all_losses.append(final_head_loss)
     # --
     return self._compile_component_loss("dp", all_losses)
Beispiel #7
0
 def loss(self,
          repr_ts,
          input_erase_mask_arr,
          orig_map: Dict,
          active_hid=True,
          **kwargs):
     conf = self.conf
     _tie_input_embeddings = conf.tie_input_embeddings
     # prepare idxes for the masked ones
     if self.add_root_token:  # offset for the special root added in embedder
         mask_idxes, mask_valids = BK.mask2idx(
             BK.input_real(input_erase_mask_arr),
             padding_idx=-1)  # [bsize, ?]
         repr_mask_idxes = mask_idxes + 1
         mask_idxes.clamp_(min=0)
     else:
         mask_idxes, mask_valids = BK.mask2idx(
             BK.input_real(input_erase_mask_arr))  # [bsize, ?]
         repr_mask_idxes = mask_idxes
     # get the losses
     if BK.get_shape(mask_idxes, -1) == 0:  # no loss
         return self._compile_component_loss("mlm", [])
     else:
         if not isinstance(repr_ts, (List, Tuple)):
             repr_ts = [repr_ts]
         target_word_scores, target_pos_scores = [], []
         target_pos_scores = None  # todo(+N): for simplicity, currently ignore this one!!
         for layer_idx in conf.loss_layers:
             # calculate scores
             target_reprs = BK.gather_first_dims(repr_ts[layer_idx],
                                                 repr_mask_idxes,
                                                 1)  # [bsize, ?, *]
             if self.hid_layer and active_hid:  # todo(+N): sometimes, we only want last softmax, need to ensure dim at outside!
                 target_hids = self.hid_layer(target_reprs)
             else:
                 target_hids = target_reprs
             if _tie_input_embeddings:
                 pred_W = self.inputter_word_node.E.E[:self.
                                                      pred_word_size]  # [PSize, Dim]
                 target_word_scores.append(BK.matmul(
                     target_hids, pred_W.T))  # List[bsize, ?, Vw]
             else:
                 target_word_scores.append(self.pred_word_layer(
                     target_hids))  # List[bsize, ?, Vw]
         # gather the losses
         all_losses = []
         for pred_name, target_scores, loss_lambda, range_min, range_max in \
                 zip(["word", "pos"], [target_word_scores, target_pos_scores], [conf.lambda_word, conf.lambda_pos],
                     [conf.min_pred_rank, 0], [min(conf.max_pred_rank, self.pred_word_size-1), self.pred_pos_size-1]):
             if loss_lambda > 0.:
                 seq_idx_t = BK.input_idx(
                     orig_map[pred_name])  # [bsize, slen]
                 target_idx_t = seq_idx_t.gather(-1,
                                                 mask_idxes)  # [bsize, ?]
                 ranged_mask_valids = mask_valids * (
                     target_idx_t >= range_min).float() * (
                         target_idx_t <= range_max).float()
                 target_idx_t[(ranged_mask_valids <
                               1.)] = 0  # make sure invalid ones in range
                 # calculate for each layer
                 all_layer_losses, all_layer_scores = [], []
                 for one_layer_idx, one_target_scores in enumerate(
                         target_scores):
                     # get loss: [bsize, ?]
                     one_pred_losses = BK.loss_nll(
                         one_target_scores,
                         target_idx_t) * conf.loss_weights[one_layer_idx]
                     all_layer_losses.append(one_pred_losses)
                     # get scores
                     one_pred_scores = BK.log_softmax(
                         one_target_scores,
                         -1) * conf.loss_weights[one_layer_idx]
                     all_layer_scores.append(one_pred_scores)
                 # combine all layers
                 pred_losses = self.loss_comb_f(all_layer_losses)
                 pred_loss_sum = (pred_losses * ranged_mask_valids).sum()
                 pred_loss_count = ranged_mask_valids.sum()
                 # argmax
                 _, argmax_idxes = self.score_comb_f(all_layer_scores).max(
                     -1)
                 pred_corrs = (argmax_idxes
                               == target_idx_t).float() * ranged_mask_valids
                 pred_corr_count = pred_corrs.sum()
                 # compile leaf loss
                 r_loss = LossHelper.compile_leaf_info(
                     pred_name,
                     pred_loss_sum,
                     pred_loss_count,
                     loss_lambda=loss_lambda,
                     corr=pred_corr_count)
                 all_losses.append(r_loss)
         return self._compile_component_loss("mlm", all_losses)