def select_oracle(self, ags: List[BfsAgenda], candidates, mode, k_arc, k_label) -> List[List]: flattened_states, cur_arc_scores, scoring_mask_ct = candidates cur_cache = self.cache cur_bsize = len(flattened_states) cur_slen = cur_cache.max_slen if mode == "topk": # todo(note): there can be multiple oracles, select topk(usually top1) in this mode. # get and apply oracle mask cur_oracle_mask_t, cur_oracle_label_t = self._get_oracle_mask( flattened_states) # [bs, Lm*Lh] cur_oracle_arc_scores = (cur_arc_scores + Constants.REAL_PRAC_MIN * (1. - cur_oracle_mask_t)).view( [cur_bsize, -1]) # arcs [*, k] topk_arc_scores, topk_arc_idxes = BK.topk(cur_oracle_arc_scores, k_arc, dim=-1, sorted=False) topk_m, topk_h = topk_arc_idxes / cur_slen, topk_arc_idxes % cur_slen # [m, h] # labels [*, k, 1] # todo(note): here we gather labels since one arc can only have one oracle label cur_label_scores = cur_cache.get_selected_label_scores( topk_m, topk_h, 0., 0.) # [*, k, labels] topk_label_idxes = cur_oracle_label_t[ cur_cache.bsize_range_t.unsqueeze(-1), topk_m, topk_h].unsqueeze(-1) # [*, k, 1] # todo(+N): here is the trick to avoid repeated calculations, maybe not correct when using full dynamic oracle topk_label_scores = BK.gather(cur_label_scores, topk_label_idxes, -1) - self.mw_label # todo(+N): here use both masks, which may lead to no oracles! Can we simply drop the oracle_mask? return self._new_states(flattened_states, scoring_mask_ct * cur_cache.oracle_mask_ct, topk_arc_scores, topk_m, topk_h, topk_label_scores, topk_label_idxes) elif mode == "": return [[]] * cur_bsize # todo(+N): other modes like sampling to be implemented: sample, topk-sample, gather else: raise NotImplementedError(mode)
def _score_label_selected(self, scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr=None): _, _, lm_expr, lh_expr = scoring_expr_pack # [BS, len-m, D] lh_expr_shape = BK.get_shape(lh_expr) selected_lh_expr = BK.gather( lh_expr, gold_heads_expr.unsqueeze(-1).expand(*lh_expr_shape), dim=len(lh_expr_shape) - 2) # [BS, len-m, L] select_label_score = self.scorer.score_label_select( lm_expr, selected_lh_expr, mask_expr) # margin? if training and margin > 0.: select_label_score = BK.minus_margin(select_label_score, gold_labels_expr, margin) return select_label_score
def loss(self, input_expr, loss_mask, gold_idxes, margin=0.): gold_all_idxes = self._get_all_idxes(gold_idxes) # scoring raw_scores = self._raw_scores(input_expr) raw_scores_aug = [] margin_P, margin_R, margin_T = self.conf.margin_lambda_P, self.conf.margin_lambda_R, self.conf.margin_lambda_T # gold_shape = BK.get_shape(gold_idxes) # [*] gold_bsize_prod = np.prod(gold_shape) # gold_arange_idxes = BK.arange_idx(gold_bsize_prod) # margin for i in range(self.eff_max_layer): cur_gold_inputs = gold_all_idxes[i] # add margin cur_scores = raw_scores[i] # [*, ?] cur_margin = margin * self.margin_lambdas[i] if cur_margin > 0.: cur_num_target = self.prediction_sizes[i] cur_isnil = self.layered_isnil[i].byte() # [NLab] cost_matrix = BK.constants([cur_num_target, cur_num_target], margin_T) # [gold, pred] cost_matrix[cur_isnil, :] = margin_P cost_matrix[:, cur_isnil] = margin_R diag_idxes = BK.arange_idx(cur_num_target) cost_matrix[diag_idxes, diag_idxes] = 0. margin_mat = cost_matrix[cur_gold_inputs] cur_aug_scores = cur_scores + margin_mat # [*, ?] else: cur_aug_scores = cur_scores raw_scores_aug.append(cur_aug_scores) # cascade scores final_scores = self._cascade_scores(raw_scores_aug) # loss weight, todo(note): asserted self.hl_vocab.nil_as_zero before loss_weights = ((gold_idxes == 0).float() * (self.loss_fullnil_weight - 1.) + 1.) if self.loss_fullnil_weight < 1. else 1. # calculate loss loss_prob_entropy_lambda = self.conf.loss_prob_entropy_lambda loss_prob_reweight = self.conf.loss_prob_reweight final_losses = [] no_loss_max_gold = self.conf.no_loss_max_gold if loss_mask is None: loss_mask = BK.constants(BK.get_shape(input_expr)[:-1], 1.) for i in range(self.eff_max_layer): cur_final_scores, cur_gold_inputs = final_scores[ i], gold_all_idxes[i] # [*, ?], [*] # collect the loss if self.is_hinge_loss: cur_pred_scores, cur_pred_idxes = cur_final_scores.max(-1) cur_gold_scores = BK.gather(cur_final_scores, cur_gold_inputs.unsqueeze(-1), -1).squeeze(-1) cur_loss = cur_pred_scores - cur_gold_scores # [*], todo(note): this must be >=0 if no_loss_max_gold: # this should be implicit cur_loss = cur_loss * (cur_loss > 0.).float() elif self.is_prob_loss: # cur_loss = BK.loss_nll(cur_final_scores, cur_gold_inputs) # [*] cur_loss = self._my_loss_prob(cur_final_scores, cur_gold_inputs, loss_prob_entropy_lambda, loss_mask, loss_prob_reweight) # [*] if no_loss_max_gold: cur_pred_scores, cur_pred_idxes = cur_final_scores.max(-1) cur_gold_scores = BK.gather(cur_final_scores, cur_gold_inputs.unsqueeze(-1), -1).squeeze(-1) cur_loss = cur_loss * (cur_gold_scores > cur_pred_scores).float() else: raise NotImplementedError( f"UNK loss {self.conf.loss_function}") # here first summing up, divided at the outside one_loss_sum = ( cur_loss * (loss_mask * loss_weights)).sum() * self.loss_lambdas[i] final_losses.append(one_loss_sum) # final sum final_loss_sum = BK.stack(final_losses).sum() _, ret_lab_idxes, ret_lab_embeds = self._predict(final_scores, None) return [[final_loss_sum, loss_mask.sum()]], ret_lab_idxes, ret_lab_embeds