def _get_basic_score(self, mb_enc_expr, batch_idxes, m_idxes, h_idxes, sib_idxes, gp_idxes): allp_size = BK.get_shape(batch_idxes, 0) all_arc_scores, all_lab_scores = [], [] cur_pidx = 0 while cur_pidx < allp_size: next_pidx = min(allp_size, cur_pidx + self.mb_dec_sb) # first calculate srepr s_enc = self.slayer cur_batch_idxes = batch_idxes[cur_pidx:next_pidx] h_expr = mb_enc_expr[cur_batch_idxes, h_idxes[cur_pidx:next_pidx]] m_expr = mb_enc_expr[cur_batch_idxes, m_idxes[cur_pidx:next_pidx]] s_expr = mb_enc_expr[cur_batch_idxes, sib_idxes[cur_pidx:next_pidx]].unsqueeze(-2) \ if (sib_idxes is not None) else None # [*, 1, D] g_expr = mb_enc_expr[cur_batch_idxes, gp_idxes[cur_pidx:next_pidx]] if ( gp_idxes is not None) else None head_srepr = s_enc.calculate_repr(h_expr, g_expr, None, None, s_expr, None, None, None) mod_srepr = s_enc.forward_repr(m_expr) # then get the scores arc_score = self.scorer.transform_and_arc_score_plain( mod_srepr, head_srepr).squeeze(-1) all_arc_scores.append(arc_score) if self.system_labeled: lab_score = self.scorer.transform_and_label_score_plain( mod_srepr, head_srepr) all_lab_scores.append(lab_score) cur_pidx = next_pidx final_arc_score = BK.concat(all_arc_scores, 0) final_lab_score = BK.concat(all_lab_scores, 0) if self.system_labeled else None return final_arc_score, final_lab_score
def __call__(self, word_arr: np.ndarray = None, char_arr: np.ndarray = None, extra_arrs: Iterable[np.ndarray] = (), aux_arrs: Iterable[np.ndarray] = ()): exprs = [] # word/char/extras/posi seq_shape = None if self.has_word: # todo(warn): singleton-UNK-dropout should be done outside before seq_shape = word_arr.shape word_expr = self.dropmd_word(self.word_embed(word_arr)) exprs.append(word_expr) if self.has_char: seq_shape = char_arr.shape[:-1] char_embeds = self.char_embed( char_arr) # [*, seq-len, word-len, D] char_cat_expr = self.dropmd_char( BK.concat([z(char_embeds) for z in self.char_cnns])) exprs.append(char_cat_expr) zcheck( len(extra_arrs) == len(self.extra_embeds), "Unmatched extra fields.") for one_extra_arr, one_extra_embed, one_extra_dropmd in zip( extra_arrs, self.extra_embeds, self.dropmd_extras): seq_shape = one_extra_arr.shape exprs.append(one_extra_dropmd(one_extra_embed(one_extra_arr))) if self.has_posi: seq_len = seq_shape[-1] posi_idxes = BK.arange_idx(seq_len) posi_input0 = self.posi_embed(posi_idxes) for _ in range(len(seq_shape) - 1): posi_input0 = BK.unsqueeze(posi_input0, 0) posi_input1 = BK.expand(posi_input0, tuple(seq_shape) + (-1, )) exprs.append(posi_input1) # assert len(aux_arrs) == len(self.drop_auxes) for one_aux_arr, one_aux_dim, one_aux_drop, one_fold, one_gamma, one_lambdas in \ zip(aux_arrs, self.dim_auxes, self.drop_auxes, self.fold_auxes, self.aux_overall_gammas, self.aux_fold_lambdas): # fold and apply trainable lambdas input_aux_repr = BK.input_real(one_aux_arr) input_shape = BK.get_shape(input_aux_repr) # todo(note): assume the original concat is [fold/layer, D] reshaped_aux_repr = input_aux_repr.view( input_shape[:-1] + [one_fold, one_aux_dim]) # [*, slen, fold, D] lambdas_softmax = BK.softmax(one_gamma, -1).unsqueeze(-1) # [fold, 1] weighted_aux_repr = (reshaped_aux_repr * lambdas_softmax ).sum(-2) * one_gamma # [*, slen, D] one_aux_expr = one_aux_drop(weighted_aux_repr) exprs.append(one_aux_expr) # concated_exprs = BK.concat(exprs, dim=-1) # optional proj if self.has_proj: final_expr = self.final_layer(concated_exprs) else: final_expr = concated_exprs return final_expr
def lookup(self, insts: List, input_lexi, input_expr, input_mask): conf = self.conf bsize = len(insts) # first get gold/input info, also multiple valid-masks gold_masks, gold_idxes, gold_items_arr, gold_valid, gold_idxes2, gold_items2_arr = self.batch_inputs_h( insts) # step 1: no selection, simply forward using gold_masks sel_idxes, sel_valid_mask = BK.mask2idx(gold_masks) # [*, max-count] sel_gold_idxes = gold_idxes.gather(-1, sel_idxes) sel_gold_idxes2 = gold_idxes2.gather(-1, sel_idxes) # todo(+N): only get items by head position! _tmp_i0, _tmp_i1 = np.arange(bsize)[:, np.newaxis], BK.get_value( sel_idxes) sel_items = gold_items_arr[_tmp_i0, _tmp_i1] # [*, mc] sel2_items = gold_items2_arr[_tmp_i0, _tmp_i1] # step 2: encoding and labeling sel_shape = BK.get_shape(sel_idxes) if sel_shape[-1] == 0: sel_lab_idxes = sel_gold_idxes sel_lab_embeds = BK.zeros(sel_shape + [conf.lab_conf.n_dim]) ret_items = sel_items # dim-1==0 else: # sel_hid_exprs = self._enc(input_expr, input_mask, sel_idxes) # [*, mc, DLab] sel_lab_idxes = sel_gold_idxes sel_lab_embeds = self.hl.lookup( sel_lab_idxes) # todo(note): here no softlookup? ret_items = sel_items # second type if self.use_secondary_type: sel2_lab_idxes = sel_gold_idxes2 sel2_lab_embeds = self.hl.lookup( sel2_lab_idxes) # todo(note): here no softlookup? sel2_valid_mask = (sel2_lab_idxes > 0).float() # combine the two if sel2_lab_idxes.sum().item( ) > 0: # if there are any gold sectypes ret_items = np.concatenate([ret_items, sel2_items], -1) # [*, mc*2] sel_idxes = BK.concat([sel_idxes, sel_idxes], -1) sel_valid_mask = BK.concat( [sel_valid_mask, sel2_valid_mask], -1) sel_lab_idxes = BK.concat([sel_lab_idxes, sel2_lab_idxes], -1) sel_lab_embeds = BK.concat( [sel_lab_embeds, sel2_lab_embeds], -2) # step 3: exclude nil assuming no deliberate nil in gold/inputs if conf.exclude_nil: # [*, mc', ...] sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, _, ret_items = \ self._exclude_nil(sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_items_arr=ret_items) # step 4: return # sel_enc_expr = BK.gather_first_dims(input_expr, sel_idxes, -2) # [*, mc', D] # mask out invalid items with None ret_items[BK.get_value(1. - sel_valid_mask).astype(np.bool)] = None return ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
def __call__(self, char_input, add_root_token: bool): char_input_t = BK.input_idx(char_input) # [*, slen, wlen] if add_root_token: slice_shape = BK.get_shape(char_input_t) slice_shape[-2] = 1 char_input_t0 = BK.constants(slice_shape, 0, dtype=char_input_t.dtype) # todo(note): simply put 0 here! char_input_t1 = BK.concat([char_input_t0, char_input_t], -2) # [*, 1?+slen, wlen] else: char_input_t1 = char_input_t char_embeds = self.E(char_input_t1) # [*, 1?+slen, wlen, D] char_cat_expr = BK.concat([z(char_embeds) for z in self.char_cnns]) return self.dropout(char_cat_expr) # todo(note): only final dropout
def _special_score( one_score): # specially change ablpair scores into [bs,m,h,*] root_score = one_score[:, :, 0].unsqueeze(2) # [bs, rlen, 1, *] tmp_shape = BK.get_shape(root_score) tmp_shape[1] = 1 # [bs, 1, 1, *] padded_root_score = BK.concat([BK.zeros(tmp_shape), root_score], dim=1) # [bs, rlen+1, 1, *] final_score = BK.concat( [padded_root_score, one_score.transpose(1, 2)], dim=2) # [bs, rlen+1[m], rlen+1[h], *] return final_score
def run_sents(self, all_sents: List, all_docs: List[DocInstance], training: bool, use_one_bucket=False): if use_one_bucket: all_buckets = [all_sents] # when we do not want to split if we know the input lengths do not vary too much else: all_sents.sort(key=lambda x: x[0].length) all_buckets = self._bucket_sents_by_length(all_sents, self.bconf.enc_bucket_range) # doc hint use_doc_hint = self.use_doc_hint if use_doc_hint: dh_sent_repr = self.dh_node.run(all_docs) # [NumDoc, MaxSent, D] else: dh_sent_repr = None # encoding for each of the bucket rets = [] dh_add, dh_both, dh_cls = [self.dh_combine_method==z for z in ["add", "both", "cls"]] for one_bucket in all_buckets: one_sents = [z[0] for z in one_bucket] # [BS, Len, Di], [BS, Len] input_repr0, mask_arr0 = self._prepare_input(one_sents, training) if use_doc_hint: one_d_idxes = BK.input_idx([z[1] for z in one_bucket]) one_s_idxes = BK.input_idx([z[2] for z in one_bucket]) one_s_reprs = dh_sent_repr[one_d_idxes, one_s_idxes].unsqueeze(-2) # [BS, 1, D] if dh_add: input_repr = input_repr0 + one_s_reprs # [BS, slen, D] mask_arr = mask_arr0 elif dh_both: input_repr = BK.concat([one_s_reprs, input_repr0, one_s_reprs], -2) # [BS, 2+slen, D] mask_arr = np.pad(mask_arr0, ((0,0),(1,1)), 'constant', constant_values=1.) # [BS, 2+slen] elif dh_cls: input_repr = BK.concat([one_s_reprs, input_repr0[:, 1:]], -2) # [BS, slen, D] mask_arr = mask_arr0 else: raise NotImplementedError() else: input_repr, mask_arr = input_repr0, mask_arr0 # [BS, Len, De] enc_repr = self.enc(input_repr, mask_arr) # separate ones (possibly using detach to avoid gradients for some of them) enc_repr_ef = self.enc_ef(enc_repr.detach() if self.bconf.enc_ef_input_detach else enc_repr, mask_arr) enc_repr_evt = self.enc_evt(enc_repr.detach() if self.bconf.enc_evt_input_detach else enc_repr, mask_arr) if use_doc_hint and dh_both: one_ret = (one_sents, input_repr0, enc_repr_ef[:, 1:-1].contiguous(), enc_repr_evt[:, 1:-1].contiguous(), mask_arr0) else: one_ret = (one_sents, input_repr0, enc_repr_ef, enc_repr_evt, mask_arr0) rets.append(one_ret) # todo(note): returning tuple is (List[Sentence], Tensor, Tensor, Tensor) return rets
def calculate_repr(self, cur_t, par_t, label_t, par_mask_t, chs_t, chs_label_t, chs_mask_t, chs_valid_mask_t): ret_t = cur_t # [*, D] # padding 0 if not using labels dim_label = self.dim_label # child features if self.use_chs and chs_t is not None: if self.use_label_feat: chs_label_rt = self.label_embeddings( chs_label_t) # [*, max-chs, dlab] else: labels_shape = BK.get_shape(chs_t) labels_shape[-1] = dim_label chs_label_rt = BK.zeros(labels_shape) chs_input_t = BK.concat([chs_t, chs_label_rt], -1) chs_feat0 = self.chs_reprer(cur_t, chs_input_t, chs_mask_t, chs_valid_mask_t) chs_feat = self.chs_ff(chs_feat0) ret_t += chs_feat # parent features if self.use_par and par_t is not None: if self.use_label_feat: cur_label_t = self.label_embeddings(label_t) # [*, dlab] else: labels_shape = BK.get_shape(par_t) labels_shape[-1] = dim_label cur_label_t = BK.zeros(labels_shape) par_feat = self.par_ff([par_t, cur_label_t]) if par_mask_t is not None: par_feat *= par_mask_t.unsqueeze(-1) ret_t += par_feat return ret_t
def _ts_self_f(_list_attn_info): _rets = [] for _t, _d in zip(_list_attn_info[3], _list_attn_info[4]): # extra dim at idx 0; todo(note): must repeat insts at outmost idx: repeated = insts * copy_num _t1 = _t.view([copy_num, -1] + BK.get_shape(_t)[1:]) # [copy, bs, ...] # roll it by 1 _t2 = BK.concat([_t1[-1].unsqueeze(0), _t1[:-1]], dim=0) _rets.append((_t1, _t2, _d)) return _rets
def inference_on_batch(self, insts: List[GeneralSentence], **kwargs): conf = self.conf self.refresh_batch(False) # print(f"{len(insts)}: {insts[0].sid}") with BK.no_grad_env(): # decode for dpar input_map = self.model.inputter(insts) emb_t, mask_t, enc_t, cache, _ = self.model._emb_and_enc(input_map, collect_loss=False) input_t = BK.concat(cache.list_attn, -1) # [bs, slen, slen, L*H] self.dpar.predict(insts, BK.zeros([1,1]), input_t, mask_t) return {}
def _enc(self, input_lexi, input_expr, input_mask, sel_idxes): if self.dmxnn: bsize, slen = BK.get_shape(input_mask) if sel_idxes is None: sel_idxes = BK.arange_idx(slen).unsqueeze( 0) # select all, [1, slen] ncand = BK.get_shape(sel_idxes, -1) # enc_expr aug with PE rel_dist = BK.arange_idx(slen).unsqueeze(0).unsqueeze( 0) - sel_idxes.unsqueeze(-1) # [*, ?, slen] pe_embeds = self.posi_embed(rel_dist) # [*, ?, slen, Dpe] aug_enc_expr = BK.concat([ pe_embeds.expand(bsize, -1, -1, -1), input_expr.unsqueeze(1).expand(-1, ncand, -1, -1) ], -1) # [*, ?, slen, D+Dpe] # [*, ?, slen, Denc] hidden_expr = self.e_encoder( aug_enc_expr.view(bsize * ncand, slen, -1), input_mask.unsqueeze(1).expand(-1, ncand, -1).contiguous().view( bsize * ncand, slen)) hidden_expr = hidden_expr.view(bsize, ncand, slen, -1) # dynamic max-pooling (dist<0, dist=0, dist>0) NEG = Constants.REAL_PRAC_MIN mp_hiddens = [] mp_masks = [rel_dist < 0, rel_dist == 0, rel_dist > 0] for mp_mask in mp_masks: float_mask = mp_mask.float() * input_mask.unsqueeze( -2) # [*, ?, slen] valid_mask = (float_mask.sum(-1) > 0.).float().unsqueeze( -1) # [*, ?, 1] mask_neg_val = ( 1. - float_mask).unsqueeze(-1) * NEG # [*, ?, slen, 1] # todo(+2): or do we simply multiply mask? mp_hid0 = (hidden_expr + mask_neg_val).max(-2)[0] mp_hid = mp_hid0 * valid_mask # [*, ?, Denc] mp_hiddens.append(self.special_drop(mp_hid)) # mp_hiddens.append(mp_hid) final_hiddens = mp_hiddens else: hidden_expr = self.e_encoder(input_expr, input_mask) # [*, slen, D'] if sel_idxes is None: hidden_expr1 = hidden_expr else: hidden_expr1 = BK.gather_first_dims(hidden_expr, sel_idxes, -2) # [*, ?, D'] final_hiddens = [self.special_drop(hidden_expr1)] if self.lab_f_use_lexi: final_hiddens.append( BK.gather_first_dims(input_lexi, sel_idxes, -2)) # [*, ?, DLex] ret_expr = self.lab_f(final_hiddens) # [*, ?, DLab] return ret_expr
def _flatten_packs(self, packs): NUM_RET_PACK = 6 # discard the first mb-size ret_packs = [[] for _ in range(NUM_RET_PACK)] cur_base_idx = 0 for one_pack in packs: mb_size = one_pack[0] ret_packs[0].append(one_pack[1] + cur_base_idx) for i in range(1, NUM_RET_PACK): ret_packs[i].append(one_pack[i + 1]) cur_base_idx += mb_size ret = [(None if z[0] is None else BK.concat(z, 0)) for z in ret_packs] return ret
def predict(self, insts: List, input_lexi, input_expr, input_mask): conf = self.conf # step 1: select mention candidates if conf.use_selector: sel_mask = self.sel.predict(input_expr, input_mask) else: sel_mask = input_mask sel_idxes, sel_valid_mask = BK.mask2idx(sel_mask) # [*, max-count] # step 2: encoding and labeling sel_hid_exprs = self._enc(input_lexi, input_expr, input_mask, sel_idxes) sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict( sel_hid_exprs, None) # [*, mc], [*, mc, D] # ===== if self.use_secondary_type: sectype_embeds = self.t1tot2(sel_lab_idxes) # [*, mc, D] sel2_input = sel_hid_exprs + sectype_embeds # [*, mc, D] sel2_lab_logprobs, sel2_lab_idxes, sel2_lab_embeds = self.hl.predict( sel2_input, None) if conf.sectype_t2ift1: sel2_lab_idxes *= ( sel_lab_idxes > 0).long() # pred t2 only if t1 is not 0 (nil) # first concat here and then exclude nil at one pass # [*, mc*2, ~] if sel2_lab_idxes.sum().item() > 0: # if there are any predictions sel_lab_logprobs = BK.concat( [sel_lab_logprobs, sel2_lab_logprobs], -1) sel_idxes = BK.concat([sel_idxes, sel_idxes], -1) sel_valid_mask = BK.concat([sel_valid_mask, sel_valid_mask], -1) sel_lab_idxes = BK.concat([sel_lab_idxes, sel2_lab_idxes], -1) sel_lab_embeds = BK.concat([sel_lab_embeds, sel2_lab_embeds], -2) # ===== # step 3: exclude nil and return if conf.exclude_nil: # [*, mc', ...] sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_lab_logprobs, _ = \ self._exclude_nil(sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_logprobs=sel_lab_logprobs) # sel_enc_expr = BK.gather_first_dims(input_expr, sel_idxes, -2) # [*, mc', D] return sel_lab_logprobs, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
def fb_on_batch(self, insts: List[GeneralSentence], training=True, loss_factor=1., rand_gen=None, assign_attns=False, **kwargs): self.refresh_batch(training) # get inputs with models with BK.no_grad_env(): input_map = self.model.inputter(insts) emb_t, mask_t, enc_t, cache, enc_loss = self.model._emb_and_enc(input_map, collect_loss=True) input_t = BK.concat(cache.list_attn, -1) # [bs, slen, slen, L*H] losses = [self.dpar.loss(insts, BK.zeros([1,1]), input_t, mask_t)] # ----- info = self.collect_loss_and_backward(losses, training, loss_factor) info.update({"fb": 1, "sent": len(insts), "tok": sum(len(z) for z in insts)}) return info
def __call__(self, input_map: Dict): exprs = [] # get masks: this mask is for validing of inst batching final_masks = BK.input_real(input_map["mask"]) # [*, slen] if self.add_root_token: # append 1 slice_t = BK.constants(BK.get_shape(final_masks)[:-1]+[1], 1.) final_masks = BK.concat([slice_t, final_masks], -1) # [*, 1+slen] # ----- # for each component for idx, name in enumerate(self.comp_names): cur_node = self.nodes[idx] cur_input = input_map[name] cur_expr = cur_node(cur_input, self.add_root_token) exprs.append(cur_expr) # ----- concated_exprs = BK.concat(exprs, dim=-1) # optional proj if self.has_proj: final_expr = self.final_layer(concated_exprs) else: final_expr = concated_exprs return final_expr, final_masks
def loss(self, insts: List, input_lexi, input_expr, input_mask, margin=0.): # todo(+N): currently margin is not used conf = self.conf bsize = len(insts) arange_t = BK.arange_idx(bsize) assert conf.train_force, "currently only have forced training" # get the gold ones gold_widxes, gold_lidxes, gold_vmasks, ret_items, _ = self.batch_inputs_g1(insts) # [*, ?] # for all the steps num_step = BK.get_shape(gold_widxes, -1) # recurrent states hard_coverage = BK.zeros(BK.get_shape(input_mask)) # [*, slen] prev_state = self.rnn_unit.zero_init_hidden(bsize) # tuple([*, D], ) all_tok_logprobs, all_lab_logprobs = [], [] for cstep in range(num_step): slice_widx, slice_lidx = gold_widxes[:,cstep], gold_lidxes[:,cstep] _, sel_tok_logprobs, _, sel_lab_logprobs, _, next_state = \ self._step(input_expr, input_mask, hard_coverage, prev_state, slice_widx, slice_lidx, None) all_tok_logprobs.append(sel_tok_logprobs) # add one of [*, 1] all_lab_logprobs.append(sel_lab_logprobs) hard_coverage = BK.copy(hard_coverage) # todo(note): cannot modify inplace! hard_coverage[arange_t, slice_widx] += 1. prev_state = [z.squeeze(-2) for z in next_state] # concat all the loss and mask # todo(note): no need to use gold_valid since things are telled in vmasks cat_tok_logprobs = BK.concat(all_tok_logprobs, -1) * gold_vmasks # [*, steps] cat_lab_logprobs = BK.concat(all_lab_logprobs, -1) * gold_vmasks loss_sum = - (cat_tok_logprobs.sum() * conf.lambda_att + cat_lab_logprobs.sum() * conf.lambda_lab) # todo(+N): here we are dividing lab_logprobs with the all-count, do we need to separate? loss_count = gold_vmasks.sum() ret_losses = [[loss_sum, loss_count]] # ===== # make eos unvalid for return ret_valid_mask = gold_vmasks * (gold_widxes>0).float() # embeddings sel_lab_embeds = self._hl_lookup(gold_lidxes) return ret_losses, ret_items, gold_widxes, ret_valid_mask, gold_lidxes, sel_lab_embeds
def _step(self, input_expr, input_mask, hard_coverage, prev_state, force_widx, force_lidx, free_beam_size): conf = self.conf free_mode = (force_widx is None) prev_state_h = prev_state[0] # ===== # collect att scores key_up = self.affine_k([input_expr, hard_coverage.unsqueeze(-1)]) # [*, slen, h] query_up = self.affine_q([self.repos.unsqueeze(0), prev_state_h.unsqueeze(-2)]) # [*, R, h] orig_scores = BK.matmul(key_up, query_up.transpose(-2, -1)) # [*, slen, R] orig_scores += (1.-input_mask).unsqueeze(-1) * Constants.REAL_PRAC_MIN # [*, slen, R] # first maximum across the R dim (this step is hard max) maxr_scores, maxr_idxes = orig_scores.max(-1) # [*, slen] if conf.zero_eos_score: # use mask to make it able to be backward tmp_mask = BK.constants(BK.get_shape(maxr_scores), 1.) tmp_mask.index_fill_(-1, BK.input_idx(0), 0.) maxr_scores *= tmp_mask # then select over the slen dim (this step is prob based) maxr_logprobs = BK.log_softmax(maxr_scores) # [*, slen] if free_mode: cur_beam_size = min(free_beam_size, BK.get_shape(maxr_logprobs, -1)) sel_tok_logprobs, sel_tok_idxes = maxr_logprobs.topk(cur_beam_size, dim=-1, sorted=False) # [*, beam] else: sel_tok_idxes = force_widx.unsqueeze(-1) # [*, 1] sel_tok_logprobs = maxr_logprobs.gather(-1, sel_tok_idxes) # [*, 1] # then collect the info and perform labeling lf_input_expr = BK.gather_first_dims(input_expr, sel_tok_idxes, -2) # [*, ?, ~] lf_coverage = hard_coverage.gather(-1, sel_tok_idxes).unsqueeze(-1) # [*, ?, 1] lf_repos = self.repos[maxr_idxes.gather(-1, sel_tok_idxes)] # [*, ?, ~] # todo(+3): using soft version? lf_prev_state = prev_state_h.unsqueeze(-2) # [*, 1, ~] lab_hid_expr = self.lab_f([lf_input_expr, lf_coverage, lf_repos, lf_prev_state]) # [*, ?, ~] # final predicting labels # todo(+N): here we select only max at labeling part, only beam at previous one if free_mode: sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, None) # [*, ?] else: sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, force_lidx.unsqueeze(-1)) # no lab-logprob (*=0) for eos (sel_tok==0) sel_lab_logprobs *= (sel_tok_idxes>0).float() # compute next-state [*, ?, ~] # todo(note): here we flatten the first two dims tmp_rnn_dims = BK.get_shape(sel_tok_idxes) + [-1] tmp_rnn_input = BK.concat([lab_hid_expr, sel_lab_embeds], -1) tmp_rnn_input = tmp_rnn_input.view(-1, BK.get_shape(tmp_rnn_input, -1)) tmp_rnn_hidden = [z.unsqueeze(-2).expand(tmp_rnn_dims).contiguous().view(-1, BK.get_shape(z, -1)) for z in prev_state] # [*, ?, ?, D] next_state = self.rnn_unit(tmp_rnn_input, tmp_rnn_hidden, None) next_state = [z.view(tmp_rnn_dims) for z in next_state] return sel_tok_idxes, sel_tok_logprobs, sel_lab_idxes, sel_lab_logprobs, sel_lab_embeds, next_state
def __call__(self, input, add_root_token: bool): voc = self.voc # todo(note): append a [cls/root] idx, currently use "bos" input_t = BK.input_idx(input) # [*, 1+slen] # rare unk in training if self.rop.training and self.use_rare_unk: rare_unk_rate = self.ec_conf.comp_rare_unk cur_unk_imask = (self.rare_mask[input_t] * (BK.rand(BK.get_shape(input_t))<rare_unk_rate)).detach().long() input_t = input_t * (1-cur_unk_imask) + self.voc.unk * cur_unk_imask # root if add_root_token: input_t_p0 = BK.constants(BK.get_shape(input_t)[:-1]+[1], voc.bos, dtype=input_t.dtype) # [*, 1+slen] input_t_p1 = BK.concat([input_t_p0, input_t], -1) else: input_t_p1 = input_t expr = self.E(input_t_p1) # [*, 1?+slen] return self.dropout(expr)
def _normalize(self, cnode: ConcreteNode, orig_scores, use_noop: bool, noop_fixed_val: float, temperature: float, dim: int): cur_shape = BK.get_shape(orig_scores) # original orig_that_dim = cur_shape[dim] cur_shape[dim] = 1 if use_noop: noop_scores = BK.constants(cur_shape, value=noop_fixed_val) # [*, 1, *] to_norm_scores = BK.concat([orig_scores, noop_scores], dim=dim) # [*, D+1, *] else: to_norm_scores = orig_scores # [*, D, *] # normalize prob_full = cnode(to_norm_scores, temperature=temperature, dim=dim) # [*, ?, *] if use_noop: prob_valid, prob_noop = BK.split(prob_full, [orig_that_dim, 1], dim) # [*, D|1, *] else: prob_valid, prob_noop = prob_full, None return prob_valid, prob_noop, prob_full
def run(self, insts, training, input_word_mask_repl=None): self._cache_subword_tokens(insts) # prepare inputs word_arr, char_arr, extra_arrs, aux_arrs, mask_arr = \ self.prepare_inputs(insts, training, input_word_mask_repl=input_word_mask_repl) # layer0: emb + bert layer0_reprs = [] if self.emb_output_dim > 0: emb_repr = self.emb(word_arr, char_arr, extra_arrs, aux_arrs) # [BS, Len, Dim] layer0_reprs.append(emb_repr) if self.bert_output_dim > 0: # prepare bert inputs BERT_MASK_ID = self.bert.tokenizer.mask_token_id batch_subword_ids, batch_subword_is_starts = [], [] for bidx, one_inst in enumerate(insts): st = one_inst.extra_features["st"] if input_word_mask_repl is not None: cur_subword_ids, cur_subword_is_start, _ = \ st.mask_and_return(input_word_mask_repl[bidx][1:], BERT_MASK_ID) # todo(note): exclude ROOT for bert tokens else: cur_subword_ids, cur_subword_is_start = st.subword_ids, st.subword_is_start batch_subword_ids.append(cur_subword_ids) batch_subword_is_starts.append(cur_subword_is_start) bert_repr, _ = self.bert.forward_batch( batch_subword_ids, batch_subword_is_starts, batched_typeids=None, training=training) # [BS, Len, D'] layer0_reprs.append(bert_repr) # layer1: enc enc_input_repr = BK.concat(layer0_reprs, -1) # [BS, Len, D+D'] if self.middle_node is not None: enc_input_repr = self.middle_node(enc_input_repr) # [BS, Len, D??] enc_repr = self.enc(enc_input_repr, mask_arr) mask_repr = BK.input_real(mask_arr) return enc_repr, mask_repr # [bs, len, *], [bs, len]
def run(self, insts: List[DocInstance], training: bool): conf = self.conf BERT_MAX_LEN = 510 # save 2 for CLS and SEP # ===== # encoder 1: the basic encoder # todo(note): only DocInstane input for this mode, otherwise will break if conf.m2e_use_basic: reidx_pad_len = conf.ms_extend_budget # enc the basic part + also get some indexes sentid2offset = {} # id(sent)->overall_seq_offset seq_offset = 0 # if look at the docs in one seq all_sents = [] # (inst, d_idx, s_idx) for d_idx, one_doc in enumerate(insts): assert isinstance(one_doc, DocInstance) for s_idx, one_sent in enumerate(one_doc.sents): # todo(note): here we encode all the sentences all_sents.append((one_sent, d_idx, s_idx)) sentid2offset[id(one_sent)] = seq_offset seq_offset += one_sent.length - 1 # exclude extra ROOT node sent_reprs = self.run_sents(all_sents, insts, training) # flatten and concatenate and re-index reidxes_arr = np.zeros( seq_offset + reidx_pad_len, dtype=np.long ) # todo(note): extra padding to avoid out of boundary all_flattened_reprs = [] all_flatten_offset = 0 # the local offset for batched basic encoding for one_pack in sent_reprs: one_sents, _, one_repr_ef, one_repr_evt, _ = one_pack assert one_repr_ef is one_repr_evt, "Currently does not support separate basic enc in m3 mode" one_repr_t = one_repr_evt _, one_slen, one_ldim = BK.get_shape(one_repr_t) all_flattened_reprs.append(one_repr_t.view([-1, one_ldim])) # fill in the indexes for one_sent in one_sents: cur_start_offset = sentid2offset[id(one_sent)] cur_real_slen = one_sent.length - 1 # again, +1 to get rid of extra ROOT reidxes_arr[cur_start_offset:cur_start_offset+cur_real_slen] = \ np.arange(cur_real_slen, dtype=np.long) + (all_flatten_offset+1) all_flatten_offset += one_slen # here add the slen in batched version # re-idxing seq_sent_repr0 = BK.concat(all_flattened_reprs, 0) seq_sent_repr = BK.select(seq_sent_repr0, reidxes_arr, 0) # [all_seq_len, D] else: sentid2offset = defaultdict(int) seq_sent_repr = None # ===== # repack and prepare for multiple sent enc # todo(note): here, the criterion is based on bert's tokenizer all_ms_info = [] if isinstance(insts[0], DocInstance): for d_idx, one_doc in enumerate(insts): for s_idx, x in enumerate(one_doc.sents): # the basic criterion is the same as the basic one include_flag = False if training: if x.length<self.train_skip_length and x.length>=self.train_min_length \ and (len(x.events)>0 or next(self.random_sample_stream)>self.train_skip_noevt_rate): include_flag = True else: if x.length >= self.test_min_length: include_flag = True if include_flag: all_ms_info.append( x.preps["ms"]) # use the pre-calculated one else: # multisent based all_ms_info = insts.copy() # shallow copy # ===== # encoder 2: the bert one (multi-sent encoding) ms_size_f = lambda x: x.subword_size all_ms_info.sort(key=ms_size_f) all_ms_buckets = self._bucket_sents_by_length( all_ms_info, conf.benc_bucket_range, ms_size_f, max_bsize=conf.benc_bucket_msize) berter = self.berter rets = [] bert_use_center_typeids = conf.bert_use_center_typeids bert_use_special_typeids = conf.bert_use_special_typeids bert_other_inputs = conf.bert_other_inputs for one_bucket in all_ms_buckets: # prepare batched_ids = [] batched_starts = [] batched_seq_offset = [] batched_typeids = [] batched_other_inputs_list: List = [ [] for _ in bert_other_inputs ] # List(comp) of List(batch) of List(idx) for one_item in one_bucket: one_sents = one_item.sents one_center_sid = one_item.center_idx one_ids, one_starts, one_typeids = [], [], [] one_other_inputs_list = [[] for _ in bert_other_inputs ] # List(comp) of List(idx) for one_sid, one_sent in enumerate(one_sents): # for bert one_bidxes = one_sent.preps["bidx"] one_ids.extend(one_bidxes.subword_ids) one_starts.extend(one_bidxes.subword_is_start) # prepare other inputs for this_field_name, this_tofill_list in zip( bert_other_inputs, one_other_inputs_list): this_tofill_list.extend( one_sent.preps["sub_" + this_field_name]) # todo(note): special procedure if bert_use_center_typeids: if one_sid != one_center_sid: one_typeids.extend([0] * len(one_bidxes.subword_ids)) else: this_typeids = [1] * len(one_bidxes.subword_ids) if bert_use_special_typeids: # todo(note): this is the special mode that we are given the events!! for this_event in one_sents[ one_center_sid].events: _, this_wid, this_wlen = this_event.mention.hard_span.position( headed=False) for a, b in one_item.center_word2sub[ this_wid - 1:this_wid - 1 + this_wlen]: this_typeids[a:b] = [0] * (b - a) one_typeids.extend(this_typeids) batched_ids.append(one_ids) batched_starts.append(one_starts) batched_typeids.append(one_typeids) for comp_one_oi, comp_batched_oi in zip( one_other_inputs_list, batched_other_inputs_list): comp_batched_oi.append(comp_one_oi) # for basic part batched_seq_offset.append(sentid2offset[id(one_sents[0])]) # bert forward: [bs, slen, fold, D] if not bert_use_center_typeids: batched_typeids = None bert_expr0, mask_expr = berter.forward_batch( batched_ids, batched_starts, batched_typeids, training=training, other_inputs=batched_other_inputs_list) if self.m3_enc_is_empty: bert_expr = bert_expr0 else: mask_arr = BK.get_value(mask_expr) # [bs, slen] m3e_exprs = [ cur_enc(bert_expr0[:, :, cur_i], mask_arr) for cur_i, cur_enc in enumerate(self.m3_encs) ] bert_expr = BK.stack(m3e_exprs, -2) # on the fold dim again # collect basic ones: [bs, slen, D'] or None if seq_sent_repr is not None: arange_idxes_t = BK.arange_idx(BK.get_shape( mask_expr, -1)).unsqueeze(0) # [1, slen] offset_idxes_t = BK.input_idx(batched_seq_offset).unsqueeze( -1) + arange_idxes_t # [bs, slen] basic_expr = seq_sent_repr[offset_idxes_t] # [bs, slen, D'] elif conf.m2e_use_basic_dep: # collect each token's head-bert and ud-label, then forward with adp fake_sents = [one_item.fake_sent for one_item in one_bucket] # head idx and labels, no artificial ROOT padded_head_arr, _ = self.dep_padder.pad( [s.ud_heads.vals[1:] for s in fake_sents]) padded_label_arr, _ = self.dep_padder.pad( [s.ud_labels.idxes[1:] for s in fake_sents]) # get tensor padded_head_t = (BK.input_idx(padded_head_arr) - 1 ) # here, the idx exclude root padded_head_t.clamp_(min=0) # [bs, slen] padded_label_t = BK.input_idx(padded_label_arr) # get inputs input_head_bert_t = bert_expr[ BK.arange_idx(len(fake_sents)).unsqueeze(-1), padded_head_t] # [bs, slen, fold, D] input_label_emb_t = self.dep_label_emb( padded_label_t) # [bs, slen, D'] basic_expr = self.dep_layer( input_head_bert_t, None, [input_label_emb_t]) # [bs, slen, ?] elif conf.m2e_use_basic_plus: sent_reprs = self.run_sents([(one_item.fake_sent, None, None) for one_item in one_bucket], insts, training, use_one_bucket=True) assert len( sent_reprs ) == 1, "Unsupported split reprs for basic encoder, please set enc_bucket_range<=benc_bucket_range" _, _, one_repr_ef, one_repr_evt, _ = sent_reprs[0] assert one_repr_ef is one_repr_evt, "Currently does not support separate basic enc in m3 mode" basic_expr = one_repr_evt[:, 1:] # exclude ROOT, [bs, slen, D] assert BK.get_shape(basic_expr)[:2] == BK.get_shape( bert_expr)[:2] else: basic_expr = None # pack: (List[ms_item], bert_expr, basic_expr) rets.append((one_bucket, bert_expr, basic_expr)) return rets
def forward_batch(self, batched_ids: List, batched_starts: List, batched_typeids: List, training: bool, other_inputs: List[List] = None): conf = self.bconf tokenizer = self.tokenizer PAD_IDX = tokenizer.pad_token_id MASK_IDX = tokenizer.mask_token_id CLS_IDX = tokenizer.cls_token_id SEP_IDX = tokenizer.sep_token_id if other_inputs is None: other_inputs = [] # ===== # batch: here add CLS and SEP bsize = len(batched_ids) max_len = max(len(z) for z in batched_ids) + 2 # plus [CLS] and [SEP] input_shape = (bsize, max_len) # first collect on CPU input_ids_arr = np.full(input_shape, PAD_IDX, dtype=np.int64) input_ids_arr[:, 0] = CLS_IDX input_mask_arr = np.full(input_shape, 0, dtype=np.float32) input_is_start_arr = np.full(input_shape, 0, dtype=np.int64) input_typeids = None if batched_typeids is None else np.full( input_shape, 0, dtype=np.int64) other_input_arrs = [ np.full(input_shape, 0, dtype=np.int64) for _ in other_inputs ] if conf.bert2_retinc_cls: # act as the ROOT word input_is_start_arr[:, 0] = 1 training_mask_rate = conf.bert2_training_mask_rate if training else 0. self_sample_stream = self.random_sample_stream for bidx in range(bsize): cur_ids, cur_starts = batched_ids[bidx], batched_starts[bidx] cur_end = len(cur_ids) + 2 # plus CLS and SEP if training_mask_rate > 0.: # input dropout input_ids_arr[bidx, 1:cur_end] = [ (MASK_IDX if next(self_sample_stream) < training_mask_rate else z) for z in cur_ids ] + [SEP_IDX] else: input_ids_arr[bidx, 1:cur_end] = cur_ids + [SEP_IDX] input_is_start_arr[bidx, 1:cur_end - 1] = cur_starts input_mask_arr[bidx, :cur_end] = 1. if batched_typeids is not None and batched_typeids[ bidx] is not None: input_typeids[bidx, 1:cur_end - 1] = batched_typeids[bidx] for one_other_input_arr, one_other_input_list in zip( other_input_arrs, other_inputs): one_other_input_arr[bidx, 1:cur_end - 1] = one_other_input_list[bidx] # arr to tensor input_ids_t = BK.input_idx(input_ids_arr) input_mask_t = BK.input_real(input_mask_arr) input_is_start_t = BK.input_idx(input_is_start_arr) input_typeid_t = None if input_typeids is None else BK.input_idx( input_typeids) other_input_ts = [BK.input_idx(z) for z in other_input_arrs] # ===== # forward (maybe need multiple times to fit maxlen constraint) MAX_LEN = 510 # save two for [CLS] and [SEP] BACK_LEN = 100 # for splitting cases, still remaining some of previous sub-tokens for context if max_len <= MAX_LEN: # directly once final_outputs = self.forward_features( input_ids_t, input_mask_t, input_typeid_t, other_input_ts) # [bs, slen, *...] start_idxes, start_masks = BK.mask2idx( input_is_start_t.float()) # [bsize, ?] else: all_outputs = [] cur_sub_idx = 0 slice_size = [bsize, 1] slice_cls, slice_sep = BK.constants(slice_size, CLS_IDX, dtype=BK.int64), BK.constants( slice_size, SEP_IDX, dtype=BK.int64) while cur_sub_idx < max_len - 1: # minus 1 to ignore ending SEP cur_slice_start = max(1, cur_sub_idx - BACK_LEN) cur_slice_end = min(cur_slice_start + MAX_LEN, max_len - 1) cur_input_ids_t = BK.concat([ slice_cls, input_ids_t[:, cur_slice_start:cur_slice_end], slice_sep ], 1) # here we simply extend extra original masks cur_input_mask_t = input_mask_t[:, cur_slice_start - 1:cur_slice_end + 1] cur_input_typeid_t = None if input_typeid_t is None else input_typeid_t[:, cur_slice_start - 1: cur_slice_end + 1] cur_other_input_ts = [ z[:, cur_slice_start - 1:cur_slice_end + 1] for z in other_input_ts ] cur_outputs = self.forward_features(cur_input_ids_t, cur_input_mask_t, cur_input_typeid_t, cur_other_input_ts) # only include CLS in the first run, no SEP included if cur_sub_idx == 0: # include CLS, exclude SEP all_outputs.append(cur_outputs[:, :-1]) else: # include only new ones, discard BACK ones, exclude CLS, SEP all_outputs.append(cur_outputs[:, cur_sub_idx - cur_slice_start + 1:-1]) zwarn( f"Add multiple-seg range: [{cur_slice_start}, {cur_sub_idx}, {cur_slice_end})] " f"for all-len={max_len}") cur_sub_idx = cur_slice_end final_outputs = BK.concat(all_outputs, 1) # [bs, max_len-1, *...] start_idxes, start_masks = BK.mask2idx( input_is_start_t[:, :-1].float()) # [bsize, ?] start_expr = BK.gather_first_dims(final_outputs, start_idxes, 1) # [bsize, ?, *...] return start_expr, start_masks # [bsize, ?, ...], [bsize, ?]
def loss(self, insts: List, input_lexi, input_expr, input_mask, margin=0.): conf = self.conf bsize = len(insts) # first get gold info, also multiple valid-masks gold_masks, gold_idxes, gold_items_arr, gold_valid, gold_idxes2, gold_items2_arr = self.batch_inputs_h( insts) input_mask = input_mask * gold_valid.unsqueeze(-1) # [*, slen] # step 1: selector if conf.use_selector: sel_loss, sel_mask = self.sel.loss(input_expr, input_mask, gold_masks, margin=margin) else: sel_loss, sel_mask = None, self._select_cands_training( input_mask, gold_masks, conf.train_min_rate) sel_idxes, sel_valid_mask = BK.mask2idx(sel_mask) # [*, max-count] sel_gold_idxes = gold_idxes.gather(-1, sel_idxes) sel_gold_idxes2 = gold_idxes2.gather(-1, sel_idxes) # todo(+N): only get items by head position! _tmp_i0, _tmp_i1 = np.arange(bsize)[:, np.newaxis], BK.get_value( sel_idxes) sel_items = gold_items_arr[_tmp_i0, _tmp_i1] # [*, mc] sel2_items = gold_items2_arr[_tmp_i0, _tmp_i1] # step 2: encoding and labeling # if we select nothing # ----- debug # zlog(f"fb-extractor 1: shape sel_idxes = {sel_idxes.shape}") # ----- sel_shape = BK.get_shape(sel_idxes) if sel_shape[-1] == 0: lab_loss = [[BK.zeros([]), BK.zeros([])]] sel2_lab_loss = [[BK.zeros([]), BK.zeros([])] ] if self.use_secondary_type else None sel_lab_idxes = sel_gold_idxes sel_lab_embeds = BK.zeros(sel_shape + [conf.lab_conf.n_dim]) ret_items = sel_items # dim-1==0 else: sel_hid_exprs = self._enc(input_lexi, input_expr, input_mask, sel_idxes) # [*, mc, DLab] lab_loss, sel_lab_idxes, sel_lab_embeds = self.hl.loss( sel_hid_exprs, sel_valid_mask, sel_gold_idxes, margin=margin) if conf.train_gold_corr: sel_lab_idxes = sel_gold_idxes if not self.hl.conf.use_lookup_soft: sel_lab_embeds = self.hl.lookup(sel_lab_idxes) ret_items = sel_items # ===== if self.use_secondary_type: sectype_embeds = self.t1tot2(sel_lab_idxes) # [*, mc, D] if conf.sectype_noback_enc: sel2_input = sel_hid_exprs.detach( ) + sectype_embeds # [*, mc, D] else: sel2_input = sel_hid_exprs + sectype_embeds # [*, mc, D] # ===== # sepcial for the sectype mask (sample it within the gold ones) sel2_valid_mask = self._select_cands_training( (sel_gold_idxes > 0).float(), (sel_gold_idxes2 > 0).float(), conf.train_min_rate_s2) # ===== sel2_lab_loss, sel2_lab_idxes, sel2_lab_embeds = self.hl.loss( sel2_input, sel2_valid_mask, sel_gold_idxes2, margin=margin) if conf.train_gold_corr: sel2_lab_idxes = sel_gold_idxes2 if not self.hl.conf.use_lookup_soft: sel2_lab_embeds = self.hl.lookup(sel2_lab_idxes) if conf.sectype_t2ift1: sel2_lab_idxes = sel2_lab_idxes * (sel_lab_idxes > 0).long( ) # pred t2 only if t1 is not 0 (nil) # combine the two if sel2_lab_idxes.sum().item( ) > 0: # if there are any gold sectypes ret_items = np.concatenate([ret_items, sel2_items], -1) # [*, mc*2] sel_idxes = BK.concat([sel_idxes, sel_idxes], -1) sel_valid_mask = BK.concat( [sel_valid_mask, sel2_valid_mask], -1) sel_lab_idxes = BK.concat([sel_lab_idxes, sel2_lab_idxes], -1) sel_lab_embeds = BK.concat( [sel_lab_embeds, sel2_lab_embeds], -2) else: sel2_lab_loss = None # ===== # step 3: exclude nil and return if conf.exclude_nil: # [*, mc', ...] sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, _, ret_items = \ self._exclude_nil(sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_items_arr=ret_items) # sel_enc_expr = BK.gather_first_dims(input_expr, sel_idxes, -2) # [*, mc', D] # step 4: finally prepare loss and items for one_loss in lab_loss: one_loss[0] *= conf.lambda_ne ret_losses = lab_loss if sel2_lab_loss is not None: for one_loss in sel2_lab_loss: one_loss[0] *= conf.lambda_ne2 ret_losses = ret_losses + sel2_lab_loss if sel_loss is not None: for one_loss in sel_loss: one_loss[0] *= conf.lambda_ns ret_losses = ret_losses + sel_loss # ----- debug # zlog(f"fb-extractor 2: shape sel_idxes = {sel_idxes.shape}") # ----- # mask out invalid items with None ret_items[BK.get_value(1. - sel_valid_mask).astype(np.bool)] = None return ret_losses, ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
def predict(self, insts: List, input_lexi, input_expr, input_mask): conf = self.conf bsize, slen = BK.get_shape(input_mask) bsize_arange_t_1d = BK.arange_idx(bsize) # [*] bsize_arange_t_2d = bsize_arange_t_1d.unsqueeze(-1) # [*, 1] beam_size = conf.beam_size # prepare things with an extra beam dimension beam_input_expr, beam_input_mask = input_expr.unsqueeze(-3).expand(-1, beam_size, -1, -1).contiguous(), \ input_mask.unsqueeze(-2).expand(-1, beam_size, -1).contiguous() # [*, beam, slen, D?] # ----- # recurrent states beam_hard_coverage = BK.zeros([bsize, beam_size, slen]) # [*, beam, slen] # tuple([*, beam, D], ) beam_prev_state = [z.unsqueeze(-2).expand(-1, beam_size, -1) for z in self.rnn_unit.zero_init_hidden(bsize)] # frozen after reach eos beam_noneos = 1.-BK.zeros([bsize, beam_size]) # [*, beam] beam_logprobs = BK.zeros([bsize, beam_size]) # [*, beam], sum of logprobs beam_logprobs_paths = BK.zeros([bsize, beam_size, 0]) # [*, beam, step] beam_tok_paths = BK.zeros([bsize, beam_size, 0]).long() beam_lab_paths = BK.zeros([bsize, beam_size, 0]).long() # ----- for cstep in range(conf.max_step): # get things of [*, beam, beam] sel_tok_idxes, sel_tok_logprobs, sel_lab_idxes, sel_lab_logprobs, sel_lab_embeds, next_state = \ self._step(beam_input_expr, beam_input_mask, beam_hard_coverage, beam_prev_state, None, None, beam_size) sel_logprobs = sel_tok_logprobs + sel_lab_logprobs # [*, beam, beam] if cstep == 0: # special for the first step, only select for the first element cur_selections = BK.arange_idx(beam_size).unsqueeze(0).expand(bsize, beam_size) # [*, beam] else: # then select the topk in beam*beam (be careful about the frozen ones!!) beam_noneos_3d = beam_noneos.unsqueeze(-1) # eos can only followed by eos sel_tok_idxes *= beam_noneos_3d.long() sel_lab_idxes *= beam_noneos_3d.long() # numeric tricks to keep the frozen ones ([0] with 0. score, [1:] with -inf scores) sel_logprobs *= beam_noneos_3d tmp_exclude_mask = 1. - beam_noneos_3d.expand_as(sel_logprobs) tmp_exclude_mask[:, :, 0] = 0. sel_logprobs += tmp_exclude_mask * Constants.REAL_PRAC_MIN # select for topk topk_logprobs = (beam_noneos * beam_logprobs).unsqueeze(-1) + sel_logprobs _, cur_selections = topk_logprobs.view([bsize, -1]).topk(beam_size, dim=-1, sorted=True) # [*, beam] # read and write the selections # gathering previous ones cur_sel_previ = cur_selections // beam_size # [*, beam] prev_hard_coverage = beam_hard_coverage[bsize_arange_t_2d, cur_sel_previ] # [*, beam] prev_noneos = beam_noneos[bsize_arange_t_2d, cur_sel_previ] # [*, beam] prev_logprobs = beam_logprobs[bsize_arange_t_2d, cur_sel_previ] # [*, beam] prev_logprobs_paths = beam_logprobs_paths[bsize_arange_t_2d, cur_sel_previ] # [*, beam, step] prev_tok_paths = beam_tok_paths[bsize_arange_t_2d, cur_sel_previ] # [*, beam, step] prev_lab_paths = beam_lab_paths[bsize_arange_t_2d, cur_sel_previ] # [*, beam, step] # prepare new ones cur_sel_newi = cur_selections % beam_size new_tok_idxes = sel_tok_idxes[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi] # [*, beam] new_lab_idxes = sel_lab_idxes[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi] # [*, beam] new_logprobs = sel_logprobs[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi] # [*, beam] new_prev_state = [z[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi] for z in next_state] # [*, beam, ~] # update prev_hard_coverage[bsize_arange_t_2d, BK.arange_idx(beam_size).unsqueeze(0), new_tok_idxes] += 1. beam_hard_coverage = prev_hard_coverage beam_prev_state = new_prev_state beam_noneos = prev_noneos * (new_tok_idxes!=0).float() beam_logprobs = prev_logprobs + new_logprobs beam_logprobs_paths = BK.concat([prev_logprobs_paths, new_logprobs.unsqueeze(-1)], -1) beam_tok_paths = BK.concat([prev_tok_paths, new_tok_idxes.unsqueeze(-1)], -1) beam_lab_paths = BK.concat([prev_lab_paths, new_lab_idxes.unsqueeze(-1)], -1) # finally force an extra eos step to get ending tok-logprob (no need to update other things) final_eos_idxes = BK.zeros([bsize, beam_size]).long() _, eos_logprobs, _, _, _, _ = self._step(beam_input_expr, beam_input_mask, beam_hard_coverage, beam_prev_state, final_eos_idxes, final_eos_idxes, None) beam_logprobs += eos_logprobs.squeeze(-1) * beam_noneos # [*, beam] # select and return the best one beam_tok_valids = (beam_tok_paths > 0).float() # [*, beam, steps] final_scores = beam_logprobs / ((beam_tok_valids.sum(-1) + 1.) ** conf.len_alpha) # [*, beam] _, best_beam_idx = final_scores.max(-1) # [*] # ----- # prepare returns; cut by max length: [*, all_step] -> [*, max_step] ret0_valid_mask = beam_tok_valids[bsize_arange_t_1d, best_beam_idx] cur_max_step = ret0_valid_mask.long().sum(-1).max().item() ret_valid_mask = ret0_valid_mask[:, :cur_max_step] ret_logprobs = beam_logprobs_paths[bsize_arange_t_1d, best_beam_idx][:, :cur_max_step] ret_tok_idxes = beam_tok_paths[bsize_arange_t_1d, best_beam_idx][:, :cur_max_step] ret_lab_idxes = beam_lab_paths[bsize_arange_t_1d, best_beam_idx][:, :cur_max_step] # embeddings ret_lab_embeds = self._hl_lookup(ret_lab_idxes) return ret_logprobs, ret_tok_idxes, ret_valid_mask, ret_lab_idxes, ret_lab_embeds
def decode(self, inst: DocInstance): conf, model = self.conf, self.model model.refresh_batch(False) test_constrain_evt_types = self.test_constrain_evt_types with BK.no_grad_env(): # ===== # init the collections flattened_ef_ereprs, flattened_evt_ereprs = [], [] sent_offsets = [Constants.INT_PRAC_MIN]*len(inst.sents) # start offset of sent in the flattened erepr cur_offset = 0 # current offset all_ef_items, all_evt_items = [], [] # ===== # first basic run and ef and evt all_packs = model.bter.run([inst], training=False) for one_pack in all_packs: sent_insts, lexi_repr, enc_repr_ef, enc_repr_evt, mask_arr = one_pack mask_expr = BK.input_real(mask_arr) # ===== # store the enc reprs and sent offsets sent_size, sent_len = BK.get_shape(enc_repr_ef)[:2] assert BK.get_shape(enc_repr_ef) == BK.get_shape(enc_repr_evt) flattened_ef_ereprs.append(enc_repr_ef.view(sent_size*sent_len, -1)) # [cur_flatten_size, D] flattened_evt_ereprs.append(enc_repr_evt.view(sent_size*sent_len, -1)) for one_sent in sent_insts: sent_offsets[one_sent.sid] = cur_offset cur_offset += sent_len # ===== lkrc = not conf.dec_debug_mode # lookup.ret_copy? # ===== # ef if conf.lookup_ef: ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, ef_lab_embeds = \ model._lookup_mentions(sent_insts, lexi_repr, enc_repr_ef, mask_expr, model.ef_extractor, ret_copy=lkrc) else: ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, ef_lab_embeds = \ model._inference_mentions(sent_insts, lexi_repr, enc_repr_ef, mask_expr, model.ef_extractor, model.ef_creator) # collect all valid ones all_ef_items.extend(ef_items[BK.get_value(ef_valid_mask).astype(np.bool)]) # event if conf.lookup_evt: evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, evt_lab_embeds = \ model._lookup_mentions(sent_insts, lexi_repr, enc_repr_evt, mask_expr, model.evt_extractor, ret_copy=lkrc) else: evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, evt_lab_embeds = \ model._inference_mentions(sent_insts, lexi_repr, enc_repr_evt, mask_expr, model.evt_extractor, model.evt_creator) # collect all valid ones if test_constrain_evt_types is None: all_evt_items.extend(evt_items[BK.get_value(evt_valid_mask).astype(np.bool)]) else: all_evt_items.extend([z for z in evt_items[BK.get_value(evt_valid_mask).astype(np.bool)] if z.type in test_constrain_evt_types]) # ==== # cross-sentence pairwise arg score # flattened all enc: [Offset, D] flattened_ef_enc_repr, flattened_evt_enc_repr = BK.concat(flattened_ef_ereprs, 0), BK.concat(flattened_evt_ereprs, 0) # sort by position in doc all_ef_items.sort(key=lambda x: x.mention.hard_span.position(True)) all_evt_items.sort(key=lambda x: x.mention.hard_span.position(True)) if not conf.dec_debug_mode: # todo(note): delete origin links! for z in all_ef_items: if z is not None: z.links.clear() for z in all_evt_items: if z is not None: z.links.clear() # get other info # todo(note): currently all using head word all_ef_offsets = BK.input_idx([sent_offsets[x.mention.hard_span.sid]+x.mention.hard_span.head_wid for x in all_ef_items]) all_evt_offsets = BK.input_idx([sent_offsets[x.mention.hard_span.sid]+x.mention.hard_span.head_wid for x in all_evt_items]) all_ef_lab_idxes = BK.input_idx([model.ef_extractor.hlidx2idx(x.type_idx) for x in all_ef_items]) all_evt_lab_idxes = BK.input_idx([model.evt_extractor.hlidx2idx(x.type_idx) for x in all_evt_items]) # score all the pairs (with mini-batch) mini_batch_size = conf.score_mini_batch arg_linker = model.arg_linker all_logprobs = BK.zeros([len(all_ef_items), len(all_evt_items), arg_linker.num_label]) for bidx_ef in range(0, len(all_ef_items), mini_batch_size): cur_ef_enc_repr = flattened_ef_enc_repr[all_ef_offsets[bidx_ef:bidx_ef+mini_batch_size]].unsqueeze(0) cur_ef_lab_idxes = all_ef_lab_idxes[bidx_ef:bidx_ef+mini_batch_size].unsqueeze(0) for bidx_evt in range(0, len(all_evt_items), mini_batch_size): cur_evt_enc_repr = flattened_evt_enc_repr[all_evt_offsets[bidx_evt:bidx_evt+mini_batch_size]].unsqueeze(0) cur_evt_lab_idxes = all_evt_lab_idxes[bidx_evt:bidx_evt + mini_batch_size].unsqueeze(0) all_logprobs[bidx_ef:bidx_ef+mini_batch_size,bidx_evt:bidx_evt+mini_batch_size] = \ arg_linker.predict(cur_ef_enc_repr, cur_evt_enc_repr, cur_ef_lab_idxes, cur_evt_lab_idxes, ret_full_logprobs=True).squeeze(0) all_logprobs_arr = BK.get_value(all_logprobs) # ===== # then decode them all using the scores self.arg_decode(inst, all_ef_items, all_evt_items, all_logprobs_arr) # ===== # assign and return num_pred_arg = 0 for one_sent in inst.sents: one_sent.pred_entity_fillers.clear() one_sent.pred_events.clear() for z in all_ef_items: inst.sents[z.mention.hard_span.sid].pred_entity_fillers.append(z) for z in all_evt_items: inst.sents[z.mention.hard_span.sid].pred_events.append(z) num_pred_arg += len(z.links) info = {"doc": 1, "sent": len(inst.sents), "token": sum(s.length-1 for s in inst.sents), "p_ef": len(all_ef_items), "p_evt": len(all_evt_items), "p_arg": num_pred_arg} return info
def loss(self, repr_t, orig_map: Dict, **kwargs): conf = self.conf _tie_input_embeddings = conf.tie_input_embeddings # -- # specify input add_root_token = self.add_root_token # get from inputs if isinstance(repr_t, (list, tuple)): l2r_repr_t, r2l_repr_t = repr_t elif self.split_input_blm: l2r_repr_t, r2l_repr_t = BK.chunk(repr_t, 2, -1) else: l2r_repr_t, r2l_repr_t = repr_t, None # l2r and r2l word_t = BK.input_idx(orig_map["word"]) # [bs, rlen] slice_zero_t = BK.zeros([BK.get_shape(word_t, 0), 1]).long() # [bs, 1] if add_root_token: l2r_trg_t = BK.concat([word_t, slice_zero_t], -1) # pad one extra 0, [bs, rlen+1] r2l_trg_t = BK.concat( [slice_zero_t, slice_zero_t, word_t[:, :-1]], -1) # pad two extra 0 at front, [bs, 2+rlen-1] else: l2r_trg_t = BK.concat( [word_t[:, 1:], slice_zero_t], -1 ) # pad one extra 0, but remove the first one, [bs, -1+rlen+1] r2l_trg_t = BK.concat( [slice_zero_t, word_t[:, :-1]], -1) # pad one extra 0 at front, [bs, 1+rlen-1] # gather the losses all_losses = [] pred_range_min, pred_range_max = max( 1, conf.min_pred_rank), self.pred_size - 1 if _tie_input_embeddings: pred_W = self.inputter_embed_node.E.E[:self. pred_size] # [PSize, Dim] else: pred_W = None # get input embeddings for output for pred_name, hid_node, pred_node, input_t, trg_t in \ zip(["l2r", "r2l"], [self.l2r_hid_layer, self.r2l_hid_layer], [self.l2r_pred, self.r2l_pred], [l2r_repr_t, r2l_repr_t], [l2r_trg_t, r2l_trg_t]): if input_t is None: continue # hidden hid_t = hid_node( input_t) if hid_node else input_t # [bs, slen, hid] # pred: [bs, slen, Vsize] if _tie_input_embeddings: scores_t = BK.matmul(hid_t, pred_W.T) else: scores_t = pred_node(hid_t) # loss mask_t = ((trg_t >= pred_range_min) & (trg_t <= pred_range_max)).float() # [bs, slen] trg_t.clamp_(max=pred_range_max) # make it in range losses_t = BK.loss_nll(scores_t, trg_t) * mask_t # [bs, slen] _, argmax_idxes = scores_t.max(-1) # [bs, slen] corrs_t = (argmax_idxes == trg_t).float() * mask_t # [bs, slen] # compile leaf loss one_loss = LossHelper.compile_leaf_info(pred_name, losses_t.sum(), mask_t.sum(), loss_lambda=1., corr=corrs_t.sum()) all_losses.append(one_loss) return self._compile_component_loss("plm", all_losses)