def __call__(self, query_up, key_up, rel_dist=None, input_scores=None): _att_scale_qk = self._att_scale_qk # ----- # get dim info len_q, len_k = BK.get_shape(query_up, -2), BK.get_shape(key_up, -2) # get distance embeddings if rel_dist is None: rel_dist = self.get_rel_dist(len_q, len_k) if self.rel_dist_abs: # use abs? rel_dist = BK.abs(rel_dist) dist_embs = self.E(rel_dist) # [len_q, len_k, Demb] # ----- # dist_up dist_up0 = self.affine_rel(dist_embs) # [len_q, len_k, head*D] # -> [head, len_q, len_k, D] dist_up1 = dist_up0.view( BK.get_shape(dist_up0)[:-1] + self.split_dims).transpose( -2, -3).transpose(-3, -4) # ----- # all items are [*, head, len_q, len_k] posi_scores = (input_scores if (input_scores is not None) else 0.) # item (b): <query, dist>: [head, len_q, len_k, D] * [*, head, len_q, D, 1] -> [*, head, len_q, len_k] item_b = (BK.matmul(dist_up1, query_up.unsqueeze(-1)) / _att_scale_qk).squeeze(-1) posi_scores += item_b # todo(note): remove this item_c since it is not related with rel_dist # # item (c): <key, u>: [*, head, len_k, D] * [head, D, 1] -> [*, head, 1, len_k] # item_c = (BK.matmul(key_up, self.vec_u.unsqueeze(-1)) / _att_scale_qk).squeeze(-1).unsqueeze(-2) # posi_scores += item_c # item (d): <dist, v>: [head, len_q, len_k, D] * [head, 1, D, 1] -> [head, len_q, len_k] item_d = (BK.matmul(dist_up1, self.vec_v.unsqueeze(-2).unsqueeze(-1)) / _att_scale_qk).squeeze(-1) posi_scores += item_d return posi_scores
def loss(self, ms_items: List, bert_expr): conf = self.conf max_range = self.conf.max_range bsize = len(ms_items) # collect instances col_efs, _, col_bidxes_t, col_hidxes_t, col_ldists_t, col_rdists_t = self._collect_insts( ms_items, True) if len(col_efs) == 0: zzz = BK.zeros([]) return [[zzz, zzz, zzz], [zzz, zzz, zzz]] left_scores, right_scores = self._score(bert_expr, col_bidxes_t, col_hidxes_t) # [N, R] if conf.use_binary_scorer: left_binaries, right_binaries = (BK.arange_idx(max_range)<=col_ldists_t.unsqueeze(-1)).float(), \ (BK.arange_idx(max_range)<=col_rdists_t.unsqueeze(-1)).float() # [N,R] left_losses = BK.binary_cross_entropy_with_logits( left_scores, left_binaries, reduction='none')[:, 1:] right_losses = BK.binary_cross_entropy_with_logits( right_scores, right_binaries, reduction='none')[:, 1:] left_count = right_count = BK.input_real( BK.get_shape(left_losses, 0) * (max_range - 1)) else: left_losses = BK.loss_nll(left_scores, col_ldists_t) right_losses = BK.loss_nll(right_scores, col_rdists_t) left_count = right_count = BK.input_real( BK.get_shape(left_losses, 0)) return [[left_losses.sum(), left_count, left_count], [right_losses.sum(), right_count, right_count]]
def select_plain(self, ags: List[BfsAgenda], candidates, mode, k_arc, k_label) -> List[List]: flattened_states, cur_arc_scores, scoring_mask_ct = candidates cur_cache = self.cache cur_bsize = len(flattened_states) cur_slen = cur_cache.max_slen cur_arc_scores_flattend = cur_arc_scores.view([cur_bsize, -1]) # [bs, Lm*Lh] if mode == "topk": # arcs [*, k] topk_arc_scores, topk_arc_idxes = BK.topk( cur_arc_scores_flattend, min(k_arc, BK.get_shape(cur_arc_scores_flattend, -1)), dim=-1, sorted=False) topk_m, topk_h = topk_arc_idxes / cur_slen, topk_arc_idxes % cur_slen # [m, h] # labels [*, k, k'] cur_label_scores = cur_cache.get_selected_label_scores( topk_m, topk_h, self.mw_arc, self.mw_label) topk_label_scores, topk_label_idxes = BK.topk( cur_label_scores, min(k_label, BK.get_shape(cur_label_scores, -1)), dim=-1, sorted=False) return self._new_states(flattened_states, scoring_mask_ct, topk_arc_scores, topk_m, topk_h, topk_label_scores, topk_label_idxes) elif mode == "": return [[]] * cur_bsize # todo(+N): other modes like sampling to be implemented: sample, topk-sample else: raise NotImplementedError(mode)
def predict(self, repr_ef, repr_evt, lab_ef, lab_evt, mask_ef=None, mask_evt=None, ret_full_logprobs=False): # ----- ret_shape = BK.get_shape(lab_ef)[:-1] + [ BK.get_shape(lab_ef, -1), BK.get_shape(lab_evt, -1) ] if np.prod(ret_shape) == 0: if ret_full_logprobs: return BK.zeros(ret_shape + [self.num_label]) else: return BK.zeros(ret_shape), BK.zeros(ret_shape).long() # ----- # todo(note): +1 for space of DROPED(UNK) full_score = self._score(repr_ef, repr_evt, lab_ef + 1, lab_evt + 1) # [*, len-ef, len-evt, D] full_logprobs = BK.log_softmax(full_score, -1) if ret_full_logprobs: return full_logprobs else: # greedy maximum decode ret_logprobs, ret_idxes = full_logprobs.max( -1) # [*, len-ef, len-evt] # mask non-valid ones if mask_ef is not None: ret_idxes *= (mask_ef.unsqueeze(-1)).long() if mask_evt is not None: ret_idxes *= (mask_evt.unsqueeze(-2)).long() return ret_logprobs, ret_idxes
def _score(self, repr_t, attn_t, mask_t): conf = self.conf # ----- repr_m = self.pre_aff_m(repr_t) # [bs, slen, S] repr_h = self.pre_aff_h(repr_t) # [bs, slen, S] scores0 = self.dps_node.paired_score( repr_m, repr_h, inputp=attn_t) # [bs, len_q, len_k, 1+N] # mask at outside slen = BK.get_shape(mask_t, -1) score_mask = BK.constants(BK.get_shape(scores0)[:-1], 1.) # [bs, len_q, len_k] score_mask *= (1. - BK.eye(slen)) # no diag score_mask *= mask_t.unsqueeze(-1) # input mask at len_k score_mask *= mask_t.unsqueeze(-2) # input mask at len_q NEG = Constants.REAL_PRAC_MIN scores1 = scores0 + NEG * (1. - score_mask.unsqueeze(-1) ) # [bs, len_q, len_k, 1+N] # add fixed idx0 scores if set if conf.fix_s0: fix_s0_mask_t = BK.input_real(self.dps_s0_mask) # [1+N] scores1 = ( 1. - fix_s0_mask_t ) * scores1 + fix_s0_mask_t * conf.fix_s0_val # [bs, len_q, len_k, 1+N] # minus s0 if conf.minus_s0: scores1 = scores1 - scores1.narrow(-1, 0, 1) # minus idx=0 scores return scores1, score_mask
def calculate_repr(self, cur_t, par_t, label_t, par_mask_t, chs_t, chs_label_t, chs_mask_t, chs_valid_mask_t): ret_t = cur_t # [*, D] # padding 0 if not using labels dim_label = self.dim_label # child features if self.use_chs and chs_t is not None: if self.use_label_feat: chs_label_rt = self.label_embeddings( chs_label_t) # [*, max-chs, dlab] else: labels_shape = BK.get_shape(chs_t) labels_shape[-1] = dim_label chs_label_rt = BK.zeros(labels_shape) chs_input_t = BK.concat([chs_t, chs_label_rt], -1) chs_feat0 = self.chs_reprer(cur_t, chs_input_t, chs_mask_t, chs_valid_mask_t) chs_feat = self.chs_ff(chs_feat0) ret_t += chs_feat # parent features if self.use_par and par_t is not None: if self.use_label_feat: cur_label_t = self.label_embeddings(label_t) # [*, dlab] else: labels_shape = BK.get_shape(par_t) labels_shape[-1] = dim_label cur_label_t = BK.zeros(labels_shape) par_feat = self.par_ff([par_t, cur_label_t]) if par_mask_t is not None: par_feat *= par_mask_t.unsqueeze(-1) ret_t += par_feat return ret_t
def __call__(self, query, key, accu_attn, mask_k, mask_qk, rel_dist): conf = self.conf # == calculate the dot-product scores # calculate the three: # [bs, len_?, head*D]; and also add sta ones if needed query_up, key_up = self.affine_q(query), self.affine_k( key) # [*, len?, head?*Dqk] query_up, key_up = self._shape_project( query_up, True), self._shape_project(key_up, True) # [*, head?, len_?, D] # original scores scores = BK.matmul(query_up, BK.transpose( key_up, -1, -2)) / self._att_scale_qk # [*, head?, len_q, len_k] # == adding rel_dist ones if conf.use_rel_dist: scores = self.dist_helper(query_up, key_up, rel_dist=rel_dist, input_scores=scores) # tranpose scores = scores.transpose(-2, -3).transpose(-1, -2) # [*, len_q, len_k, head?] # == unhead score if conf.use_unhead_score: scores_t0, score_t1 = BK.split(scores, [1, self.head_count], -1) # [*, len_q, len_k, 1|head] scores = scores_t0 + score_t1 # [*, len_q, len_k, head] # == combining with history accumulated attns if conf.use_lambq and accu_attn is not None: # todo(note): here we only consider "query" and "head", would it be necessary for "key"? lambq_vals = self.lambq_aff( query ) # [*, len_q, head], if for eg., using relu as fact, this>=0 scores -= lambq_vals.unsqueeze(-2) * accu_attn # == score offset if conf.use_soff: # todo(note): here we only consider "query" and "head", key may be handled by "unhead_score" score_offset_t = self.soff_aff(query) # [*, len_q, 1+head] score_offset_t0, score_offset_t1 = BK.split( score_offset_t, [1, self.head_count], -1) # [*, len_q, 1|head] scores -= score_offset_t0.unsqueeze(-2) scores -= score_offset_t1.unsqueeze( -2) # still [*, len_q, len_k, head] # == apply mask & no-self-loop # NEG_INF = Constants.REAL_PRAC_MIN NEG_INF = -1000. # this should be enough NEG_INF2 = -2000. # this should be enough if mask_k is not None: # [*, 1, len_k, 1] scores += (1. - mask_k).unsqueeze(-2).unsqueeze(-1) * NEG_INF2 if mask_qk is not None: # [*, len_q, len_k, 1] scores += (1. - mask_qk).unsqueeze(-1) * NEG_INF2 if self.no_self_loop: query_len = BK.get_shape(query, -2) assert query_len == BK.get_shape( key, -2), "Shape not matched for no_self_loop" scores += BK.eye(query_len).unsqueeze( -1) * NEG_INF # [len_q, len_k, 1] return scores.contiguous() # [*, len_q, len_k, head]
def init_call(self, src): # init accumulated attn: all 0. src_shape = BK.get_shape(src) attn_shape = src_shape[:-1] + [src_shape[-2], self.attn_count ] # [*, len_q, len_k, head] cache = VRecCache() cache.orig_t = src cache.rec_t = src # initially, the same as "orig_t" cache.rec_lstm_c_t = BK.zeros(BK.get_shape(src)) # initially zero cache.accu_attn = BK.zeros(attn_shape) return cache
def _enc(self, input_lexi, input_expr, input_mask, sel_idxes): if self.dmxnn: bsize, slen = BK.get_shape(input_mask) if sel_idxes is None: sel_idxes = BK.arange_idx(slen).unsqueeze( 0) # select all, [1, slen] ncand = BK.get_shape(sel_idxes, -1) # enc_expr aug with PE rel_dist = BK.arange_idx(slen).unsqueeze(0).unsqueeze( 0) - sel_idxes.unsqueeze(-1) # [*, ?, slen] pe_embeds = self.posi_embed(rel_dist) # [*, ?, slen, Dpe] aug_enc_expr = BK.concat([ pe_embeds.expand(bsize, -1, -1, -1), input_expr.unsqueeze(1).expand(-1, ncand, -1, -1) ], -1) # [*, ?, slen, D+Dpe] # [*, ?, slen, Denc] hidden_expr = self.e_encoder( aug_enc_expr.view(bsize * ncand, slen, -1), input_mask.unsqueeze(1).expand(-1, ncand, -1).contiguous().view( bsize * ncand, slen)) hidden_expr = hidden_expr.view(bsize, ncand, slen, -1) # dynamic max-pooling (dist<0, dist=0, dist>0) NEG = Constants.REAL_PRAC_MIN mp_hiddens = [] mp_masks = [rel_dist < 0, rel_dist == 0, rel_dist > 0] for mp_mask in mp_masks: float_mask = mp_mask.float() * input_mask.unsqueeze( -2) # [*, ?, slen] valid_mask = (float_mask.sum(-1) > 0.).float().unsqueeze( -1) # [*, ?, 1] mask_neg_val = ( 1. - float_mask).unsqueeze(-1) * NEG # [*, ?, slen, 1] # todo(+2): or do we simply multiply mask? mp_hid0 = (hidden_expr + mask_neg_val).max(-2)[0] mp_hid = mp_hid0 * valid_mask # [*, ?, Denc] mp_hiddens.append(self.special_drop(mp_hid)) # mp_hiddens.append(mp_hid) final_hiddens = mp_hiddens else: hidden_expr = self.e_encoder(input_expr, input_mask) # [*, slen, D'] if sel_idxes is None: hidden_expr1 = hidden_expr else: hidden_expr1 = BK.gather_first_dims(hidden_expr, sel_idxes, -2) # [*, ?, D'] final_hiddens = [self.special_drop(hidden_expr1)] if self.lab_f_use_lexi: final_hiddens.append( BK.gather_first_dims(input_lexi, sel_idxes, -2)) # [*, ?, DLex] ret_expr = self.lab_f(final_hiddens) # [*, ?, DLab] return ret_expr
def inference_on_batch(self, insts: List[GeneralSentence], **kwargs): conf = self.conf self.refresh_batch(False) with BK.no_grad_env(): # special mode # use: CUDA_VISIBLE_DEVICES=3 PYTHONPATH=../../src/ python3 -m pdb ../../src/tasks/cmd.py zmlm.main.test ${RUN_DIR}/_conf device:0 dict_dir:${RUN_DIR}/ model_load_name:${RUN_DIR}/zmodel.best test:./_en.debug test_interactive:1 if conf.test_interactive: iinput_sent = input(">> (Interactive testing) Input sent sep by blanks: ") iinput_tokens = iinput_sent.split() if len(iinput_sent) > 0: iinput_inst = GeneralSentence.create(iinput_tokens) iinput_inst.word_seq.set_idxes([self.word_vocab.get_else_unk(w) for w in iinput_inst.word_seq.vals]) iinput_inst.char_seq.build_idxes(self.inputter.vpack.get_voc("char")) iinput_map = self.inputter([iinput_inst]) iinput_erase_mask = np.asarray([[z=="Z" for z in iinput_tokens]]).astype(dtype=np.float32) iinput_masked_map = self.inputter.mask_input(iinput_map, iinput_erase_mask, set("pos")) emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(iinput_masked_map, collect_loss=False, insts=[iinput_inst]) mlm_loss = self.masklm.loss(enc_t, iinput_erase_mask, iinput_map) dpar_input_attn = self.prepr_f(cache, self._get_rel_dist(BK.get_shape(mask_t, -1))) self.dpar.predict([iinput_inst], enc_t, dpar_input_attn, mask_t) self.upos.predict([iinput_inst], enc_t, mask_t) # print them import pandas as pd cur_fields = { "idxes": list(range(1, len(iinput_inst)+1)), "word": iinput_inst.word_seq.vals, "pos": iinput_inst.pred_pos_seq.vals, "head": iinput_inst.pred_dep_tree.heads[1:], "dlab": iinput_inst.pred_dep_tree.labels[1:]} zlog(f"Result:\n{pd.DataFrame(cur_fields).to_string()}") return {} # simply return here for interactive mode # ----- # test for MLM simply as in training (use special separate rand_gen to keep the masks the same for testing) # todo(+2): do we need to keep testing/validing during training the same? Currently not! info = self.fb_on_batch(insts, training=False, rand_gen=self.testing_rand_gen, assign_attns=conf.testing_get_attns) # ----- if len(insts) == 0: return info # decode for dpar input_map = self.inputter(insts) emb_t, mask_t, enc_t, cache, _ = self._emb_and_enc(input_map, collect_loss=False, insts=insts) dpar_input_attn = self.prepr_f(cache, self._get_rel_dist(BK.get_shape(mask_t, -1))) self.dpar.predict(insts, enc_t, dpar_input_attn, mask_t) self.upos.predict(insts, enc_t, mask_t) if self.ner is not None: self.ner.predict(insts, enc_t, mask_t) # ----- if conf.testing_get_attns: if conf.enc_choice == "vrec": self._assign_attns_item(insts, "orig", cache=cache) elif conf.enc_choice in ["original"]: pass else: raise NotImplementedError() return info
def _score(self, bert_expr, bidxes_t, hidxes_t): # ---- # # debug # print(f"# ====\n Debug: {ArgSpanExpander._debug_count}") # ArgSpanExpander._debug_count += 1 # ---- bert_expr = bert_expr.view(BK.get_shape(bert_expr)[:-2] + [-1]) # flatten # max_range = self.conf.max_range max_slen = BK.get_shape(bert_expr, 1) # get candidates range_t = BK.arange_idx(max_range).unsqueeze(0) # [1, R] bidxes_t = bidxes_t.unsqueeze(1) # [N, 1] hidxes_t = hidxes_t.unsqueeze(1) # [N, 1] left_cands = hidxes_t - range_t # [N, R] right_cands = hidxes_t + range_t left_masks = (left_cands >= 0).float() right_masks = (right_cands < max_slen).float() left_cands.clamp_(min=0) right_cands.clamp_(max=max_slen - 1) # score head_exprs = bert_expr[bidxes_t, hidxes_t] # [N, 1, D'] left_cand_exprs = bert_expr[bidxes_t, left_cands] # [N, R, D'] right_cand_exprs = bert_expr[bidxes_t, right_cands] # actual scoring if self.use_lstm_scorer: batch_size = BK.get_shape(bidxes_t, 0) all_concat_outputs = [] for cand_exprs, lstm_node in zip( [left_cand_exprs, right_cand_exprs], [self.llstm, self.rlstm]): cur_state = lstm_node.zero_init_hidden(batch_size) step_size = BK.get_shape(cand_exprs, 1) all_outputs = [] for step_i in range(step_size): cur_state = lstm_node(cand_exprs[:, step_i], cur_state, None) all_outputs.append(cur_state[0]) # using h concat_output = BK.stack(all_outputs, 1) # [N, R, ?] all_concat_outputs.append(concat_output) left_hidden, right_hidden = all_concat_outputs left_scores = self.lscorer(left_hidden).squeeze(-1) # [N, R] right_scores = self.rscorer(right_hidden).squeeze(-1) # [N, R] else: left_scores = self.lscorer([left_cand_exprs, head_exprs]).squeeze(-1) # [N, R] right_scores = self.rscorer([right_cand_exprs, head_exprs]).squeeze(-1) # mask left_scores += Constants.REAL_PRAC_MIN * (1. - left_masks) right_scores += Constants.REAL_PRAC_MIN * (1. - right_masks) return left_scores, right_scores
def score_arc_all(self, am_expr, ah_expr, m_mask_expr, h_mask_expr): if self.dist_helper: lenm, lenh = BK.get_shape(am_expr, -2), BK.get_shape(ah_expr, -2) ah_rel1, _ = self.dist_helper(lenm, lenh) else: ah_rel1 = None arc_scores = self.arc_scorer.paired_score(am_expr, ah_expr, m_mask_expr, h_mask_expr, rel1_t=ah_rel1) ret = arc_scores.squeeze(-1) # squeeze the last one return ret
def _step(self, input_expr, input_mask, hard_coverage, prev_state, force_widx, force_lidx, free_beam_size): conf = self.conf free_mode = (force_widx is None) prev_state_h = prev_state[0] # ===== # collect att scores key_up = self.affine_k([input_expr, hard_coverage.unsqueeze(-1)]) # [*, slen, h] query_up = self.affine_q([self.repos.unsqueeze(0), prev_state_h.unsqueeze(-2)]) # [*, R, h] orig_scores = BK.matmul(key_up, query_up.transpose(-2, -1)) # [*, slen, R] orig_scores += (1.-input_mask).unsqueeze(-1) * Constants.REAL_PRAC_MIN # [*, slen, R] # first maximum across the R dim (this step is hard max) maxr_scores, maxr_idxes = orig_scores.max(-1) # [*, slen] if conf.zero_eos_score: # use mask to make it able to be backward tmp_mask = BK.constants(BK.get_shape(maxr_scores), 1.) tmp_mask.index_fill_(-1, BK.input_idx(0), 0.) maxr_scores *= tmp_mask # then select over the slen dim (this step is prob based) maxr_logprobs = BK.log_softmax(maxr_scores) # [*, slen] if free_mode: cur_beam_size = min(free_beam_size, BK.get_shape(maxr_logprobs, -1)) sel_tok_logprobs, sel_tok_idxes = maxr_logprobs.topk(cur_beam_size, dim=-1, sorted=False) # [*, beam] else: sel_tok_idxes = force_widx.unsqueeze(-1) # [*, 1] sel_tok_logprobs = maxr_logprobs.gather(-1, sel_tok_idxes) # [*, 1] # then collect the info and perform labeling lf_input_expr = BK.gather_first_dims(input_expr, sel_tok_idxes, -2) # [*, ?, ~] lf_coverage = hard_coverage.gather(-1, sel_tok_idxes).unsqueeze(-1) # [*, ?, 1] lf_repos = self.repos[maxr_idxes.gather(-1, sel_tok_idxes)] # [*, ?, ~] # todo(+3): using soft version? lf_prev_state = prev_state_h.unsqueeze(-2) # [*, 1, ~] lab_hid_expr = self.lab_f([lf_input_expr, lf_coverage, lf_repos, lf_prev_state]) # [*, ?, ~] # final predicting labels # todo(+N): here we select only max at labeling part, only beam at previous one if free_mode: sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, None) # [*, ?] else: sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, force_lidx.unsqueeze(-1)) # no lab-logprob (*=0) for eos (sel_tok==0) sel_lab_logprobs *= (sel_tok_idxes>0).float() # compute next-state [*, ?, ~] # todo(note): here we flatten the first two dims tmp_rnn_dims = BK.get_shape(sel_tok_idxes) + [-1] tmp_rnn_input = BK.concat([lab_hid_expr, sel_lab_embeds], -1) tmp_rnn_input = tmp_rnn_input.view(-1, BK.get_shape(tmp_rnn_input, -1)) tmp_rnn_hidden = [z.unsqueeze(-2).expand(tmp_rnn_dims).contiguous().view(-1, BK.get_shape(z, -1)) for z in prev_state] # [*, ?, ?, D] next_state = self.rnn_unit(tmp_rnn_input, tmp_rnn_hidden, None) next_state = [z.view(tmp_rnn_dims) for z in next_state] return sel_tok_idxes, sel_tok_logprobs, sel_lab_idxes, sel_lab_logprobs, sel_lab_embeds, next_state
def _pmask2idxes(self, pred_mask): orig_shape = BK.get_shape(pred_mask) dim_type = orig_shape[-1] flattened_mask = pred_mask.view(orig_shape[:-2] + [-1]) # [*, slen*L] f_idxes, sel_valid_mask = BK.mask2idx(flattened_mask) # [*, max-count] # then back to the two dimensions sel_idxes, sel_lab_idxes = f_idxes // dim_type, f_idxes % dim_type # the embeddings sel_shape = BK.get_shape(sel_idxes) if sel_shape[-1] == 0: sel_lab_embeds = BK.zeros(sel_shape + [self.conf.lab_conf.n_dim]) else: assert not self.hl.conf.use_lookup_soft, "Cannot do soft-lookup in this mode" sel_lab_embeds = self.hl.lookup(sel_lab_idxes) return sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
def loss(self, repr_t, pred_mask_repl_arr, pred_idx_arr): mask_idxes, mask_valids = BK.mask2idx( BK.input_real(pred_mask_repl_arr)) # [bsize, ?] if BK.get_shape(mask_idxes, -1) == 0: # no loss zzz = BK.zeros([]) return [[zzz, zzz, zzz]] else: target_reprs = BK.gather_first_dims(repr_t, mask_idxes, 1) # [bsize, ?, *] target_hids = self.hid_layer(target_reprs) target_scores = self.pred_layer(target_hids) # [bsize, ?, V] pred_idx_t = BK.input_idx(pred_idx_arr) # [bsize, slen] target_idx_t = pred_idx_t.gather(-1, mask_idxes) # [bsize, ?] target_idx_t[(mask_valids < 1.)] = 0 # make sure invalid ones in range # get loss pred_losses = BK.loss_nll(target_scores, target_idx_t) # [bsize, ?] pred_loss_sum = (pred_losses * mask_valids).sum() pred_loss_count = mask_valids.sum() # argmax _, argmax_idxes = target_scores.max(-1) pred_corrs = (argmax_idxes == target_idx_t).float() * mask_valids pred_corr_count = pred_corrs.sum() return [[pred_loss_sum, pred_loss_count, pred_corr_count]]
def loss(self, insts: List[ParseInstance], enc_expr, mask_expr, **kwargs): conf = self.conf # scoring arc_score, lab_score = self._score(enc_expr, mask_expr) # [bs, m, h, *] # loss bsize, max_len = BK.get_shape(mask_expr) # gold heads and labels gold_heads_arr, _ = self.predict_padder.pad( [z.heads.vals for z in insts]) # todo(note): here use the original idx of label, no shift! gold_labels_arr, _ = self.predict_padder.pad( [z.labels.idxes for z in insts]) gold_heads_expr = BK.input_idx(gold_heads_arr) # [bs, Len] gold_labels_expr = BK.input_idx(gold_labels_arr) # [bs, Len] # collect the losses arange_bs_expr = BK.arange_idx(bsize).unsqueeze(-1) # [bs, 1] arange_m_expr = BK.arange_idx(max_len).unsqueeze(0) # [1, Len] # logsoftmax and losses arc_logsoftmaxs = BK.log_softmax(arc_score.squeeze(-1), -1) # [bs, m, h] lab_logsoftmaxs = BK.log_softmax(lab_score, -1) # [bs, m, h, Lab] arc_sel_ls = arc_logsoftmaxs[arange_bs_expr, arange_m_expr, gold_heads_expr] # [bs, Len] lab_sel_ls = lab_logsoftmaxs[arange_bs_expr, arange_m_expr, gold_heads_expr, gold_labels_expr] # [bs, Len] # head selection (no root) arc_loss_sum = (-arc_sel_ls * mask_expr)[:, 1:].sum() lab_loss_sum = (-lab_sel_ls * mask_expr)[:, 1:].sum() final_loss = conf.lambda_arc * arc_loss_sum + conf.lambda_lab * lab_loss_sum final_loss_count = mask_expr[:, 1:].sum() return [[final_loss, final_loss_count]]
def loss(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs): conf = self.conf # score scores_t = self._score(repr_t) # [bs, ?+rlen, D] # get gold gold_pidxes = np.zeros(BK.get_shape(mask_t), dtype=np.long) # [bs, ?+rlen] for bidx, inst in enumerate(insts): cur_seq_idxes = getattr(inst, self.attr_name).idxes if self.add_root_token: gold_pidxes[bidx, 1:1 + len(cur_seq_idxes)] = cur_seq_idxes else: gold_pidxes[bidx, :len(cur_seq_idxes)] = cur_seq_idxes # get loss margin = self.margin.value gold_pidxes_t = BK.input_idx(gold_pidxes) gold_pidxes_t *= (gold_pidxes_t < self.pred_out_dim).long() # 0 means invalid ones!! loss_mask_t = (gold_pidxes_t > 0).float() * mask_t # [bs, ?+rlen] lab_losses_t = BK.loss_nll(scores_t, gold_pidxes_t, margin=margin) # [bs, ?+rlen] # argmax _, argmax_idxes = scores_t.max(-1) pred_corrs = (argmax_idxes == gold_pidxes_t).float() * loss_mask_t # compile loss lab_loss = LossHelper.compile_leaf_info("slab", lab_losses_t.sum(), loss_mask_t.sum(), corr=pred_corrs.sum()) return self._compile_component_loss(self.pname, [lab_loss])
def loss(self, ms_items: List, bert_expr, basic_expr): conf = self.conf bsize = len(ms_items) # use gold targets: only use positive samples!! offsets_t, masks_t, _, items_arr, labels_t = PrepHelper.prep_targets( ms_items, lambda x: x.events, True, False, 0., 0., True) # [bs, ?] realis_flist = [(-1 if (z is None or z.realis_idx is None) else z.realis_idx) for z in items_arr.flatten()] realis_t = BK.input_idx(realis_flist).view(items_arr.shape) # [bs, ?] realis_mask = (realis_t >= 0).float() realis_t.clamp_(min=0) # make sure all idxes are legal # ----- # return 0 if all no targets if BK.get_shape(offsets_t, -1) == 0: zzz = BK.zeros([]) return [[zzz, zzz, zzz], [zzz, zzz, zzz]] # realis, types # ----- arange_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] sel_bert_t = bert_expr[arange_t, offsets_t] # [bsize, ?, Fold, D] sel_basic_t = None if basic_expr is None else basic_expr[ arange_t, offsets_t] # [bsize, ?, D'] hiddens = self.adp(sel_bert_t, sel_basic_t, []) # [bsize, ?, D"] # build losses loss_item_realis = self._get_one_loss(self.realis_predictor, hiddens, realis_t, realis_mask, conf.lambda_realis) loss_item_type = self._get_one_loss(self.type_predictor, hiddens, labels_t, masks_t, conf.lambda_type) return [loss_item_realis, loss_item_type]
def __call__(self, word_arr: np.ndarray = None, char_arr: np.ndarray = None, extra_arrs: Iterable[np.ndarray] = (), aux_arrs: Iterable[np.ndarray] = ()): exprs = [] # word/char/extras/posi seq_shape = None if self.has_word: # todo(warn): singleton-UNK-dropout should be done outside before seq_shape = word_arr.shape word_expr = self.dropmd_word(self.word_embed(word_arr)) exprs.append(word_expr) if self.has_char: seq_shape = char_arr.shape[:-1] char_embeds = self.char_embed( char_arr) # [*, seq-len, word-len, D] char_cat_expr = self.dropmd_char( BK.concat([z(char_embeds) for z in self.char_cnns])) exprs.append(char_cat_expr) zcheck( len(extra_arrs) == len(self.extra_embeds), "Unmatched extra fields.") for one_extra_arr, one_extra_embed, one_extra_dropmd in zip( extra_arrs, self.extra_embeds, self.dropmd_extras): seq_shape = one_extra_arr.shape exprs.append(one_extra_dropmd(one_extra_embed(one_extra_arr))) if self.has_posi: seq_len = seq_shape[-1] posi_idxes = BK.arange_idx(seq_len) posi_input0 = self.posi_embed(posi_idxes) for _ in range(len(seq_shape) - 1): posi_input0 = BK.unsqueeze(posi_input0, 0) posi_input1 = BK.expand(posi_input0, tuple(seq_shape) + (-1, )) exprs.append(posi_input1) # assert len(aux_arrs) == len(self.drop_auxes) for one_aux_arr, one_aux_dim, one_aux_drop, one_fold, one_gamma, one_lambdas in \ zip(aux_arrs, self.dim_auxes, self.drop_auxes, self.fold_auxes, self.aux_overall_gammas, self.aux_fold_lambdas): # fold and apply trainable lambdas input_aux_repr = BK.input_real(one_aux_arr) input_shape = BK.get_shape(input_aux_repr) # todo(note): assume the original concat is [fold/layer, D] reshaped_aux_repr = input_aux_repr.view( input_shape[:-1] + [one_fold, one_aux_dim]) # [*, slen, fold, D] lambdas_softmax = BK.softmax(one_gamma, -1).unsqueeze(-1) # [fold, 1] weighted_aux_repr = (reshaped_aux_repr * lambdas_softmax ).sum(-2) * one_gamma # [*, slen, D] one_aux_expr = one_aux_drop(weighted_aux_repr) exprs.append(one_aux_expr) # concated_exprs = BK.concat(exprs, dim=-1) # optional proj if self.has_proj: final_expr = self.final_layer(concated_exprs) else: final_expr = concated_exprs return final_expr
def loss(self, ms_items: List, bert_expr, basic_expr, margin=0.): conf = self.conf bsize = len(ms_items) # build targets (include all sents) # todo(note): use "x.entity_fillers" for getting gold args offsets_t, masks_t, _, items_arr, labels_t = PrepHelper.prep_targets( ms_items, lambda x: x.entity_fillers, True, True, conf.train_neg_rate, conf.train_neg_rate_outside, True) labels_t.clamp_(max=1) # either 0 or 1 # ----- # return 0 if all no targets if BK.get_shape(offsets_t, -1) == 0: zzz = BK.zeros([]) return [[zzz, zzz, zzz]] # ----- arange_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] sel_bert_t = bert_expr[arange_t, offsets_t] # [bsize, ?, Fold, D] sel_basic_t = None if basic_expr is None else basic_expr[ arange_t, offsets_t] # [bsize, ?, D'] hiddens = self.adp(sel_bert_t, sel_basic_t, []) # [bsize, ?, D"] # build loss logits = self.predictor(hiddens) # [bsize, ?, Out] log_probs = BK.log_softmax(logits, -1) picked_log_probs = -BK.gather_one_lastdim(log_probs, labels_t).squeeze( -1) # [bsize, ?] masked_losses = picked_log_probs * masks_t # loss_sum, loss_count, gold_count return [[ masked_losses.sum(), masks_t.sum(), (labels_t > 0).float().sum() ]]
def _select_topk(self, masked_scores, pad_mask, ratio_mask, topk_ratio, thresh_k): slen = BK.get_shape(masked_scores, -1) sel_mask = BK.copy(pad_mask) # first apply the absolute thresh if thresh_k is not None: sel_mask *= (masked_scores > thresh_k).float() # then ratio-ed topk if topk_ratio > 0.: # prepare number cur_topk_num = ratio_mask.sum(-1) # [*] cur_topk_num = (cur_topk_num * topk_ratio).long() # [*] cur_topk_num.clamp_(min=1, max=slen) # at least one, at most all # topk actual_max_k = max(cur_topk_num.max().item(), 1) topk_score, _ = BK.topk(masked_scores, actual_max_k, dim=-1, sorted=True) # [*, k] thresh_score = topk_score.gather( -1, cur_topk_num.clamp(min=1).unsqueeze(-1) - 1) # [*, 1] # get mask and apply sel_mask *= (masked_scores >= thresh_score).float() return sel_mask
def _losses_global_prob(self, full_score_expr, gold_heads_expr, gold_labels_expr, marginals_expr, mask_expr): # combine the last two dimension full_shape = BK.get_shape(full_score_expr) last_size = full_shape[-1] # [*, m, h*L] combined_marginals_expr = marginals_expr.view(full_shape[:-2] + [-1]) # # todo(warn): make sure sum to 1., handled in algorithm instead # combined_marginals_expr = combined_marginals_expr / combined_marginals_expr.sum(dim=-1, keepdim=True) # [*, m] gold_combined_idx_expr = gold_heads_expr * last_size + gold_labels_expr # [*, m, h, L] gradients = BK.minus_margin(combined_marginals_expr, gold_combined_idx_expr, 1.).view(full_shape) # the gradients on h are already 0. from the marginal algorithm gradients_masked = gradients * mask_expr.unsqueeze(-1).unsqueeze( -1) * mask_expr.unsqueeze(-2).unsqueeze(-1) # for the h-dimension, need to divide by the real length. # todo(warn): this values should be directly summed rather than averaged, since directly from loss fake_losses = (full_score_expr * gradients_masked).sum(-1).sum( -1) # [BS, m] # todo(warn): be aware of search-error-like output constrains; # but this clamp for all is not good for loss-prob, dealt at outside with unproj-mask. # <bad> fake_losses = BK.clamp(fake_losses, min=0.) return fake_losses
def get_losses_global_hinge(full_score_expr, gold_heads_expr, gold_labels_expr, pred_heads_expr, pred_labels_expr, mask_expr, clamping=True): # combine the last two dimension full_shape = BK.get_shape(full_score_expr) # [*, m, h*L] last_size = full_shape[-1] combiend_score_expr = full_score_expr.view(full_shape[:-2] + [-1]) # [*, m] gold_combined_idx_expr = gold_heads_expr * last_size + gold_labels_expr pred_combined_idx_expr = pred_heads_expr * last_size + pred_labels_expr # [*, m] gold_scores = BK.gather_one_lastdim(combiend_score_expr, gold_combined_idx_expr).squeeze(-1) pred_scores = BK.gather_one_lastdim(combiend_score_expr, pred_combined_idx_expr).squeeze(-1) # todo(warn): be aware of search error! # hinge_losses = BK.clamp(pred_scores - gold_scores, min=0.) # this is previous version hinge_losses = pred_scores - gold_scores # [*, len] if clamping: valid_losses = ((hinge_losses * mask_expr)[:, 1:].sum(-1) > 0.).float().unsqueeze(-1) # [*, 1] return hinge_losses * valid_losses else: # for this mode, will there be problems of search error? Maybe rare. return hinge_losses
def _score_label_full(self, scoring_expr_pack, mask_expr, training, margin, gold_heads_expr=None, gold_labels_expr=None): _, _, lm_expr, lh_expr = scoring_expr_pack # [BS, len-m, len-h, L] full_label_score = self.scorer.score_label_all(lm_expr, lh_expr, mask_expr, mask_expr) # # set diag to small values # todo(warn): handled specifically in algorithms # maxlen = BK.get_shape(full_label_score, 1) # full_label_score += BK.diagflat(BK.constants([maxlen], Constants.REAL_PRAC_MIN)).unsqueeze(-1) # margin? -- specially reshaping if training and margin > 0.: full_shape = BK.get_shape(full_label_score) # combine last two dim combiend_score_expr = full_label_score.view(full_shape[:-2] + [-1]) combined_idx_expr = gold_heads_expr * full_shape[ -1] + gold_labels_expr combined_changed_score = BK.minus_margin(combiend_score_expr, combined_idx_expr, margin) full_label_score = combined_changed_score.view(full_shape) return full_label_score
def get_losses_from_attn_list(list_attn_info: List, ts_f, loss_f, loss_prefix, loss_lambda): loss_num = None loss_counts: List[int] = [] loss_sums: List[List] = [] rets = [] # ----- for one_attn_info in list_attn_info: # each update step one_ts: List = ts_f( one_attn_info) # get tensor list from attn_info # get number of losses if loss_num is None: loss_num = len(one_ts) loss_counts = [0] * loss_num loss_sums = [[] for _ in range(loss_num)] else: assert len(one_ts) == loss_num, "mismatched ts length" # iter them for one_t_idx, one_t in enumerate( one_ts): # iter on the tensor list one_loss = loss_f(one_t) # need it to be in the corresponding shape loss_counts[one_t_idx] += np.prod( BK.get_shape(one_loss)).item() loss_sums[one_t_idx].append(one_loss.sum()) # for different steps for i, one_loss_count, one_loss_sums in zip(range(len(loss_counts)), loss_counts, loss_sums): loss_leaf = LossHelper.compile_leaf_info( f"{loss_prefix}{i}", BK.stack(one_loss_sums, 0).sum(), BK.input_real(one_loss_count), loss_lambda=loss_lambda) rets.append(loss_leaf) return rets
def get_selected_label_scores(self, idxes_m_t, idxes_h_t, bsize_range_t, oracle_mask_t, oracle_label_t, arc_margin: float, label_margin: float): # todo(note): in this mode, no repeated arc_margin dim1_range_t = bsize_range_t dim2_range_t = dim1_range_t.unsqueeze(-1) if self.system_labeled: selected_m_cache = [ z[dim2_range_t, idxes_m_t] for z in self.mod_label_cache ] selected_h_repr = self.head_label_cache[dim2_range_t, idxes_h_t] ret = self.scorer.score_label(selected_m_cache, selected_h_repr) # [*, k, labels] if label_margin > 0.: oracle_label_idxes = oracle_label_t[dim2_range_t, idxes_m_t, idxes_h_t].unsqueeze( -1) # [*, k, 1] of int ret.scatter_add_( -1, oracle_label_idxes, BK.constants(oracle_label_idxes.shape, -label_margin)) else: # todo(note): otherwise, simply put zeros (with idx=0 as the slightly best to be consistent) ret = BK.zeros(BK.get_shape(idxes_m_t) + [self.num_label]) ret[:, :, 0] += 0.01 if self.g1_lab_scores is not None: ret += self.g1_lab_scores[dim2_range_t, idxes_m_t, idxes_h_t] return ret
def nmst_greedy(scores_expr, mask_expr, lengths_arr, labeled=True, ret_arr=False): assert labeled with BK.no_grad_env(): scores_shape = BK.get_shape(scores_expr) maxlen = scores_shape[1] # mask out diag scores_expr += BK.diagflat( BK.constants([maxlen], Constants.REAL_PRAC_MIN)).unsqueeze(-1) # combined last two dimension and Max over them combined_scores_expr = scores_expr.view(scores_shape[:-2] + [-1]) combine_max_scores, combined_max_idxes = BK.max(combined_scores_expr, dim=-1) # back to real idxes last_size = scores_shape[-1] greedy_heads = combined_max_idxes // last_size greedy_labels = combined_max_idxes % last_size if ret_arr: mst_heads_arr, mst_labels_arr, mst_scores_arr = [ BK.get_value(z) for z in (greedy_heads, greedy_labels, combine_max_scores) ] return mst_heads_arr, mst_labels_arr, mst_scores_arr else: return greedy_heads, greedy_labels, combine_max_scores
def _get_basic_score(self, mb_enc_expr, batch_idxes, m_idxes, h_idxes, sib_idxes, gp_idxes): allp_size = BK.get_shape(batch_idxes, 0) all_arc_scores, all_lab_scores = [], [] cur_pidx = 0 while cur_pidx < allp_size: next_pidx = min(allp_size, cur_pidx + self.mb_dec_sb) # first calculate srepr s_enc = self.slayer cur_batch_idxes = batch_idxes[cur_pidx:next_pidx] h_expr = mb_enc_expr[cur_batch_idxes, h_idxes[cur_pidx:next_pidx]] m_expr = mb_enc_expr[cur_batch_idxes, m_idxes[cur_pidx:next_pidx]] s_expr = mb_enc_expr[cur_batch_idxes, sib_idxes[cur_pidx:next_pidx]].unsqueeze(-2) \ if (sib_idxes is not None) else None # [*, 1, D] g_expr = mb_enc_expr[cur_batch_idxes, gp_idxes[cur_pidx:next_pidx]] if ( gp_idxes is not None) else None head_srepr = s_enc.calculate_repr(h_expr, g_expr, None, None, s_expr, None, None, None) mod_srepr = s_enc.forward_repr(m_expr) # then get the scores arc_score = self.scorer.transform_and_arc_score_plain( mod_srepr, head_srepr).squeeze(-1) all_arc_scores.append(arc_score) if self.system_labeled: lab_score = self.scorer.transform_and_label_score_plain( mod_srepr, head_srepr) all_lab_scores.append(lab_score) cur_pidx = next_pidx final_arc_score = BK.concat(all_arc_scores, 0) final_lab_score = BK.concat(all_lab_scores, 0) if self.system_labeled else None return final_arc_score, final_lab_score
def __call__(self, scores, temperature=1., dim=-1): is_training = self.rop.training # only use stochastic at training if is_training: if self.use_gumbel: gumbel_eps = self.gumbel_eps G = (BK.rand(BK.get_shape(scores)) + gumbel_eps).clamp( max=1.) # [0,1) scores = scores - (gumbel_eps - G.log()).log() # normalize probs = BK.softmax(scores / temperature, dim=dim) # [*, S] # prune and re-normalize? if self.prune_val > 0.: probs = probs * (probs > self.prune_val).float() # todo(note): currently no re-normalize # probs = probs / probs.sum(dim=dim, keepdim=True) # [*, S] # argmax and ste if self.use_argmax: # use the hard argmax max_probs, _ = probs.max(dim, keepdim=True) # [*, 1] # todo(+N): currently we do not re-normalize here, should it be done here? st_probs = (probs >= max_probs).float() * probs # [*, S] if is_training: # (hard-soft).detach() + soft st_probs = (st_probs - probs).detach() + probs # [*, S] return st_probs else: return probs
def _my_loss_prob(self, score_expr, gold_idxes_expr, entropy_lambda: float, loss_mask, neg_reweight: bool): probs = BK.softmax(score_expr, -1) # [*, NLab] log_probs = BK.log(probs + 1e-8) # first plain NLL loss nll_loss = -BK.gather_one_lastdim(log_probs, gold_idxes_expr).squeeze(-1) # next the special loss if entropy_lambda > 0.: negative_entropy = probs * log_probs # [*, NLab] last_dim = BK.get_shape(score_expr, -1) confusion_matrix = 1. - BK.eye(last_dim) # [Nlab, Nlab] entropy_mask = confusion_matrix[gold_idxes_expr] # [*, Nlab] entropy_loss = (negative_entropy * entropy_mask).sum(-1) final_loss = nll_loss + entropy_lambda * entropy_loss else: final_loss = nll_loss # reweight? if neg_reweight: golden_prob = BK.gather_one_lastdim(probs, gold_idxes_expr).squeeze(-1) is_full_nil = (gold_idxes_expr == 0.).float() not_full_nil = 1. - is_full_nil count_pos = (loss_mask * not_full_nil).sum() count_neg = (loss_mask * is_full_nil).sum() prob_pos = (loss_mask * not_full_nil * golden_prob).sum() prob_neg = (loss_mask * is_full_nil * golden_prob).sum() neg_weight = prob_pos / (count_pos + count_neg - prob_neg + 1e-8) final_weights = not_full_nil + is_full_nil * neg_weight # todo(note): final mask will be applied at outside final_loss = final_loss * final_weights return final_loss