示例#1
0
 def select_oracle(self, ags: List[BfsAgenda], candidates, mode, k_arc,
                   k_label) -> List[List]:
     flattened_states, cur_arc_scores, scoring_mask_ct = candidates
     cur_cache = self.cache
     cur_bsize = len(flattened_states)
     cur_slen = cur_cache.max_slen
     if mode == "topk":
         # todo(note): there can be multiple oracles, select topk(usually top1) in this mode.
         # get and apply oracle mask
         cur_oracle_mask_t, cur_oracle_label_t = self._get_oracle_mask(
             flattened_states)
         # [bs, Lm*Lh]
         cur_oracle_arc_scores = (cur_arc_scores + Constants.REAL_PRAC_MIN *
                                  (1. - cur_oracle_mask_t)).view(
                                      [cur_bsize, -1])
         # arcs [*, k]
         topk_arc_scores, topk_arc_idxes = BK.topk(cur_oracle_arc_scores,
                                                   k_arc,
                                                   dim=-1,
                                                   sorted=False)
         topk_m, topk_h = topk_arc_idxes / cur_slen, topk_arc_idxes % cur_slen  # [m, h]
         # labels [*, k, 1]
         # todo(note): here we gather labels since one arc can only have one oracle label
         cur_label_scores = cur_cache.get_selected_label_scores(
             topk_m, topk_h, 0., 0.)  # [*, k, labels]
         topk_label_idxes = cur_oracle_label_t[
             cur_cache.bsize_range_t.unsqueeze(-1), topk_m,
             topk_h].unsqueeze(-1)  # [*, k, 1]
         # todo(+N): here is the trick to avoid repeated calculations, maybe not correct when using full dynamic oracle
         topk_label_scores = BK.gather(cur_label_scores, topk_label_idxes,
                                       -1) - self.mw_label
         # todo(+N): here use both masks, which may lead to no oracles! Can we simply drop the oracle_mask?
         return self._new_states(flattened_states,
                                 scoring_mask_ct * cur_cache.oracle_mask_ct,
                                 topk_arc_scores, topk_m, topk_h,
                                 topk_label_scores, topk_label_idxes)
     elif mode == "":
         return [[]] * cur_bsize
     # todo(+N): other modes like sampling to be implemented: sample, topk-sample, gather
     else:
         raise NotImplementedError(mode)
示例#2
0
 def _score_label_selected(self,
                           scoring_expr_pack,
                           mask_expr,
                           training,
                           margin,
                           gold_heads_expr,
                           gold_labels_expr=None):
     _, _, lm_expr, lh_expr = scoring_expr_pack
     # [BS, len-m, D]
     lh_expr_shape = BK.get_shape(lh_expr)
     selected_lh_expr = BK.gather(
         lh_expr,
         gold_heads_expr.unsqueeze(-1).expand(*lh_expr_shape),
         dim=len(lh_expr_shape) - 2)
     # [BS, len-m, L]
     select_label_score = self.scorer.score_label_select(
         lm_expr, selected_lh_expr, mask_expr)
     # margin?
     if training and margin > 0.:
         select_label_score = BK.minus_margin(select_label_score,
                                              gold_labels_expr, margin)
     return select_label_score
示例#3
0
 def loss(self, input_expr, loss_mask, gold_idxes, margin=0.):
     gold_all_idxes = self._get_all_idxes(gold_idxes)
     # scoring
     raw_scores = self._raw_scores(input_expr)
     raw_scores_aug = []
     margin_P, margin_R, margin_T = self.conf.margin_lambda_P, self.conf.margin_lambda_R, self.conf.margin_lambda_T
     #
     gold_shape = BK.get_shape(gold_idxes)  # [*]
     gold_bsize_prod = np.prod(gold_shape)
     # gold_arange_idxes = BK.arange_idx(gold_bsize_prod)
     # margin
     for i in range(self.eff_max_layer):
         cur_gold_inputs = gold_all_idxes[i]
         # add margin
         cur_scores = raw_scores[i]  # [*, ?]
         cur_margin = margin * self.margin_lambdas[i]
         if cur_margin > 0.:
             cur_num_target = self.prediction_sizes[i]
             cur_isnil = self.layered_isnil[i].byte()  # [NLab]
             cost_matrix = BK.constants([cur_num_target, cur_num_target],
                                        margin_T)  # [gold, pred]
             cost_matrix[cur_isnil, :] = margin_P
             cost_matrix[:, cur_isnil] = margin_R
             diag_idxes = BK.arange_idx(cur_num_target)
             cost_matrix[diag_idxes, diag_idxes] = 0.
             margin_mat = cost_matrix[cur_gold_inputs]
             cur_aug_scores = cur_scores + margin_mat  # [*, ?]
         else:
             cur_aug_scores = cur_scores
         raw_scores_aug.append(cur_aug_scores)
     # cascade scores
     final_scores = self._cascade_scores(raw_scores_aug)
     # loss weight, todo(note): asserted self.hl_vocab.nil_as_zero before
     loss_weights = ((gold_idxes == 0).float() *
                     (self.loss_fullnil_weight - 1.) +
                     1.) if self.loss_fullnil_weight < 1. else 1.
     # calculate loss
     loss_prob_entropy_lambda = self.conf.loss_prob_entropy_lambda
     loss_prob_reweight = self.conf.loss_prob_reweight
     final_losses = []
     no_loss_max_gold = self.conf.no_loss_max_gold
     if loss_mask is None:
         loss_mask = BK.constants(BK.get_shape(input_expr)[:-1], 1.)
     for i in range(self.eff_max_layer):
         cur_final_scores, cur_gold_inputs = final_scores[
             i], gold_all_idxes[i]  # [*, ?], [*]
         # collect the loss
         if self.is_hinge_loss:
             cur_pred_scores, cur_pred_idxes = cur_final_scores.max(-1)
             cur_gold_scores = BK.gather(cur_final_scores,
                                         cur_gold_inputs.unsqueeze(-1),
                                         -1).squeeze(-1)
             cur_loss = cur_pred_scores - cur_gold_scores  # [*], todo(note): this must be >=0
             if no_loss_max_gold:  # this should be implicit
                 cur_loss = cur_loss * (cur_loss > 0.).float()
         elif self.is_prob_loss:
             # cur_loss = BK.loss_nll(cur_final_scores, cur_gold_inputs)  # [*]
             cur_loss = self._my_loss_prob(cur_final_scores,
                                           cur_gold_inputs,
                                           loss_prob_entropy_lambda,
                                           loss_mask,
                                           loss_prob_reweight)  # [*]
             if no_loss_max_gold:
                 cur_pred_scores, cur_pred_idxes = cur_final_scores.max(-1)
                 cur_gold_scores = BK.gather(cur_final_scores,
                                             cur_gold_inputs.unsqueeze(-1),
                                             -1).squeeze(-1)
                 cur_loss = cur_loss * (cur_gold_scores >
                                        cur_pred_scores).float()
         else:
             raise NotImplementedError(
                 f"UNK loss {self.conf.loss_function}")
         # here first summing up, divided at the outside
         one_loss_sum = (
             cur_loss *
             (loss_mask * loss_weights)).sum() * self.loss_lambdas[i]
         final_losses.append(one_loss_sum)
     # final sum
     final_loss_sum = BK.stack(final_losses).sum()
     _, ret_lab_idxes, ret_lab_embeds = self._predict(final_scores, None)
     return [[final_loss_sum,
              loss_mask.sum()]], ret_lab_idxes, ret_lab_embeds