def loss(self, insts: List[ParseInstance], enc_expr, mask_expr, **kwargs): conf = self.conf # scoring arc_score, lab_score = self._score(enc_expr, mask_expr) # [bs, m, h, *] # loss bsize, max_len = BK.get_shape(mask_expr) # gold heads and labels gold_heads_arr, _ = self.predict_padder.pad( [z.heads.vals for z in insts]) # todo(note): here use the original idx of label, no shift! gold_labels_arr, _ = self.predict_padder.pad( [z.labels.idxes for z in insts]) gold_heads_expr = BK.input_idx(gold_heads_arr) # [bs, Len] gold_labels_expr = BK.input_idx(gold_labels_arr) # [bs, Len] # collect the losses arange_bs_expr = BK.arange_idx(bsize).unsqueeze(-1) # [bs, 1] arange_m_expr = BK.arange_idx(max_len).unsqueeze(0) # [1, Len] # logsoftmax and losses arc_logsoftmaxs = BK.log_softmax(arc_score.squeeze(-1), -1) # [bs, m, h] lab_logsoftmaxs = BK.log_softmax(lab_score, -1) # [bs, m, h, Lab] arc_sel_ls = arc_logsoftmaxs[arange_bs_expr, arange_m_expr, gold_heads_expr] # [bs, Len] lab_sel_ls = lab_logsoftmaxs[arange_bs_expr, arange_m_expr, gold_heads_expr, gold_labels_expr] # [bs, Len] # head selection (no root) arc_loss_sum = (-arc_sel_ls * mask_expr)[:, 1:].sum() lab_loss_sum = (-lab_sel_ls * mask_expr)[:, 1:].sum() final_loss = conf.lambda_arc * arc_loss_sum + conf.lambda_lab * lab_loss_sum final_loss_count = mask_expr[:, 1:].sum() return [[final_loss, final_loss_count]]
def predict(self, insts: List[ParseInstance], enc_expr, mask_expr, **kwargs): conf = self.conf # scoring arc_score, lab_score = self._score(enc_expr, mask_expr) # [bs, m, h, *] full_score = BK.log_softmax(arc_score, -2) + BK.log_softmax( lab_score, -1) # [bs, m, h, Lab] # decode mst_lengths = [len(z) + 1 for z in insts] # +1 to include ROOT for mst decoding mst_lengths_arr = np.asarray(mst_lengths, dtype=np.int32) mst_heads_arr, mst_labels_arr, mst_scores_arr = \ nmst_unproj(full_score, mask_expr, mst_lengths_arr, labeled=True, ret_arr=True) # ===== assign, todo(warn): here, the labels are directly original idx, no need to change misc_prefix = "g" for one_idx, one_inst in enumerate(insts): cur_length = mst_lengths[one_idx] one_inst.pred_heads.set_vals( mst_heads_arr[one_idx] [:cur_length]) # directly int-val for heads one_inst.pred_labels.build_vals( mst_labels_arr[one_idx][:cur_length], self.label_vocab) one_scores = mst_scores_arr[one_idx][:cur_length] one_inst.pred_par_scores.set_vals(one_scores) # extra output one_inst.extra_pred_misc[misc_prefix + "_score"] = one_scores.tolist()
def loss(self, ms_items: List, bert_expr, basic_expr, margin=0.): conf = self.conf bsize = len(ms_items) # build targets (include all sents) # todo(note): use "x.entity_fillers" for getting gold args offsets_t, masks_t, _, items_arr, labels_t = PrepHelper.prep_targets( ms_items, lambda x: x.entity_fillers, True, True, conf.train_neg_rate, conf.train_neg_rate_outside, True) labels_t.clamp_(max=1) # either 0 or 1 # ----- # return 0 if all no targets if BK.get_shape(offsets_t, -1) == 0: zzz = BK.zeros([]) return [[zzz, zzz, zzz]] # ----- arange_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] sel_bert_t = bert_expr[arange_t, offsets_t] # [bsize, ?, Fold, D] sel_basic_t = None if basic_expr is None else basic_expr[ arange_t, offsets_t] # [bsize, ?, D'] hiddens = self.adp(sel_bert_t, sel_basic_t, []) # [bsize, ?, D"] # build loss logits = self.predictor(hiddens) # [bsize, ?, Out] log_probs = BK.log_softmax(logits, -1) picked_log_probs = -BK.gather_one_lastdim(log_probs, labels_t).squeeze( -1) # [bsize, ?] masked_losses = picked_log_probs * masks_t # loss_sum, loss_count, gold_count return [[ masked_losses.sum(), masks_t.sum(), (labels_t > 0).float().sum() ]]
def predict(self, repr_ef, repr_evt, lab_ef, lab_evt, mask_ef=None, mask_evt=None, ret_full_logprobs=False): # ----- ret_shape = BK.get_shape(lab_ef)[:-1] + [ BK.get_shape(lab_ef, -1), BK.get_shape(lab_evt, -1) ] if np.prod(ret_shape) == 0: if ret_full_logprobs: return BK.zeros(ret_shape + [self.num_label]) else: return BK.zeros(ret_shape), BK.zeros(ret_shape).long() # ----- # todo(note): +1 for space of DROPED(UNK) full_score = self._score(repr_ef, repr_evt, lab_ef + 1, lab_evt + 1) # [*, len-ef, len-evt, D] full_logprobs = BK.log_softmax(full_score, -1) if ret_full_logprobs: return full_logprobs else: # greedy maximum decode ret_logprobs, ret_idxes = full_logprobs.max( -1) # [*, len-ef, len-evt] # mask non-valid ones if mask_ef is not None: ret_idxes *= (mask_ef.unsqueeze(-1)).long() if mask_evt is not None: ret_idxes *= (mask_evt.unsqueeze(-2)).long() return ret_logprobs, ret_idxes
def _get_one_loss(self, predictor, hidden_t, labels_t, masks_t, lambda_loss): logits = predictor(hidden_t) # [bsize, ?, Out] log_probs = BK.log_softmax(logits, -1) picked_neg_log_probs = -BK.gather_one_lastdim( log_probs, labels_t).squeeze(-1) # [bsize, ?] masked_losses = picked_neg_log_probs * masks_t # loss_sum, loss_count, gold_count(only for type) return [ masked_losses.sum() * lambda_loss, masks_t.sum(), (labels_t > 0).float().sum() ]
def _pred_and_put_res(self, predictor, hidden_t, evt_arr, put_f): logits = predictor(hidden_t) # [bsize, ?, Out] log_probs = BK.log_softmax(logits, -1) max_log_probs, max_label_idxes = log_probs.max( -1) # [bs, ?], simply argmax prediction max_log_probs_arr, max_label_idxes_arr = BK.get_value( max_log_probs), BK.get_value(max_label_idxes) for evt_row, lprob_row, lidx_row in zip(evt_arr, max_log_probs_arr, max_label_idxes_arr): for one_evt, one_lprob, one_lidx in zip(evt_row, lprob_row, lidx_row): if one_evt is not None: put_f(one_evt, one_lprob, one_lidx) # callback for inplace setting
def _step(self, input_expr, input_mask, hard_coverage, prev_state, force_widx, force_lidx, free_beam_size): conf = self.conf free_mode = (force_widx is None) prev_state_h = prev_state[0] # ===== # collect att scores key_up = self.affine_k([input_expr, hard_coverage.unsqueeze(-1)]) # [*, slen, h] query_up = self.affine_q([self.repos.unsqueeze(0), prev_state_h.unsqueeze(-2)]) # [*, R, h] orig_scores = BK.matmul(key_up, query_up.transpose(-2, -1)) # [*, slen, R] orig_scores += (1.-input_mask).unsqueeze(-1) * Constants.REAL_PRAC_MIN # [*, slen, R] # first maximum across the R dim (this step is hard max) maxr_scores, maxr_idxes = orig_scores.max(-1) # [*, slen] if conf.zero_eos_score: # use mask to make it able to be backward tmp_mask = BK.constants(BK.get_shape(maxr_scores), 1.) tmp_mask.index_fill_(-1, BK.input_idx(0), 0.) maxr_scores *= tmp_mask # then select over the slen dim (this step is prob based) maxr_logprobs = BK.log_softmax(maxr_scores) # [*, slen] if free_mode: cur_beam_size = min(free_beam_size, BK.get_shape(maxr_logprobs, -1)) sel_tok_logprobs, sel_tok_idxes = maxr_logprobs.topk(cur_beam_size, dim=-1, sorted=False) # [*, beam] else: sel_tok_idxes = force_widx.unsqueeze(-1) # [*, 1] sel_tok_logprobs = maxr_logprobs.gather(-1, sel_tok_idxes) # [*, 1] # then collect the info and perform labeling lf_input_expr = BK.gather_first_dims(input_expr, sel_tok_idxes, -2) # [*, ?, ~] lf_coverage = hard_coverage.gather(-1, sel_tok_idxes).unsqueeze(-1) # [*, ?, 1] lf_repos = self.repos[maxr_idxes.gather(-1, sel_tok_idxes)] # [*, ?, ~] # todo(+3): using soft version? lf_prev_state = prev_state_h.unsqueeze(-2) # [*, 1, ~] lab_hid_expr = self.lab_f([lf_input_expr, lf_coverage, lf_repos, lf_prev_state]) # [*, ?, ~] # final predicting labels # todo(+N): here we select only max at labeling part, only beam at previous one if free_mode: sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, None) # [*, ?] else: sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, force_lidx.unsqueeze(-1)) # no lab-logprob (*=0) for eos (sel_tok==0) sel_lab_logprobs *= (sel_tok_idxes>0).float() # compute next-state [*, ?, ~] # todo(note): here we flatten the first two dims tmp_rnn_dims = BK.get_shape(sel_tok_idxes) + [-1] tmp_rnn_input = BK.concat([lab_hid_expr, sel_lab_embeds], -1) tmp_rnn_input = tmp_rnn_input.view(-1, BK.get_shape(tmp_rnn_input, -1)) tmp_rnn_hidden = [z.unsqueeze(-2).expand(tmp_rnn_dims).contiguous().view(-1, BK.get_shape(z, -1)) for z in prev_state] # [*, ?, ?, D] next_state = self.rnn_unit(tmp_rnn_input, tmp_rnn_hidden, None) next_state = [z.view(tmp_rnn_dims) for z in next_state] return sel_tok_idxes, sel_tok_logprobs, sel_lab_idxes, sel_lab_logprobs, sel_lab_embeds, next_state
def loss(self, repr_ef, repr_evt, lab_ef, lab_evt, mask_ef, mask_evt, gold_idxes, margin=0.): conf = self.conf # ----- if np.prod(BK.get_shape(gold_idxes)) == 0: return [[BK.zeros([]), BK.zeros([])]] # ----- # todo(note): +1 for space of DROPED(UNK) lab_ef = self._dropout_idxes(lab_ef + 1, conf.train_drop_ef_lab) lab_evt = self._dropout_idxes(lab_evt + 1, conf.train_drop_evt_lab) if conf.linker_ef_detach: repr_ef = repr_ef.detach() if conf.linker_evt_detach: repr_evt = repr_evt.detach() full_score = self._score(repr_ef, repr_evt, lab_ef, lab_evt) # [*, len-ef, len-evt, D] if margin > 0.: aug_score = BK.zeros(BK.get_shape(full_score)) + margin aug_score.scatter_(-1, gold_idxes.unsqueeze(-1), 0.) full_score += aug_score full_logprobs = BK.log_softmax(full_score, -1) gold_logprobs = full_logprobs.gather(-1, gold_idxes.unsqueeze(-1)).squeeze( -1) # [*, len-ef, len-evt] # sampling and mask loss_mask = mask_ef.unsqueeze(-1) * mask_evt.unsqueeze(-2) # ==== # first select examples (randomly) sel_mask = (BK.rand(BK.get_shape(loss_mask)) < conf.train_min_rate).float() # [*, len-ef, len-evt] # add gold and exclude pad sel_mask += (gold_idxes > 0).float() sel_mask.clamp_(max=1.) loss_mask *= sel_mask # ===== loss_sum = -(gold_logprobs * loss_mask).sum() loss_count = loss_mask.sum() ret_losses = [[loss_sum, loss_count]] return ret_losses
def _predict(self, all_scores, force_idxes): # predicting on the last one last_score = BK.log_softmax(all_scores[self.eff_max_layer - 1], -1) # [*, ?] if force_idxes is None: # todo(note): currently only do max res_logprobs, res_idxes = last_score.max(-1) # [*] else: res_idxes = force_idxes res_logprobs = last_score.gather(-1, res_idxes.unsqueeze(-1)).squeeze( -1) # [*] # lookup: [*, D] conf = self.conf if conf.use_lookup_soft: ret_lab_embeds = self.lookup_soft(all_scores) else: ret_lab_embeds = self.lookup(res_idxes) return res_logprobs, res_idxes, ret_lab_embeds
def loss(self, repr_t, attn_t, mask_t, disturb_keep_arr, **kwargs): conf = self.conf CR, PR = conf.cand_range, conf.pred_range # ----- mask_single = BK.copy(mask_t) # no predictions for ARTI_ROOT if self.add_root_token: mask_single[:, 0] = 0. # [bs, slen] # casting predicting range cur_slen = BK.get_shape(mask_single, -1) arange_t = BK.arange_idx(cur_slen) # [slen] # [1, len] - [len, 1] = [len, len] reldist_t = (arange_t.unsqueeze(-2) - arange_t.unsqueeze(-1) ) # [slen, slen] mask_pair = ((reldist_t.abs() <= CR) & (reldist_t != 0)).float() # within CR-range; [slen, slen] mask_pair = mask_pair * mask_single.unsqueeze( -1) * mask_single.unsqueeze(-2) # [bs, slen, slen] if disturb_keep_arr is not None: mask_pair *= BK.input_real(1. - disturb_keep_arr).unsqueeze( -1) # no predictions for the kept ones! # get all pair scores score_t = self.ps_node.paired_score( repr_t, repr_t, attn_t, maskp=mask_pair) # [bs, len_q, len_k, 2*R] # ----- # loss: normalize on which dim? # get the answers first if conf.pred_abs: answer_t = reldist_t.abs() # [1,2,3,...,PR] answer_t.clamp_( min=0, max=PR - 1) # [slen, slen], clip in range, distinguish using masks else: answer_t = BK.where( (reldist_t >= 0), reldist_t - 1, reldist_t + 2 * PR) # [1,2,3,...PR,-PR,-PR+1,...,-1] answer_t.clamp_( min=0, max=2 * PR - 1) # [slen, slen], clip in range, distinguish using masks # expand answer into idxes answer_hit_t = BK.zeros(BK.get_shape(answer_t) + [2 * PR]) # [len_q, len_k, 2*R] answer_hit_t.scatter_(-1, answer_t.unsqueeze(-1), 1.) answer_valid_t = ((reldist_t.abs() <= PR) & (reldist_t != 0)).float().unsqueeze( -1) # [bs, len_q, len_k, 1] answer_hit_t = answer_hit_t * mask_pair.unsqueeze( -1) * answer_valid_t # clear invalid ones; [bs, len_q, len_k, 2*R] # get losses sum(log(answer*prob)) # -- dim=-1 is standard 2*PR classification, dim=-2 usually have 2*PR candidates, but can be less at edges all_losses = [] for one_dim, one_lambda in zip([-1, -2], [conf.lambda_n1, conf.lambda_n2]): if one_lambda > 0.: # since currently there can be only one or zero correct answer logprob_t = BK.log_softmax(score_t, one_dim) # [bs, len_q, len_k, 2*R] sumlogprob_t = (logprob_t * answer_hit_t).sum( one_dim) # [bs, len_q, len_k||2*R] cur_dim_mask_t = (answer_hit_t.sum(one_dim) > 0.).float() # [bs, len_q, len_k||2*R] # loss cur_dim_loss = -(sumlogprob_t * cur_dim_mask_t).sum() cur_dim_count = cur_dim_mask_t.sum() # argmax and corr (any correct counts) _, cur_argmax_idxes = score_t.max(one_dim) cur_corrs = answer_hit_t.gather( one_dim, cur_argmax_idxes.unsqueeze( one_dim)) # [bs, len_q, len_k|1, 2*R|1] cur_dim_corr_count = cur_corrs.sum() # compile loss one_loss = LossHelper.compile_leaf_info( f"d{one_dim}", cur_dim_loss, cur_dim_count, loss_lambda=one_lambda, corr=cur_dim_corr_count) all_losses.append(one_loss) return self._compile_component_loss("orp", all_losses)
def fb_on_batch(self, annotated_insts, training=True, loss_factor=1, **kwargs): self.refresh_batch(training) margin = self.margin.value # gold heads and labels gold_heads_arr, _ = self.predict_padder.pad( [z.heads.vals for z in annotated_insts]) gold_labels_arr, _ = self.predict_padder.pad( [self.real2pred_labels(z.labels.idxes) for z in annotated_insts]) gold_heads_expr = BK.input_idx(gold_heads_arr) # [BS, Len] gold_labels_expr = BK.input_idx(gold_labels_arr) # [BS, Len] # ===== calculate scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score( annotated_insts, training) full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr, training, margin, gold_heads_expr) # final_losses = None if self.norm_local or self.norm_single: select_label_score = self._score_label_selected( scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) # already added margin previously losses_heads = losses_labels = None if self.loss_prob: if self.norm_local: losses_heads = BK.loss_nll(full_arc_score, gold_heads_expr) losses_labels = BK.loss_nll(select_label_score, gold_labels_expr) elif self.norm_single: single_sample = self.conf.tconf.loss_single_sample losses_heads = self._losses_single(full_arc_score, gold_heads_expr, single_sample, is_hinge=False) losses_labels = self._losses_single(select_label_score, gold_labels_expr, single_sample, is_hinge=False) # simply adding final_losses = losses_heads + losses_labels elif self.loss_hinge: if self.norm_local: losses_heads = BK.loss_hinge(full_arc_score, gold_heads_expr) losses_labels = BK.loss_hinge(select_label_score, gold_labels_expr) elif self.norm_single: single_sample = self.conf.tconf.loss_single_sample losses_heads = self._losses_single(full_arc_score, gold_heads_expr, single_sample, is_hinge=True, margin=margin) losses_labels = self._losses_single(select_label_score, gold_labels_expr, single_sample, is_hinge=True, margin=margin) # simply adding final_losses = losses_heads + losses_labels elif self.loss_mr: # special treatment! probs_heads = BK.softmax(full_arc_score, dim=-1) # [bs, m, h] probs_labels = BK.softmax(select_label_score, dim=-1) # [bs, m, h] # select probs_head_gold = BK.gather_one_lastdim( probs_heads, gold_heads_expr).squeeze(-1) # [bs, m] probs_label_gold = BK.gather_one_lastdim( probs_labels, gold_labels_expr).squeeze(-1) # [bs, m] # root and pad will be excluded later # Reward = \sum_i 1.*marginal(GEdge_i); while for global models, need to gradient on marginal-functions # todo(warn): have problem since steps will be quite small, not used! final_losses = (mask_expr - probs_head_gold * probs_label_gold ) # let loss>=0 elif self.norm_global: full_label_score = self._score_label_full(scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) # for this one, use the merged full score full_score = full_arc_score.unsqueeze( -1) + full_label_score # [BS, m, h, L] # +=1 to include ROOT for mst decoding mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts], dtype=np.int32) # do inference if self.loss_prob: marginals_expr = self._marginal( full_score, mask_expr, mst_lengths_arr) # [BS, m, h, L] final_losses = self._losses_global_prob( full_score, gold_heads_expr, gold_labels_expr, marginals_expr, mask_expr) if self.alg_proj: # todo(+N): deal with search-error-like problem, discard unproj neg losses (score>weighted-avg), # but this might be too loose, although the unproj edges are few? gold_unproj_arr, _ = self.predict_padder.pad( [z.unprojs for z in annotated_insts]) gold_unproj_expr = BK.input_real( gold_unproj_arr) # [BS, Len] comparing_expr = Constants.REAL_PRAC_MIN * ( 1. - gold_unproj_expr) final_losses = BK.max_elem(final_losses, comparing_expr) elif self.loss_hinge: pred_heads_arr, pred_labels_arr, _ = self._decode( full_score, mask_expr, mst_lengths_arr) pred_heads_expr = BK.input_idx(pred_heads_arr) # [BS, Len] pred_labels_expr = BK.input_idx(pred_labels_arr) # [BS, Len] # final_losses = self._losses_global_hinge( full_score, gold_heads_expr, gold_labels_expr, pred_heads_expr, pred_labels_expr, mask_expr) elif self.loss_mr: # todo(+N): Loss = -Reward = \sum marginals, which requires gradients on marginal-one-edge, or marginal-two-edges raise NotImplementedError( "Not implemented for global-loss + mr.") elif self.norm_hlocal: # firstly label losses are the same select_label_score = self._score_label_selected( scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) losses_labels = BK.loss_nll(select_label_score, gold_labels_expr) # then specially for arc loss children_masks_arr, _ = self.hlocal_padder.pad( [z.get_children_mask_arr() for z in annotated_insts]) children_masks_expr = BK.input_real( children_masks_arr) # [bs, h, m] # [bs, h] # todo(warn): use prod rather than sum, but still only an approximation for the top-down # losses_arc = -BK.log(BK.sum(BK.softmax(full_arc_score, -2).transpose(-1, -2) * children_masks_expr, dim=-1) + (1-mask_expr)) losses_arc = -BK.sum(BK.log_softmax(full_arc_score, -2).transpose( -1, -2) * children_masks_expr, dim=-1) # including the root-head is important losses_arc[:, 1] += losses_arc[:, 0] final_losses = losses_arc + losses_labels # # jpos loss? (the same mask as parsing) jpos_losses_expr = jpos_pack[1] if jpos_losses_expr is not None: final_losses += jpos_losses_expr # collect loss with mask, also excluding the first symbol of ROOT final_losses_masked = (final_losses * mask_expr)[:, 1:] final_loss_sum = BK.sum(final_losses_masked) # divide loss by what? num_sent = len(annotated_insts) num_valid_tok = sum(len(z) for z in annotated_insts) if self.conf.tconf.loss_div_tok: final_loss = final_loss_sum / num_valid_tok else: final_loss = final_loss_sum / num_sent # final_loss_sum_val = float(BK.get_value(final_loss_sum)) info = { "sent": num_sent, "tok": num_valid_tok, "loss_sum": final_loss_sum_val } if training: BK.backward(final_loss, loss_factor) return info
def inference_on_batch(self, insts: List[ParseInstance], **kwargs): with BK.no_grad_env(): self.refresh_batch(False) # ===== calculate scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score( insts, False) full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr, False, 0.) full_label_score = self._score_label_full(scoring_expr_pack, mask_expr, False, 0.) # normalizing scores full_score = None final_exp_score = False # whether to provide PROB by exp if self.norm_local and self.loss_prob: full_score = BK.log_softmax(full_arc_score, -1).unsqueeze(-1) + BK.log_softmax( full_label_score, -1) final_exp_score = True elif self.norm_hlocal and self.loss_prob: # normalize at m dimension, ignore each nodes's self-finish step. full_score = BK.log_softmax(full_arc_score, -2).unsqueeze(-1) + BK.log_softmax( full_label_score, -1) elif self.norm_single and self.loss_prob: if self.conf.iconf.dec_single_neg: # todo(+2): add all-neg for prob explanation full_arc_probs = BK.sigmoid(full_arc_score) full_label_probs = BK.sigmoid(full_label_score) fake_arc_scores = BK.log(full_arc_probs) - BK.log( 1. - full_arc_probs) fake_label_scores = BK.log(full_label_probs) - BK.log( 1. - full_label_probs) full_score = fake_arc_scores.unsqueeze( -1) + fake_label_scores else: full_score = BK.logsigmoid(full_arc_score).unsqueeze( -1) + BK.logsigmoid(full_label_score) final_exp_score = True else: full_score = full_arc_score.unsqueeze(-1) + full_label_score # decode mst_lengths = [len(z) + 1 for z in insts ] # +=1 to include ROOT for mst decoding mst_heads_arr, mst_labels_arr, mst_scores_arr = self._decode( full_score, mask_expr, np.asarray(mst_lengths, dtype=np.int32)) if final_exp_score: mst_scores_arr = np.exp(mst_scores_arr) # jpos prediction (directly index, no converting as in parsing) jpos_preds_expr = jpos_pack[2] has_jpos_pred = jpos_preds_expr is not None jpos_preds_arr = BK.get_value( jpos_preds_expr) if has_jpos_pred else None # ===== assign info = {"sent": len(insts), "tok": sum(mst_lengths) - len(insts)} mst_real_labels = self.pred2real_labels(mst_labels_arr) for one_idx, one_inst in enumerate(insts): cur_length = mst_lengths[one_idx] one_inst.pred_heads.set_vals( mst_heads_arr[one_idx] [:cur_length]) # directly int-val for heads one_inst.pred_labels.build_vals( mst_real_labels[one_idx][:cur_length], self.label_vocab) one_inst.pred_par_scores.set_vals( mst_scores_arr[one_idx][:cur_length]) if has_jpos_pred: one_inst.pred_poses.build_vals( jpos_preds_arr[one_idx][:cur_length], self.bter.pos_vocab) return info
def predict(self, ms_items: List, bert_expr, basic_expr): conf = self.conf bsize = len(ms_items) # build targets (include all sents) offsets_t, masks_t, _, _, _ = PrepHelper.prep_targets( ms_items, lambda x: [], True, True, 1., 1., False) arange_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] sel_bert_t = bert_expr[arange_t, offsets_t] # [bsize, ?, Fold, D] sel_basic_t = None if basic_expr is None else basic_expr[ arange_t, offsets_t] # [bsize, ?, D'] hiddens = self.adp(sel_bert_t, sel_basic_t, []) # [bsize, ?, D"] logits = self.predictor(hiddens) # [bsize, ?, Out] # ----- log_probs = BK.log_softmax(logits, -1) log_probs[:, :, 0] -= conf.nil_penalty # encourage more predictions topk_log_probs, topk_log_labels = log_probs.max( dim=-1) # [bsize, ?, k] # decoding head_offsets_arr = BK.get_value(offsets_t) # [bs, ?] masks_arr = BK.get_value(masks_t) topk_log_probs_arr, topk_log_labels_arr = BK.get_value( topk_log_probs), BK.get_value(topk_log_labels) # [bsize, ?, k] for one_ms_item, one_offsets_arr, one_masks_arr, one_logprobs_arr, one_labels_arr \ in zip(ms_items, head_offsets_arr, masks_arr, topk_log_probs_arr, topk_log_labels_arr): # build tidx2sidx one_sents = one_ms_item.sents one_offsets = one_ms_item.offsets tidx2sidx = [] for idx in range(1, len(one_offsets)): tidx2sidx.extend([idx - 1] * (one_offsets[idx] - one_offsets[idx - 1])) # get all candidates all_candidates = [[] for _ in one_sents] for cur_offset, cur_valid, cur_logprob, cur_label in zip( one_offsets_arr, one_masks_arr, one_logprobs_arr, one_labels_arr): if not cur_valid or cur_label <= 0: continue # which sent cur_offset = int(cur_offset) cur_sidx = tidx2sidx[cur_offset] cur_sent = one_sents[cur_sidx] minus_offset = one_ms_item.offsets[ cur_sidx] - 1 # again consider the ROOT cur_mention = Mention( HardSpan(cur_sent.sid, cur_offset - minus_offset, None, None)) all_candidates[cur_sidx].append( (cur_sent, cur_mention, cur_label, cur_logprob)) # keep certain ratio for each sent separately? final_candidates = [] if conf.pred_sent_ratio_sep: for one_sent, one_sent_candidates in zip( one_sents, all_candidates): cur_keep_num = max( int(conf.pred_sent_ratio * (one_sent.length - 1)), 1) one_sent_candidates.sort(key=lambda x: x[-1], reverse=True) final_candidates.extend(one_sent_candidates[:cur_keep_num]) else: all_size = 0 for one_sent, one_sent_candidates in zip( one_sents, all_candidates): all_size += one_sent.length - 1 final_candidates.extend(one_sent_candidates) final_candidates.sort(key=lambda x: x[-1], reverse=True) final_keep_num = max(int(conf.pred_sent_ratio * all_size), len(one_sents)) final_candidates = final_candidates[:final_keep_num] # add them all for cur_sent, cur_mention, cur_label, cur_logprob in final_candidates: cur_logprob = float(cur_logprob) doc_id = cur_sent.doc.doc_id self.id_counter[doc_id] += 1 new_id = f"ef-{doc_id}-{self.id_counter[doc_id]}" hlidx = self.valid_hlidx new_ef = EntityFiller(new_id, cur_mention, str(hlidx), None, True, type_idx=hlidx, score=cur_logprob) cur_sent.pred_entity_fillers.append(new_ef)
def loss(self, repr_ts, input_erase_mask_arr, orig_map: Dict, active_hid=True, **kwargs): conf = self.conf _tie_input_embeddings = conf.tie_input_embeddings # prepare idxes for the masked ones if self.add_root_token: # offset for the special root added in embedder mask_idxes, mask_valids = BK.mask2idx( BK.input_real(input_erase_mask_arr), padding_idx=-1) # [bsize, ?] repr_mask_idxes = mask_idxes + 1 mask_idxes.clamp_(min=0) else: mask_idxes, mask_valids = BK.mask2idx( BK.input_real(input_erase_mask_arr)) # [bsize, ?] repr_mask_idxes = mask_idxes # get the losses if BK.get_shape(mask_idxes, -1) == 0: # no loss return self._compile_component_loss("mlm", []) else: if not isinstance(repr_ts, (List, Tuple)): repr_ts = [repr_ts] target_word_scores, target_pos_scores = [], [] target_pos_scores = None # todo(+N): for simplicity, currently ignore this one!! for layer_idx in conf.loss_layers: # calculate scores target_reprs = BK.gather_first_dims(repr_ts[layer_idx], repr_mask_idxes, 1) # [bsize, ?, *] if self.hid_layer and active_hid: # todo(+N): sometimes, we only want last softmax, need to ensure dim at outside! target_hids = self.hid_layer(target_reprs) else: target_hids = target_reprs if _tie_input_embeddings: pred_W = self.inputter_word_node.E.E[:self. pred_word_size] # [PSize, Dim] target_word_scores.append(BK.matmul( target_hids, pred_W.T)) # List[bsize, ?, Vw] else: target_word_scores.append(self.pred_word_layer( target_hids)) # List[bsize, ?, Vw] # gather the losses all_losses = [] for pred_name, target_scores, loss_lambda, range_min, range_max in \ zip(["word", "pos"], [target_word_scores, target_pos_scores], [conf.lambda_word, conf.lambda_pos], [conf.min_pred_rank, 0], [min(conf.max_pred_rank, self.pred_word_size-1), self.pred_pos_size-1]): if loss_lambda > 0.: seq_idx_t = BK.input_idx( orig_map[pred_name]) # [bsize, slen] target_idx_t = seq_idx_t.gather(-1, mask_idxes) # [bsize, ?] ranged_mask_valids = mask_valids * ( target_idx_t >= range_min).float() * ( target_idx_t <= range_max).float() target_idx_t[(ranged_mask_valids < 1.)] = 0 # make sure invalid ones in range # calculate for each layer all_layer_losses, all_layer_scores = [], [] for one_layer_idx, one_target_scores in enumerate( target_scores): # get loss: [bsize, ?] one_pred_losses = BK.loss_nll( one_target_scores, target_idx_t) * conf.loss_weights[one_layer_idx] all_layer_losses.append(one_pred_losses) # get scores one_pred_scores = BK.log_softmax( one_target_scores, -1) * conf.loss_weights[one_layer_idx] all_layer_scores.append(one_pred_scores) # combine all layers pred_losses = self.loss_comb_f(all_layer_losses) pred_loss_sum = (pred_losses * ranged_mask_valids).sum() pred_loss_count = ranged_mask_valids.sum() # argmax _, argmax_idxes = self.score_comb_f(all_layer_scores).max( -1) pred_corrs = (argmax_idxes == target_idx_t).float() * ranged_mask_valids pred_corr_count = pred_corrs.sum() # compile leaf loss r_loss = LossHelper.compile_leaf_info( pred_name, pred_loss_sum, pred_loss_count, loss_lambda=loss_lambda, corr=pred_corr_count) all_losses.append(r_loss) return self._compile_component_loss("mlm", all_losses)