def _get_rel_dist(self, len_q: int, len_k: int = None): if len_k is None: len_k = len_q dist_x = BK.arange_idx(0, len_k).unsqueeze(0) # [1, len_k] dist_y = BK.arange_idx(0, len_q).unsqueeze(1) # [len_q, 1] distance = dist_x - dist_y # [len_q, len_k] return distance
def loss(self, ms_items: List, bert_expr): conf = self.conf max_range = self.conf.max_range bsize = len(ms_items) # collect instances col_efs, _, col_bidxes_t, col_hidxes_t, col_ldists_t, col_rdists_t = self._collect_insts( ms_items, True) if len(col_efs) == 0: zzz = BK.zeros([]) return [[zzz, zzz, zzz], [zzz, zzz, zzz]] left_scores, right_scores = self._score(bert_expr, col_bidxes_t, col_hidxes_t) # [N, R] if conf.use_binary_scorer: left_binaries, right_binaries = (BK.arange_idx(max_range)<=col_ldists_t.unsqueeze(-1)).float(), \ (BK.arange_idx(max_range)<=col_rdists_t.unsqueeze(-1)).float() # [N,R] left_losses = BK.binary_cross_entropy_with_logits( left_scores, left_binaries, reduction='none')[:, 1:] right_losses = BK.binary_cross_entropy_with_logits( right_scores, right_binaries, reduction='none')[:, 1:] left_count = right_count = BK.input_real( BK.get_shape(left_losses, 0) * (max_range - 1)) else: left_losses = BK.loss_nll(left_scores, col_ldists_t) right_losses = BK.loss_nll(right_scores, col_rdists_t) left_count = right_count = BK.input_real( BK.get_shape(left_losses, 0)) return [[left_losses.sum(), left_count, left_count], [right_losses.sum(), right_count, right_count]]
def loss(self, insts: List[ParseInstance], enc_expr, mask_expr, **kwargs): conf = self.conf # scoring arc_score, lab_score = self._score(enc_expr, mask_expr) # [bs, m, h, *] # loss bsize, max_len = BK.get_shape(mask_expr) # gold heads and labels gold_heads_arr, _ = self.predict_padder.pad( [z.heads.vals for z in insts]) # todo(note): here use the original idx of label, no shift! gold_labels_arr, _ = self.predict_padder.pad( [z.labels.idxes for z in insts]) gold_heads_expr = BK.input_idx(gold_heads_arr) # [bs, Len] gold_labels_expr = BK.input_idx(gold_labels_arr) # [bs, Len] # collect the losses arange_bs_expr = BK.arange_idx(bsize).unsqueeze(-1) # [bs, 1] arange_m_expr = BK.arange_idx(max_len).unsqueeze(0) # [1, Len] # logsoftmax and losses arc_logsoftmaxs = BK.log_softmax(arc_score.squeeze(-1), -1) # [bs, m, h] lab_logsoftmaxs = BK.log_softmax(lab_score, -1) # [bs, m, h, Lab] arc_sel_ls = arc_logsoftmaxs[arange_bs_expr, arange_m_expr, gold_heads_expr] # [bs, Len] lab_sel_ls = lab_logsoftmaxs[arange_bs_expr, arange_m_expr, gold_heads_expr, gold_labels_expr] # [bs, Len] # head selection (no root) arc_loss_sum = (-arc_sel_ls * mask_expr)[:, 1:].sum() lab_loss_sum = (-lab_sel_ls * mask_expr)[:, 1:].sum() final_loss = conf.lambda_arc * arc_loss_sum + conf.lambda_lab * lab_loss_sum final_loss_count = mask_expr[:, 1:].sum() return [[final_loss, final_loss_count]]
def _enc(self, input_lexi, input_expr, input_mask, sel_idxes): if self.dmxnn: bsize, slen = BK.get_shape(input_mask) if sel_idxes is None: sel_idxes = BK.arange_idx(slen).unsqueeze( 0) # select all, [1, slen] ncand = BK.get_shape(sel_idxes, -1) # enc_expr aug with PE rel_dist = BK.arange_idx(slen).unsqueeze(0).unsqueeze( 0) - sel_idxes.unsqueeze(-1) # [*, ?, slen] pe_embeds = self.posi_embed(rel_dist) # [*, ?, slen, Dpe] aug_enc_expr = BK.concat([ pe_embeds.expand(bsize, -1, -1, -1), input_expr.unsqueeze(1).expand(-1, ncand, -1, -1) ], -1) # [*, ?, slen, D+Dpe] # [*, ?, slen, Denc] hidden_expr = self.e_encoder( aug_enc_expr.view(bsize * ncand, slen, -1), input_mask.unsqueeze(1).expand(-1, ncand, -1).contiguous().view( bsize * ncand, slen)) hidden_expr = hidden_expr.view(bsize, ncand, slen, -1) # dynamic max-pooling (dist<0, dist=0, dist>0) NEG = Constants.REAL_PRAC_MIN mp_hiddens = [] mp_masks = [rel_dist < 0, rel_dist == 0, rel_dist > 0] for mp_mask in mp_masks: float_mask = mp_mask.float() * input_mask.unsqueeze( -2) # [*, ?, slen] valid_mask = (float_mask.sum(-1) > 0.).float().unsqueeze( -1) # [*, ?, 1] mask_neg_val = ( 1. - float_mask).unsqueeze(-1) * NEG # [*, ?, slen, 1] # todo(+2): or do we simply multiply mask? mp_hid0 = (hidden_expr + mask_neg_val).max(-2)[0] mp_hid = mp_hid0 * valid_mask # [*, ?, Denc] mp_hiddens.append(self.special_drop(mp_hid)) # mp_hiddens.append(mp_hid) final_hiddens = mp_hiddens else: hidden_expr = self.e_encoder(input_expr, input_mask) # [*, slen, D'] if sel_idxes is None: hidden_expr1 = hidden_expr else: hidden_expr1 = BK.gather_first_dims(hidden_expr, sel_idxes, -2) # [*, ?, D'] final_hiddens = [self.special_drop(hidden_expr1)] if self.lab_f_use_lexi: final_hiddens.append( BK.gather_first_dims(input_lexi, sel_idxes, -2)) # [*, ?, DLex] ret_expr = self.lab_f(final_hiddens) # [*, ?, DLab] return ret_expr
def loss(self, ms_items: List, bert_expr, basic_expr, margin=0.): conf = self.conf bsize = len(ms_items) # build targets (include all sents) # todo(note): use "x.entity_fillers" for getting gold args offsets_t, masks_t, _, items_arr, labels_t = PrepHelper.prep_targets( ms_items, lambda x: x.entity_fillers, True, True, conf.train_neg_rate, conf.train_neg_rate_outside, True) labels_t.clamp_(max=1) # either 0 or 1 # ----- # return 0 if all no targets if BK.get_shape(offsets_t, -1) == 0: zzz = BK.zeros([]) return [[zzz, zzz, zzz]] # ----- arange_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] sel_bert_t = bert_expr[arange_t, offsets_t] # [bsize, ?, Fold, D] sel_basic_t = None if basic_expr is None else basic_expr[ arange_t, offsets_t] # [bsize, ?, D'] hiddens = self.adp(sel_bert_t, sel_basic_t, []) # [bsize, ?, D"] # build loss logits = self.predictor(hiddens) # [bsize, ?, Out] log_probs = BK.log_softmax(logits, -1) picked_log_probs = -BK.gather_one_lastdim(log_probs, labels_t).squeeze( -1) # [bsize, ?] masked_losses = picked_log_probs * masks_t # loss_sum, loss_count, gold_count return [[ masked_losses.sum(), masks_t.sum(), (labels_t > 0).float().sum() ]]
def __call__(self, word_arr: np.ndarray = None, char_arr: np.ndarray = None, extra_arrs: Iterable[np.ndarray] = (), aux_arrs: Iterable[np.ndarray] = ()): exprs = [] # word/char/extras/posi seq_shape = None if self.has_word: # todo(warn): singleton-UNK-dropout should be done outside before seq_shape = word_arr.shape word_expr = self.dropmd_word(self.word_embed(word_arr)) exprs.append(word_expr) if self.has_char: seq_shape = char_arr.shape[:-1] char_embeds = self.char_embed( char_arr) # [*, seq-len, word-len, D] char_cat_expr = self.dropmd_char( BK.concat([z(char_embeds) for z in self.char_cnns])) exprs.append(char_cat_expr) zcheck( len(extra_arrs) == len(self.extra_embeds), "Unmatched extra fields.") for one_extra_arr, one_extra_embed, one_extra_dropmd in zip( extra_arrs, self.extra_embeds, self.dropmd_extras): seq_shape = one_extra_arr.shape exprs.append(one_extra_dropmd(one_extra_embed(one_extra_arr))) if self.has_posi: seq_len = seq_shape[-1] posi_idxes = BK.arange_idx(seq_len) posi_input0 = self.posi_embed(posi_idxes) for _ in range(len(seq_shape) - 1): posi_input0 = BK.unsqueeze(posi_input0, 0) posi_input1 = BK.expand(posi_input0, tuple(seq_shape) + (-1, )) exprs.append(posi_input1) # assert len(aux_arrs) == len(self.drop_auxes) for one_aux_arr, one_aux_dim, one_aux_drop, one_fold, one_gamma, one_lambdas in \ zip(aux_arrs, self.dim_auxes, self.drop_auxes, self.fold_auxes, self.aux_overall_gammas, self.aux_fold_lambdas): # fold and apply trainable lambdas input_aux_repr = BK.input_real(one_aux_arr) input_shape = BK.get_shape(input_aux_repr) # todo(note): assume the original concat is [fold/layer, D] reshaped_aux_repr = input_aux_repr.view( input_shape[:-1] + [one_fold, one_aux_dim]) # [*, slen, fold, D] lambdas_softmax = BK.softmax(one_gamma, -1).unsqueeze(-1) # [fold, 1] weighted_aux_repr = (reshaped_aux_repr * lambdas_softmax ).sum(-2) * one_gamma # [*, slen, D] one_aux_expr = one_aux_drop(weighted_aux_repr) exprs.append(one_aux_expr) # concated_exprs = BK.concat(exprs, dim=-1) # optional proj if self.has_proj: final_expr = self.final_layer(concated_exprs) else: final_expr = concated_exprs return final_expr
def loss(self, ms_items: List, bert_expr, basic_expr): conf = self.conf bsize = len(ms_items) # use gold targets: only use positive samples!! offsets_t, masks_t, _, items_arr, labels_t = PrepHelper.prep_targets( ms_items, lambda x: x.events, True, False, 0., 0., True) # [bs, ?] realis_flist = [(-1 if (z is None or z.realis_idx is None) else z.realis_idx) for z in items_arr.flatten()] realis_t = BK.input_idx(realis_flist).view(items_arr.shape) # [bs, ?] realis_mask = (realis_t >= 0).float() realis_t.clamp_(min=0) # make sure all idxes are legal # ----- # return 0 if all no targets if BK.get_shape(offsets_t, -1) == 0: zzz = BK.zeros([]) return [[zzz, zzz, zzz], [zzz, zzz, zzz]] # realis, types # ----- arange_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] sel_bert_t = bert_expr[arange_t, offsets_t] # [bsize, ?, Fold, D] sel_basic_t = None if basic_expr is None else basic_expr[ arange_t, offsets_t] # [bsize, ?, D'] hiddens = self.adp(sel_bert_t, sel_basic_t, []) # [bsize, ?, D"] # build losses loss_item_realis = self._get_one_loss(self.realis_predictor, hiddens, realis_t, realis_mask, conf.lambda_realis) loss_item_type = self._get_one_loss(self.type_predictor, hiddens, labels_t, masks_t, conf.lambda_type) return [loss_item_realis, loss_item_type]
def _decode(self, insts: List[ParseInstance], full_score, mask_expr, misc_prefix): # decode mst_lengths = [len(z) + 1 for z in insts] # +=1 to include ROOT for mst decoding mst_lengths_arr = np.asarray(mst_lengths, dtype=np.int32) mst_heads_arr, mst_labels_arr, mst_scores_arr = nmst_unproj( full_score, mask_expr, mst_lengths_arr, labeled=True, ret_arr=True) if self.conf.iconf.output_marginals: # todo(note): here, we care about marginals for arc # lab_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True) arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).sum(-1) bsize, max_len = BK.get_shape(mask_expr) idxes_bs_expr = BK.arange_idx(bsize).unsqueeze(-1) idxes_m_expr = BK.arange_idx(max_len).unsqueeze(0) output_marg = arc_marginals[idxes_bs_expr, idxes_m_expr, BK.input_idx(mst_heads_arr)] mst_marg_arr = BK.get_value(output_marg) else: mst_marg_arr = None # ===== assign, todo(warn): here, the labels are directly original idx, no need to change for one_idx, one_inst in enumerate(insts): cur_length = mst_lengths[one_idx] one_inst.pred_heads.set_vals( mst_heads_arr[one_idx] [:cur_length]) # directly int-val for heads one_inst.pred_labels.build_vals( mst_labels_arr[one_idx][:cur_length], self.label_vocab) one_scores = mst_scores_arr[one_idx][:cur_length] one_inst.pred_par_scores.set_vals(one_scores) # extra output one_inst.extra_pred_misc[misc_prefix + "_score"] = one_scores.tolist() if mst_marg_arr is not None: one_inst.extra_pred_misc[ misc_prefix + "_marg"] = mst_marg_arr[one_idx][:cur_length].tolist()
def _score(self, bert_expr, bidxes_t, hidxes_t): # ---- # # debug # print(f"# ====\n Debug: {ArgSpanExpander._debug_count}") # ArgSpanExpander._debug_count += 1 # ---- bert_expr = bert_expr.view(BK.get_shape(bert_expr)[:-2] + [-1]) # flatten # max_range = self.conf.max_range max_slen = BK.get_shape(bert_expr, 1) # get candidates range_t = BK.arange_idx(max_range).unsqueeze(0) # [1, R] bidxes_t = bidxes_t.unsqueeze(1) # [N, 1] hidxes_t = hidxes_t.unsqueeze(1) # [N, 1] left_cands = hidxes_t - range_t # [N, R] right_cands = hidxes_t + range_t left_masks = (left_cands >= 0).float() right_masks = (right_cands < max_slen).float() left_cands.clamp_(min=0) right_cands.clamp_(max=max_slen - 1) # score head_exprs = bert_expr[bidxes_t, hidxes_t] # [N, 1, D'] left_cand_exprs = bert_expr[bidxes_t, left_cands] # [N, R, D'] right_cand_exprs = bert_expr[bidxes_t, right_cands] # actual scoring if self.use_lstm_scorer: batch_size = BK.get_shape(bidxes_t, 0) all_concat_outputs = [] for cand_exprs, lstm_node in zip( [left_cand_exprs, right_cand_exprs], [self.llstm, self.rlstm]): cur_state = lstm_node.zero_init_hidden(batch_size) step_size = BK.get_shape(cand_exprs, 1) all_outputs = [] for step_i in range(step_size): cur_state = lstm_node(cand_exprs[:, step_i], cur_state, None) all_outputs.append(cur_state[0]) # using h concat_output = BK.stack(all_outputs, 1) # [N, R, ?] all_concat_outputs.append(concat_output) left_hidden, right_hidden = all_concat_outputs left_scores = self.lscorer(left_hidden).squeeze(-1) # [N, R] right_scores = self.rscorer(right_hidden).squeeze(-1) # [N, R] else: left_scores = self.lscorer([left_cand_exprs, head_exprs]).squeeze(-1) # [N, R] right_scores = self.rscorer([right_cand_exprs, head_exprs]).squeeze(-1) # mask left_scores += Constants.REAL_PRAC_MIN * (1. - left_masks) right_scores += Constants.REAL_PRAC_MIN * (1. - right_masks) return left_scores, right_scores
def predict(self, insts: List, input_lexi, input_expr, input_mask): input_mask[:, 0] = 0. # no artificial root final_score, attn, attn2 = self._score(input_expr, input_mask) pred_mask = self._predict(final_score, attn, attn2, input_mask) # [*, slen, L] sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds = self._pmask2idxes( pred_mask) all_logprobs = final_score.log().unsqueeze(-2) + ( attn + 1e-10).log() # [*, slen, L] bsize = len(insts) sel_lab_logprobs = all_logprobs[BK.arange_idx(bsize).unsqueeze(-1), sel_idxes, sel_lab_idxes] # [*, ?] return sel_lab_logprobs, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
def __call__(self, input_v, add_root_token: bool): if isinstance(input_v, np.ndarray): # direct use this [batch_size, slen] as input posi_idxes = BK.input_idx(input_v) expr = self.node(posi_idxes) # [batch_size, slen, D] else: # input is a shape as prepared by "PosiHelper" batch_size, max_len = input_v if add_root_token: max_len += 1 posi_idxes = BK.arange_idx(max_len) # [1?+slen] add root=0 here expr = self.node(posi_idxes).unsqueeze(0).expand(batch_size, -1, -1) return self.dropout(expr)
def _add_margin_inplaced(self, shape, hit_idxes0, hit_idxes1, hit_idxes2, hit_labels, query_idxes0, query_idxes1, query_idxes2, arc_scores, lab_scores, arc_margin: float, lab_margin: float): # arc gold_arc_mat = BK.constants(shape, 0.) gold_arc_mat[hit_idxes0, hit_idxes1, hit_idxes2] = arc_margin gold_arc_margins = gold_arc_mat[query_idxes0, query_idxes1, query_idxes2] arc_scores -= gold_arc_margins if lab_scores is not None: # label gold_lab_mat = BK.constants_idx(shape, 0) # 0 means the padding idx gold_lab_mat[hit_idxes0, hit_idxes1, hit_idxes2] = hit_labels gold_lab_margin_idxes = gold_lab_mat[query_idxes0, query_idxes1, query_idxes2] lab_scores[BK.arange_idx(BK.get_shape(gold_lab_margin_idxes, 0)), gold_lab_margin_idxes] -= lab_margin return
def postprocess_scores(self, scores_expr, mask_expr, margin, gold_heads_expr, gold_labels_expr): final_full_scores = scores_expr # first apply mask mask_value = Constants.REAL_PRAC_MIN mask_mul = (mask_value * (1. - mask_expr)).unsqueeze(-1) # [*, len, 1] final_full_scores += mask_mul.unsqueeze(-2) final_full_scores += mask_mul.unsqueeze(-3) # then margin if margin > 0.: full_shape = BK.get_shape(final_full_scores) # combine the first two dim, and minus margin correspondingly combined_size = full_shape[0] * full_shape[1] combiend_score_expr = final_full_scores.view([combined_size] + full_shape[-2:]) arange_idx_expr = BK.arange_idx(combined_size) combiend_score_expr[arange_idx_expr, gold_heads_expr.view(-1)] -= margin combiend_score_expr[arange_idx_expr, gold_heads_expr.view(-1), gold_labels_expr.view(-1)] -= margin final_full_scores = combiend_score_expr.view(full_shape) return final_full_scores
def loss(self, insts: List, input_lexi, input_expr, input_mask, margin=0.): # todo(+N): currently margin is not used conf = self.conf bsize = len(insts) arange_t = BK.arange_idx(bsize) assert conf.train_force, "currently only have forced training" # get the gold ones gold_widxes, gold_lidxes, gold_vmasks, ret_items, _ = self.batch_inputs_g1(insts) # [*, ?] # for all the steps num_step = BK.get_shape(gold_widxes, -1) # recurrent states hard_coverage = BK.zeros(BK.get_shape(input_mask)) # [*, slen] prev_state = self.rnn_unit.zero_init_hidden(bsize) # tuple([*, D], ) all_tok_logprobs, all_lab_logprobs = [], [] for cstep in range(num_step): slice_widx, slice_lidx = gold_widxes[:,cstep], gold_lidxes[:,cstep] _, sel_tok_logprobs, _, sel_lab_logprobs, _, next_state = \ self._step(input_expr, input_mask, hard_coverage, prev_state, slice_widx, slice_lidx, None) all_tok_logprobs.append(sel_tok_logprobs) # add one of [*, 1] all_lab_logprobs.append(sel_lab_logprobs) hard_coverage = BK.copy(hard_coverage) # todo(note): cannot modify inplace! hard_coverage[arange_t, slice_widx] += 1. prev_state = [z.squeeze(-2) for z in next_state] # concat all the loss and mask # todo(note): no need to use gold_valid since things are telled in vmasks cat_tok_logprobs = BK.concat(all_tok_logprobs, -1) * gold_vmasks # [*, steps] cat_lab_logprobs = BK.concat(all_lab_logprobs, -1) * gold_vmasks loss_sum = - (cat_tok_logprobs.sum() * conf.lambda_att + cat_lab_logprobs.sum() * conf.lambda_lab) # todo(+N): here we are dividing lab_logprobs with the all-count, do we need to separate? loss_count = gold_vmasks.sum() ret_losses = [[loss_sum, loss_count]] # ===== # make eos unvalid for return ret_valid_mask = gold_vmasks * (gold_widxes>0).float() # embeddings sel_lab_embeds = self._hl_lookup(gold_lidxes) return ret_losses, ret_items, gold_widxes, ret_valid_mask, gold_lidxes, sel_lab_embeds
def predict(self, ms_items: List, bert_expr, basic_expr): conf = self.conf bsize = len(ms_items) # todo(note): use the pred_events which are shallow copied from inputs offsets_t, masks_t, _, items_arr, _ = PrepHelper.prep_targets( ms_items, lambda x: x.pred_events, True, False, 0., 0., True) # [bs, ?] # ----- if BK.get_shape(offsets_t, -1) == 0: return # no input # ----- # similar ones arange_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] sel_bert_t = bert_expr[arange_t, offsets_t] # [bsize, ?, Fold, D] sel_basic_t = None if basic_expr is None else basic_expr[ arange_t, offsets_t] # [bsize, ?, D'] hiddens = self.adp(sel_bert_t, sel_basic_t, []) # [bsize, ?, D"] # predict: only top-1! if conf.pred_realis: self._pred_and_put_res(self.realis_predictor, hiddens, items_arr, self._put_realis) if conf.pred_type: self._pred_and_put_res(self.type_predictor, hiddens, items_arr, self._put_type)
def predict(self, ms_items: List, bert_expr, basic_expr): conf = self.conf bsize = len(ms_items) # build targets (include all sents) offsets_t, masks_t, _, _, _ = PrepHelper.prep_targets( ms_items, lambda x: [], True, True, 1., 1., False) arange_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] sel_bert_t = bert_expr[arange_t, offsets_t] # [bsize, ?, Fold, D] sel_basic_t = None if basic_expr is None else basic_expr[ arange_t, offsets_t] # [bsize, ?, D'] hiddens = self.adp(sel_bert_t, sel_basic_t, []) # [bsize, ?, D"] logits = self.predictor(hiddens) # [bsize, ?, Out] # ----- log_probs = BK.log_softmax(logits, -1) log_probs[:, :, 0] -= conf.nil_penalty # encourage more predictions topk_log_probs, topk_log_labels = log_probs.max( dim=-1) # [bsize, ?, k] # decoding head_offsets_arr = BK.get_value(offsets_t) # [bs, ?] masks_arr = BK.get_value(masks_t) topk_log_probs_arr, topk_log_labels_arr = BK.get_value( topk_log_probs), BK.get_value(topk_log_labels) # [bsize, ?, k] for one_ms_item, one_offsets_arr, one_masks_arr, one_logprobs_arr, one_labels_arr \ in zip(ms_items, head_offsets_arr, masks_arr, topk_log_probs_arr, topk_log_labels_arr): # build tidx2sidx one_sents = one_ms_item.sents one_offsets = one_ms_item.offsets tidx2sidx = [] for idx in range(1, len(one_offsets)): tidx2sidx.extend([idx - 1] * (one_offsets[idx] - one_offsets[idx - 1])) # get all candidates all_candidates = [[] for _ in one_sents] for cur_offset, cur_valid, cur_logprob, cur_label in zip( one_offsets_arr, one_masks_arr, one_logprobs_arr, one_labels_arr): if not cur_valid or cur_label <= 0: continue # which sent cur_offset = int(cur_offset) cur_sidx = tidx2sidx[cur_offset] cur_sent = one_sents[cur_sidx] minus_offset = one_ms_item.offsets[ cur_sidx] - 1 # again consider the ROOT cur_mention = Mention( HardSpan(cur_sent.sid, cur_offset - minus_offset, None, None)) all_candidates[cur_sidx].append( (cur_sent, cur_mention, cur_label, cur_logprob)) # keep certain ratio for each sent separately? final_candidates = [] if conf.pred_sent_ratio_sep: for one_sent, one_sent_candidates in zip( one_sents, all_candidates): cur_keep_num = max( int(conf.pred_sent_ratio * (one_sent.length - 1)), 1) one_sent_candidates.sort(key=lambda x: x[-1], reverse=True) final_candidates.extend(one_sent_candidates[:cur_keep_num]) else: all_size = 0 for one_sent, one_sent_candidates in zip( one_sents, all_candidates): all_size += one_sent.length - 1 final_candidates.extend(one_sent_candidates) final_candidates.sort(key=lambda x: x[-1], reverse=True) final_keep_num = max(int(conf.pred_sent_ratio * all_size), len(one_sents)) final_candidates = final_candidates[:final_keep_num] # add them all for cur_sent, cur_mention, cur_label, cur_logprob in final_candidates: cur_logprob = float(cur_logprob) doc_id = cur_sent.doc.doc_id self.id_counter[doc_id] += 1 new_id = f"ef-{doc_id}-{self.id_counter[doc_id]}" hlidx = self.valid_hlidx new_ef = EntityFiller(new_id, cur_mention, str(hlidx), None, True, type_idx=hlidx, score=cur_logprob) cur_sent.pred_entity_fillers.append(new_ef)
def loss(self, repr_t, attn_t, mask_t, disturb_keep_arr, **kwargs): conf = self.conf CR, PR = conf.cand_range, conf.pred_range # ----- mask_single = BK.copy(mask_t) # no predictions for ARTI_ROOT if self.add_root_token: mask_single[:, 0] = 0. # [bs, slen] # casting predicting range cur_slen = BK.get_shape(mask_single, -1) arange_t = BK.arange_idx(cur_slen) # [slen] # [1, len] - [len, 1] = [len, len] reldist_t = (arange_t.unsqueeze(-2) - arange_t.unsqueeze(-1) ) # [slen, slen] mask_pair = ((reldist_t.abs() <= CR) & (reldist_t != 0)).float() # within CR-range; [slen, slen] mask_pair = mask_pair * mask_single.unsqueeze( -1) * mask_single.unsqueeze(-2) # [bs, slen, slen] if disturb_keep_arr is not None: mask_pair *= BK.input_real(1. - disturb_keep_arr).unsqueeze( -1) # no predictions for the kept ones! # get all pair scores score_t = self.ps_node.paired_score( repr_t, repr_t, attn_t, maskp=mask_pair) # [bs, len_q, len_k, 2*R] # ----- # loss: normalize on which dim? # get the answers first if conf.pred_abs: answer_t = reldist_t.abs() # [1,2,3,...,PR] answer_t.clamp_( min=0, max=PR - 1) # [slen, slen], clip in range, distinguish using masks else: answer_t = BK.where( (reldist_t >= 0), reldist_t - 1, reldist_t + 2 * PR) # [1,2,3,...PR,-PR,-PR+1,...,-1] answer_t.clamp_( min=0, max=2 * PR - 1) # [slen, slen], clip in range, distinguish using masks # expand answer into idxes answer_hit_t = BK.zeros(BK.get_shape(answer_t) + [2 * PR]) # [len_q, len_k, 2*R] answer_hit_t.scatter_(-1, answer_t.unsqueeze(-1), 1.) answer_valid_t = ((reldist_t.abs() <= PR) & (reldist_t != 0)).float().unsqueeze( -1) # [bs, len_q, len_k, 1] answer_hit_t = answer_hit_t * mask_pair.unsqueeze( -1) * answer_valid_t # clear invalid ones; [bs, len_q, len_k, 2*R] # get losses sum(log(answer*prob)) # -- dim=-1 is standard 2*PR classification, dim=-2 usually have 2*PR candidates, but can be less at edges all_losses = [] for one_dim, one_lambda in zip([-1, -2], [conf.lambda_n1, conf.lambda_n2]): if one_lambda > 0.: # since currently there can be only one or zero correct answer logprob_t = BK.log_softmax(score_t, one_dim) # [bs, len_q, len_k, 2*R] sumlogprob_t = (logprob_t * answer_hit_t).sum( one_dim) # [bs, len_q, len_k||2*R] cur_dim_mask_t = (answer_hit_t.sum(one_dim) > 0.).float() # [bs, len_q, len_k||2*R] # loss cur_dim_loss = -(sumlogprob_t * cur_dim_mask_t).sum() cur_dim_count = cur_dim_mask_t.sum() # argmax and corr (any correct counts) _, cur_argmax_idxes = score_t.max(one_dim) cur_corrs = answer_hit_t.gather( one_dim, cur_argmax_idxes.unsqueeze( one_dim)) # [bs, len_q, len_k|1, 2*R|1] cur_dim_corr_count = cur_corrs.sum() # compile loss one_loss = LossHelper.compile_leaf_info( f"d{one_dim}", cur_dim_loss, cur_dim_count, loss_lambda=one_lambda, corr=cur_dim_corr_count) all_losses.append(one_loss) return self._compile_component_loss("orp", all_losses)
def run(self, insts: List[DocInstance], training: bool): conf = self.conf BERT_MAX_LEN = 510 # save 2 for CLS and SEP # ===== # encoder 1: the basic encoder # todo(note): only DocInstane input for this mode, otherwise will break if conf.m2e_use_basic: reidx_pad_len = conf.ms_extend_budget # enc the basic part + also get some indexes sentid2offset = {} # id(sent)->overall_seq_offset seq_offset = 0 # if look at the docs in one seq all_sents = [] # (inst, d_idx, s_idx) for d_idx, one_doc in enumerate(insts): assert isinstance(one_doc, DocInstance) for s_idx, one_sent in enumerate(one_doc.sents): # todo(note): here we encode all the sentences all_sents.append((one_sent, d_idx, s_idx)) sentid2offset[id(one_sent)] = seq_offset seq_offset += one_sent.length - 1 # exclude extra ROOT node sent_reprs = self.run_sents(all_sents, insts, training) # flatten and concatenate and re-index reidxes_arr = np.zeros( seq_offset + reidx_pad_len, dtype=np.long ) # todo(note): extra padding to avoid out of boundary all_flattened_reprs = [] all_flatten_offset = 0 # the local offset for batched basic encoding for one_pack in sent_reprs: one_sents, _, one_repr_ef, one_repr_evt, _ = one_pack assert one_repr_ef is one_repr_evt, "Currently does not support separate basic enc in m3 mode" one_repr_t = one_repr_evt _, one_slen, one_ldim = BK.get_shape(one_repr_t) all_flattened_reprs.append(one_repr_t.view([-1, one_ldim])) # fill in the indexes for one_sent in one_sents: cur_start_offset = sentid2offset[id(one_sent)] cur_real_slen = one_sent.length - 1 # again, +1 to get rid of extra ROOT reidxes_arr[cur_start_offset:cur_start_offset+cur_real_slen] = \ np.arange(cur_real_slen, dtype=np.long) + (all_flatten_offset+1) all_flatten_offset += one_slen # here add the slen in batched version # re-idxing seq_sent_repr0 = BK.concat(all_flattened_reprs, 0) seq_sent_repr = BK.select(seq_sent_repr0, reidxes_arr, 0) # [all_seq_len, D] else: sentid2offset = defaultdict(int) seq_sent_repr = None # ===== # repack and prepare for multiple sent enc # todo(note): here, the criterion is based on bert's tokenizer all_ms_info = [] if isinstance(insts[0], DocInstance): for d_idx, one_doc in enumerate(insts): for s_idx, x in enumerate(one_doc.sents): # the basic criterion is the same as the basic one include_flag = False if training: if x.length<self.train_skip_length and x.length>=self.train_min_length \ and (len(x.events)>0 or next(self.random_sample_stream)>self.train_skip_noevt_rate): include_flag = True else: if x.length >= self.test_min_length: include_flag = True if include_flag: all_ms_info.append( x.preps["ms"]) # use the pre-calculated one else: # multisent based all_ms_info = insts.copy() # shallow copy # ===== # encoder 2: the bert one (multi-sent encoding) ms_size_f = lambda x: x.subword_size all_ms_info.sort(key=ms_size_f) all_ms_buckets = self._bucket_sents_by_length( all_ms_info, conf.benc_bucket_range, ms_size_f, max_bsize=conf.benc_bucket_msize) berter = self.berter rets = [] bert_use_center_typeids = conf.bert_use_center_typeids bert_use_special_typeids = conf.bert_use_special_typeids bert_other_inputs = conf.bert_other_inputs for one_bucket in all_ms_buckets: # prepare batched_ids = [] batched_starts = [] batched_seq_offset = [] batched_typeids = [] batched_other_inputs_list: List = [ [] for _ in bert_other_inputs ] # List(comp) of List(batch) of List(idx) for one_item in one_bucket: one_sents = one_item.sents one_center_sid = one_item.center_idx one_ids, one_starts, one_typeids = [], [], [] one_other_inputs_list = [[] for _ in bert_other_inputs ] # List(comp) of List(idx) for one_sid, one_sent in enumerate(one_sents): # for bert one_bidxes = one_sent.preps["bidx"] one_ids.extend(one_bidxes.subword_ids) one_starts.extend(one_bidxes.subword_is_start) # prepare other inputs for this_field_name, this_tofill_list in zip( bert_other_inputs, one_other_inputs_list): this_tofill_list.extend( one_sent.preps["sub_" + this_field_name]) # todo(note): special procedure if bert_use_center_typeids: if one_sid != one_center_sid: one_typeids.extend([0] * len(one_bidxes.subword_ids)) else: this_typeids = [1] * len(one_bidxes.subword_ids) if bert_use_special_typeids: # todo(note): this is the special mode that we are given the events!! for this_event in one_sents[ one_center_sid].events: _, this_wid, this_wlen = this_event.mention.hard_span.position( headed=False) for a, b in one_item.center_word2sub[ this_wid - 1:this_wid - 1 + this_wlen]: this_typeids[a:b] = [0] * (b - a) one_typeids.extend(this_typeids) batched_ids.append(one_ids) batched_starts.append(one_starts) batched_typeids.append(one_typeids) for comp_one_oi, comp_batched_oi in zip( one_other_inputs_list, batched_other_inputs_list): comp_batched_oi.append(comp_one_oi) # for basic part batched_seq_offset.append(sentid2offset[id(one_sents[0])]) # bert forward: [bs, slen, fold, D] if not bert_use_center_typeids: batched_typeids = None bert_expr0, mask_expr = berter.forward_batch( batched_ids, batched_starts, batched_typeids, training=training, other_inputs=batched_other_inputs_list) if self.m3_enc_is_empty: bert_expr = bert_expr0 else: mask_arr = BK.get_value(mask_expr) # [bs, slen] m3e_exprs = [ cur_enc(bert_expr0[:, :, cur_i], mask_arr) for cur_i, cur_enc in enumerate(self.m3_encs) ] bert_expr = BK.stack(m3e_exprs, -2) # on the fold dim again # collect basic ones: [bs, slen, D'] or None if seq_sent_repr is not None: arange_idxes_t = BK.arange_idx(BK.get_shape( mask_expr, -1)).unsqueeze(0) # [1, slen] offset_idxes_t = BK.input_idx(batched_seq_offset).unsqueeze( -1) + arange_idxes_t # [bs, slen] basic_expr = seq_sent_repr[offset_idxes_t] # [bs, slen, D'] elif conf.m2e_use_basic_dep: # collect each token's head-bert and ud-label, then forward with adp fake_sents = [one_item.fake_sent for one_item in one_bucket] # head idx and labels, no artificial ROOT padded_head_arr, _ = self.dep_padder.pad( [s.ud_heads.vals[1:] for s in fake_sents]) padded_label_arr, _ = self.dep_padder.pad( [s.ud_labels.idxes[1:] for s in fake_sents]) # get tensor padded_head_t = (BK.input_idx(padded_head_arr) - 1 ) # here, the idx exclude root padded_head_t.clamp_(min=0) # [bs, slen] padded_label_t = BK.input_idx(padded_label_arr) # get inputs input_head_bert_t = bert_expr[ BK.arange_idx(len(fake_sents)).unsqueeze(-1), padded_head_t] # [bs, slen, fold, D] input_label_emb_t = self.dep_label_emb( padded_label_t) # [bs, slen, D'] basic_expr = self.dep_layer( input_head_bert_t, None, [input_label_emb_t]) # [bs, slen, ?] elif conf.m2e_use_basic_plus: sent_reprs = self.run_sents([(one_item.fake_sent, None, None) for one_item in one_bucket], insts, training, use_one_bucket=True) assert len( sent_reprs ) == 1, "Unsupported split reprs for basic encoder, please set enc_bucket_range<=benc_bucket_range" _, _, one_repr_ef, one_repr_evt, _ = sent_reprs[0] assert one_repr_ef is one_repr_evt, "Currently does not support separate basic enc in m3 mode" basic_expr = one_repr_evt[:, 1:] # exclude ROOT, [bs, slen, D] assert BK.get_shape(basic_expr)[:2] == BK.get_shape( bert_expr)[:2] else: basic_expr = None # pack: (List[ms_item], bert_expr, basic_expr) rets.append((one_bucket, bert_expr, basic_expr)) return rets
def update_call(self, cache: VRecCache, src_mask=None, qk_mask=None, attn_range=None, rel_dist=None, temperature=1., forced_attn=None): conf = self.conf # ----- # first call matt to get v matt_input_qk = cache.orig_t * conf.feat_qk_lambda_orig + cache.rec_t * ( 1. - conf.feat_qk_lambda_orig) if self.att_pre_norm: matt_input_qk = self.att_pre_norm(matt_input_qk) matt_input_v = cache.orig_t * conf.feat_v_lambda_orig + cache.rec_t * ( 1. - conf.feat_v_lambda_orig) # todo(note): currently no pre-norm for matt_input_v # put attn_range as mask_qk if attn_range is not None and attn_range >= 0: # <0 means not effective cur_slen = BK.get_shape(matt_input_qk, -2) tmp_arange_t = BK.arange_idx(cur_slen) # [slen] # less or equal!! mask_qk = ( (tmp_arange_t.unsqueeze(-1) - tmp_arange_t.unsqueeze(0)).abs() <= attn_range).float() if qk_mask is not None: # further with input masks mask_qk *= qk_mask else: mask_qk = qk_mask scores, attn_info, result_value = self.feat_node( matt_input_qk, matt_input_qk, matt_input_v, cache.accu_attn, mask_k=src_mask, mask_qk=mask_qk, rel_dist=rel_dist, temperature=temperature, forced_attn=forced_attn) # ..., [*, len_q, dv] # ----- # then combine q(hidden) and v(input) comb_input_q = cache.orig_t * conf.comb_q_lambda_orig + cache.rec_t * ( 1. - conf.comb_q_lambda_orig) comb_result, comb_c = self.comb_f( comb_input_q, result_value, cache.rec_lstm_c_t) # [*, len_q, dim] if self.att_post_norm: comb_result = self.att_post_norm(comb_result) # ----- # ff if self.has_ff: if self.ff_pre_norm: ff_input = self.ff_pre_norm(comb_result) else: ff_input = comb_result ff_output = comb_result + self.dropout2( self.linear2(self.dropout1(self.linear1(ff_input)))) if self.ff_post_norm: ff_output = self.ff_post_norm(ff_output) else: # otherwise no ff ff_output = comb_result # ----- # update cache and return output # cache.orig_t = cache.orig_t # this does not change cache.rec_t = ff_output cache.accu_attn = cache.accu_attn + attn_info[0] # accumulating attn cache.rec_lstm_c_t = comb_c # optional C for lstm cache.list_hidden.append(ff_output) # all hidden layers cache.list_score.append(scores) # all un-normed scores cache.list_attn.append(attn_info[0]) # all normed scores cache.list_accu_attn.append(cache.accu_attn) # all accumulated attns cache.list_attn_info.append(attn_info) # all attn infos return ff_output
def predict(self, insts: List, input_lexi, input_expr, input_mask): conf = self.conf bsize, slen = BK.get_shape(input_mask) bsize_arange_t_1d = BK.arange_idx(bsize) # [*] bsize_arange_t_2d = bsize_arange_t_1d.unsqueeze(-1) # [*, 1] beam_size = conf.beam_size # prepare things with an extra beam dimension beam_input_expr, beam_input_mask = input_expr.unsqueeze(-3).expand(-1, beam_size, -1, -1).contiguous(), \ input_mask.unsqueeze(-2).expand(-1, beam_size, -1).contiguous() # [*, beam, slen, D?] # ----- # recurrent states beam_hard_coverage = BK.zeros([bsize, beam_size, slen]) # [*, beam, slen] # tuple([*, beam, D], ) beam_prev_state = [z.unsqueeze(-2).expand(-1, beam_size, -1) for z in self.rnn_unit.zero_init_hidden(bsize)] # frozen after reach eos beam_noneos = 1.-BK.zeros([bsize, beam_size]) # [*, beam] beam_logprobs = BK.zeros([bsize, beam_size]) # [*, beam], sum of logprobs beam_logprobs_paths = BK.zeros([bsize, beam_size, 0]) # [*, beam, step] beam_tok_paths = BK.zeros([bsize, beam_size, 0]).long() beam_lab_paths = BK.zeros([bsize, beam_size, 0]).long() # ----- for cstep in range(conf.max_step): # get things of [*, beam, beam] sel_tok_idxes, sel_tok_logprobs, sel_lab_idxes, sel_lab_logprobs, sel_lab_embeds, next_state = \ self._step(beam_input_expr, beam_input_mask, beam_hard_coverage, beam_prev_state, None, None, beam_size) sel_logprobs = sel_tok_logprobs + sel_lab_logprobs # [*, beam, beam] if cstep == 0: # special for the first step, only select for the first element cur_selections = BK.arange_idx(beam_size).unsqueeze(0).expand(bsize, beam_size) # [*, beam] else: # then select the topk in beam*beam (be careful about the frozen ones!!) beam_noneos_3d = beam_noneos.unsqueeze(-1) # eos can only followed by eos sel_tok_idxes *= beam_noneos_3d.long() sel_lab_idxes *= beam_noneos_3d.long() # numeric tricks to keep the frozen ones ([0] with 0. score, [1:] with -inf scores) sel_logprobs *= beam_noneos_3d tmp_exclude_mask = 1. - beam_noneos_3d.expand_as(sel_logprobs) tmp_exclude_mask[:, :, 0] = 0. sel_logprobs += tmp_exclude_mask * Constants.REAL_PRAC_MIN # select for topk topk_logprobs = (beam_noneos * beam_logprobs).unsqueeze(-1) + sel_logprobs _, cur_selections = topk_logprobs.view([bsize, -1]).topk(beam_size, dim=-1, sorted=True) # [*, beam] # read and write the selections # gathering previous ones cur_sel_previ = cur_selections // beam_size # [*, beam] prev_hard_coverage = beam_hard_coverage[bsize_arange_t_2d, cur_sel_previ] # [*, beam] prev_noneos = beam_noneos[bsize_arange_t_2d, cur_sel_previ] # [*, beam] prev_logprobs = beam_logprobs[bsize_arange_t_2d, cur_sel_previ] # [*, beam] prev_logprobs_paths = beam_logprobs_paths[bsize_arange_t_2d, cur_sel_previ] # [*, beam, step] prev_tok_paths = beam_tok_paths[bsize_arange_t_2d, cur_sel_previ] # [*, beam, step] prev_lab_paths = beam_lab_paths[bsize_arange_t_2d, cur_sel_previ] # [*, beam, step] # prepare new ones cur_sel_newi = cur_selections % beam_size new_tok_idxes = sel_tok_idxes[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi] # [*, beam] new_lab_idxes = sel_lab_idxes[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi] # [*, beam] new_logprobs = sel_logprobs[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi] # [*, beam] new_prev_state = [z[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi] for z in next_state] # [*, beam, ~] # update prev_hard_coverage[bsize_arange_t_2d, BK.arange_idx(beam_size).unsqueeze(0), new_tok_idxes] += 1. beam_hard_coverage = prev_hard_coverage beam_prev_state = new_prev_state beam_noneos = prev_noneos * (new_tok_idxes!=0).float() beam_logprobs = prev_logprobs + new_logprobs beam_logprobs_paths = BK.concat([prev_logprobs_paths, new_logprobs.unsqueeze(-1)], -1) beam_tok_paths = BK.concat([prev_tok_paths, new_tok_idxes.unsqueeze(-1)], -1) beam_lab_paths = BK.concat([prev_lab_paths, new_lab_idxes.unsqueeze(-1)], -1) # finally force an extra eos step to get ending tok-logprob (no need to update other things) final_eos_idxes = BK.zeros([bsize, beam_size]).long() _, eos_logprobs, _, _, _, _ = self._step(beam_input_expr, beam_input_mask, beam_hard_coverage, beam_prev_state, final_eos_idxes, final_eos_idxes, None) beam_logprobs += eos_logprobs.squeeze(-1) * beam_noneos # [*, beam] # select and return the best one beam_tok_valids = (beam_tok_paths > 0).float() # [*, beam, steps] final_scores = beam_logprobs / ((beam_tok_valids.sum(-1) + 1.) ** conf.len_alpha) # [*, beam] _, best_beam_idx = final_scores.max(-1) # [*] # ----- # prepare returns; cut by max length: [*, all_step] -> [*, max_step] ret0_valid_mask = beam_tok_valids[bsize_arange_t_1d, best_beam_idx] cur_max_step = ret0_valid_mask.long().sum(-1).max().item() ret_valid_mask = ret0_valid_mask[:, :cur_max_step] ret_logprobs = beam_logprobs_paths[bsize_arange_t_1d, best_beam_idx][:, :cur_max_step] ret_tok_idxes = beam_tok_paths[bsize_arange_t_1d, best_beam_idx][:, :cur_max_step] ret_lab_idxes = beam_lab_paths[bsize_arange_t_1d, best_beam_idx][:, :cur_max_step] # embeddings ret_lab_embeds = self._hl_lookup(ret_lab_idxes) return ret_logprobs, ret_tok_idxes, ret_valid_mask, ret_lab_idxes, ret_lab_embeds
def update_bsize(self, new_bsize): if new_bsize != self.cur_bsize: self.cur_bsize = new_bsize self.bsize_range_t = BK.arange_idx(new_bsize)
def get_rel_dist(self, len_q: int, len_k: int): dist_x = BK.arange_idx(0, len_k).unsqueeze(0) # [1, len_k] dist_y = BK.arange_idx(0, len_q).unsqueeze(1) # [len_q, 1] distance = dist_x - dist_y # [len_q, len_k] return distance
def loss(self, input_expr, loss_mask, gold_idxes, margin=0.): gold_all_idxes = self._get_all_idxes(gold_idxes) # scoring raw_scores = self._raw_scores(input_expr) raw_scores_aug = [] margin_P, margin_R, margin_T = self.conf.margin_lambda_P, self.conf.margin_lambda_R, self.conf.margin_lambda_T # gold_shape = BK.get_shape(gold_idxes) # [*] gold_bsize_prod = np.prod(gold_shape) # gold_arange_idxes = BK.arange_idx(gold_bsize_prod) # margin for i in range(self.eff_max_layer): cur_gold_inputs = gold_all_idxes[i] # add margin cur_scores = raw_scores[i] # [*, ?] cur_margin = margin * self.margin_lambdas[i] if cur_margin > 0.: cur_num_target = self.prediction_sizes[i] cur_isnil = self.layered_isnil[i].byte() # [NLab] cost_matrix = BK.constants([cur_num_target, cur_num_target], margin_T) # [gold, pred] cost_matrix[cur_isnil, :] = margin_P cost_matrix[:, cur_isnil] = margin_R diag_idxes = BK.arange_idx(cur_num_target) cost_matrix[diag_idxes, diag_idxes] = 0. margin_mat = cost_matrix[cur_gold_inputs] cur_aug_scores = cur_scores + margin_mat # [*, ?] else: cur_aug_scores = cur_scores raw_scores_aug.append(cur_aug_scores) # cascade scores final_scores = self._cascade_scores(raw_scores_aug) # loss weight, todo(note): asserted self.hl_vocab.nil_as_zero before loss_weights = ((gold_idxes == 0).float() * (self.loss_fullnil_weight - 1.) + 1.) if self.loss_fullnil_weight < 1. else 1. # calculate loss loss_prob_entropy_lambda = self.conf.loss_prob_entropy_lambda loss_prob_reweight = self.conf.loss_prob_reweight final_losses = [] no_loss_max_gold = self.conf.no_loss_max_gold if loss_mask is None: loss_mask = BK.constants(BK.get_shape(input_expr)[:-1], 1.) for i in range(self.eff_max_layer): cur_final_scores, cur_gold_inputs = final_scores[ i], gold_all_idxes[i] # [*, ?], [*] # collect the loss if self.is_hinge_loss: cur_pred_scores, cur_pred_idxes = cur_final_scores.max(-1) cur_gold_scores = BK.gather(cur_final_scores, cur_gold_inputs.unsqueeze(-1), -1).squeeze(-1) cur_loss = cur_pred_scores - cur_gold_scores # [*], todo(note): this must be >=0 if no_loss_max_gold: # this should be implicit cur_loss = cur_loss * (cur_loss > 0.).float() elif self.is_prob_loss: # cur_loss = BK.loss_nll(cur_final_scores, cur_gold_inputs) # [*] cur_loss = self._my_loss_prob(cur_final_scores, cur_gold_inputs, loss_prob_entropy_lambda, loss_mask, loss_prob_reweight) # [*] if no_loss_max_gold: cur_pred_scores, cur_pred_idxes = cur_final_scores.max(-1) cur_gold_scores = BK.gather(cur_final_scores, cur_gold_inputs.unsqueeze(-1), -1).squeeze(-1) cur_loss = cur_loss * (cur_gold_scores > cur_pred_scores).float() else: raise NotImplementedError( f"UNK loss {self.conf.loss_function}") # here first summing up, divided at the outside one_loss_sum = ( cur_loss * (loss_mask * loss_weights)).sum() * self.loss_lambdas[i] final_losses.append(one_loss_sum) # final sum final_loss_sum = BK.stack(final_losses).sum() _, ret_lab_idxes, ret_lab_embeds = self._predict(final_scores, None) return [[final_loss_sum, loss_mask.sum()]], ret_lab_idxes, ret_lab_embeds
def _loss(self, annotated_insts: List[ParseInstance], full_score_expr, mask_expr, valid_expr=None): bsize, max_len = BK.get_shape(mask_expr) # gold heads and labels gold_heads_arr, _ = self.predict_padder.pad( [z.heads.vals for z in annotated_insts]) # todo(note): here use the original idx of label, no shift! gold_labels_arr, _ = self.predict_padder.pad( [z.labels.idxes for z in annotated_insts]) gold_heads_expr = BK.input_idx(gold_heads_arr) # [BS, Len] gold_labels_expr = BK.input_idx(gold_labels_arr) # [BS, Len] # idxes_bs_expr = BK.arange_idx(bsize).unsqueeze(-1) idxes_m_expr = BK.arange_idx(max_len).unsqueeze(0) # scores for decoding or marginal margin = self.margin.value decoding_scores = full_score_expr.clone().detach() decoding_scores = self.scorer_helper.postprocess_scores( decoding_scores, mask_expr, margin, gold_heads_expr, gold_labels_expr) if self.loss_hinge: mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts], dtype=np.int32) pred_heads_expr, pred_labels_expr, _ = nmst_unproj(decoding_scores, mask_expr, mst_lengths_arr, labeled=True, ret_arr=False) # ===== add margin*cost, [bs, len] gold_final_scores = full_score_expr[idxes_bs_expr, idxes_m_expr, gold_heads_expr, gold_labels_expr] pred_final_scores = full_score_expr[ idxes_bs_expr, idxes_m_expr, pred_heads_expr, pred_labels_expr] + margin * ( gold_heads_expr != pred_heads_expr).float() + margin * ( gold_labels_expr != pred_labels_expr).float() # plus margin hinge_losses = pred_final_scores - gold_final_scores valid_losses = ((hinge_losses * mask_expr)[:, 1:].sum(-1) > 0.).float().unsqueeze(-1) # [*, 1] final_losses = hinge_losses * valid_losses else: lab_marginals = nmarginal_unproj(decoding_scores, mask_expr, None, labeled=True) lab_marginals[idxes_bs_expr, idxes_m_expr, gold_heads_expr, gold_labels_expr] -= 1. grads_masked = lab_marginals * mask_expr.unsqueeze(-1).unsqueeze( -1) * mask_expr.unsqueeze(-2).unsqueeze(-1) final_losses = (full_score_expr * grads_masked).sum(-1).sum( -1) # [bs, m] # divide loss by what? num_sent = len(annotated_insts) num_valid_tok = sum(len(z) for z in annotated_insts) # exclude non-valid ones: there can be pruning error if valid_expr is not None: final_valids = valid_expr[idxes_bs_expr, idxes_m_expr, gold_heads_expr] # [bs, m] of (0. or 1.) final_losses = final_losses * final_valids tok_valid = float(BK.get_value(final_valids[:, 1:].sum())) assert tok_valid <= num_valid_tok tok_prune_err = num_valid_tok - tok_valid else: tok_prune_err = 0 # collect loss with mask, also excluding the first symbol of ROOT final_losses_masked = (final_losses * mask_expr)[:, 1:] final_loss_sum = BK.sum(final_losses_masked) if self.conf.tconf.loss_div_tok: final_loss = final_loss_sum / num_valid_tok else: final_loss = final_loss_sum / num_sent final_loss_sum_val = float(BK.get_value(final_loss_sum)) info = { "sent": num_sent, "tok": num_valid_tok, "tok_prune_err": tok_prune_err, "loss_sum": final_loss_sum_val } return final_loss, info