Example #1
0
 def _get_basic_score(self, mb_enc_expr, batch_idxes, m_idxes, h_idxes,
                      sib_idxes, gp_idxes):
     allp_size = BK.get_shape(batch_idxes, 0)
     all_arc_scores, all_lab_scores = [], []
     cur_pidx = 0
     while cur_pidx < allp_size:
         next_pidx = min(allp_size, cur_pidx + self.mb_dec_sb)
         # first calculate srepr
         s_enc = self.slayer
         cur_batch_idxes = batch_idxes[cur_pidx:next_pidx]
         h_expr = mb_enc_expr[cur_batch_idxes, h_idxes[cur_pidx:next_pidx]]
         m_expr = mb_enc_expr[cur_batch_idxes, m_idxes[cur_pidx:next_pidx]]
         s_expr = mb_enc_expr[cur_batch_idxes, sib_idxes[cur_pidx:next_pidx]].unsqueeze(-2) \
             if (sib_idxes is not None) else None  # [*, 1, D]
         g_expr = mb_enc_expr[cur_batch_idxes,
                              gp_idxes[cur_pidx:next_pidx]] if (
                                  gp_idxes is not None) else None
         head_srepr = s_enc.calculate_repr(h_expr, g_expr, None, None,
                                           s_expr, None, None, None)
         mod_srepr = s_enc.forward_repr(m_expr)
         # then get the scores
         arc_score = self.scorer.transform_and_arc_score_plain(
             mod_srepr, head_srepr).squeeze(-1)
         all_arc_scores.append(arc_score)
         if self.system_labeled:
             lab_score = self.scorer.transform_and_label_score_plain(
                 mod_srepr, head_srepr)
             all_lab_scores.append(lab_score)
         cur_pidx = next_pidx
     final_arc_score = BK.concat(all_arc_scores, 0)
     final_lab_score = BK.concat(all_lab_scores,
                                 0) if self.system_labeled else None
     return final_arc_score, final_lab_score
Example #2
0
 def __call__(self,
              word_arr: np.ndarray = None,
              char_arr: np.ndarray = None,
              extra_arrs: Iterable[np.ndarray] = (),
              aux_arrs: Iterable[np.ndarray] = ()):
     exprs = []
     # word/char/extras/posi
     seq_shape = None
     if self.has_word:
         # todo(warn): singleton-UNK-dropout should be done outside before
         seq_shape = word_arr.shape
         word_expr = self.dropmd_word(self.word_embed(word_arr))
         exprs.append(word_expr)
     if self.has_char:
         seq_shape = char_arr.shape[:-1]
         char_embeds = self.char_embed(
             char_arr)  # [*, seq-len, word-len, D]
         char_cat_expr = self.dropmd_char(
             BK.concat([z(char_embeds) for z in self.char_cnns]))
         exprs.append(char_cat_expr)
     zcheck(
         len(extra_arrs) == len(self.extra_embeds),
         "Unmatched extra fields.")
     for one_extra_arr, one_extra_embed, one_extra_dropmd in zip(
             extra_arrs, self.extra_embeds, self.dropmd_extras):
         seq_shape = one_extra_arr.shape
         exprs.append(one_extra_dropmd(one_extra_embed(one_extra_arr)))
     if self.has_posi:
         seq_len = seq_shape[-1]
         posi_idxes = BK.arange_idx(seq_len)
         posi_input0 = self.posi_embed(posi_idxes)
         for _ in range(len(seq_shape) - 1):
             posi_input0 = BK.unsqueeze(posi_input0, 0)
         posi_input1 = BK.expand(posi_input0, tuple(seq_shape) + (-1, ))
         exprs.append(posi_input1)
     #
     assert len(aux_arrs) == len(self.drop_auxes)
     for one_aux_arr, one_aux_dim, one_aux_drop, one_fold, one_gamma, one_lambdas in \
             zip(aux_arrs, self.dim_auxes, self.drop_auxes, self.fold_auxes, self.aux_overall_gammas, self.aux_fold_lambdas):
         # fold and apply trainable lambdas
         input_aux_repr = BK.input_real(one_aux_arr)
         input_shape = BK.get_shape(input_aux_repr)
         # todo(note): assume the original concat is [fold/layer, D]
         reshaped_aux_repr = input_aux_repr.view(
             input_shape[:-1] +
             [one_fold, one_aux_dim])  # [*, slen, fold, D]
         lambdas_softmax = BK.softmax(one_gamma,
                                      -1).unsqueeze(-1)  # [fold, 1]
         weighted_aux_repr = (reshaped_aux_repr * lambdas_softmax
                              ).sum(-2) * one_gamma  # [*, slen, D]
         one_aux_expr = one_aux_drop(weighted_aux_repr)
         exprs.append(one_aux_expr)
     #
     concated_exprs = BK.concat(exprs, dim=-1)
     # optional proj
     if self.has_proj:
         final_expr = self.final_layer(concated_exprs)
     else:
         final_expr = concated_exprs
     return final_expr
Example #3
0
 def lookup(self, insts: List, input_lexi, input_expr, input_mask):
     conf = self.conf
     bsize = len(insts)
     # first get gold/input info, also multiple valid-masks
     gold_masks, gold_idxes, gold_items_arr, gold_valid, gold_idxes2, gold_items2_arr = self.batch_inputs_h(
         insts)
     # step 1: no selection, simply forward using gold_masks
     sel_idxes, sel_valid_mask = BK.mask2idx(gold_masks)  # [*, max-count]
     sel_gold_idxes = gold_idxes.gather(-1, sel_idxes)
     sel_gold_idxes2 = gold_idxes2.gather(-1, sel_idxes)
     # todo(+N): only get items by head position!
     _tmp_i0, _tmp_i1 = np.arange(bsize)[:, np.newaxis], BK.get_value(
         sel_idxes)
     sel_items = gold_items_arr[_tmp_i0, _tmp_i1]  # [*, mc]
     sel2_items = gold_items2_arr[_tmp_i0, _tmp_i1]
     # step 2: encoding and labeling
     sel_shape = BK.get_shape(sel_idxes)
     if sel_shape[-1] == 0:
         sel_lab_idxes = sel_gold_idxes
         sel_lab_embeds = BK.zeros(sel_shape + [conf.lab_conf.n_dim])
         ret_items = sel_items  # dim-1==0
     else:
         # sel_hid_exprs = self._enc(input_expr, input_mask, sel_idxes)  # [*, mc, DLab]
         sel_lab_idxes = sel_gold_idxes
         sel_lab_embeds = self.hl.lookup(
             sel_lab_idxes)  # todo(note): here no softlookup?
         ret_items = sel_items
         # second type
         if self.use_secondary_type:
             sel2_lab_idxes = sel_gold_idxes2
             sel2_lab_embeds = self.hl.lookup(
                 sel2_lab_idxes)  # todo(note): here no softlookup?
             sel2_valid_mask = (sel2_lab_idxes > 0).float()
             # combine the two
             if sel2_lab_idxes.sum().item(
             ) > 0:  # if there are any gold sectypes
                 ret_items = np.concatenate([ret_items, sel2_items],
                                            -1)  # [*, mc*2]
                 sel_idxes = BK.concat([sel_idxes, sel_idxes], -1)
                 sel_valid_mask = BK.concat(
                     [sel_valid_mask, sel2_valid_mask], -1)
                 sel_lab_idxes = BK.concat([sel_lab_idxes, sel2_lab_idxes],
                                           -1)
                 sel_lab_embeds = BK.concat(
                     [sel_lab_embeds, sel2_lab_embeds], -2)
     # step 3: exclude nil assuming no deliberate nil in gold/inputs
     if conf.exclude_nil:  # [*, mc', ...]
         sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, _, ret_items = \
             self._exclude_nil(sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_items_arr=ret_items)
     # step 4: return
     # sel_enc_expr = BK.gather_first_dims(input_expr, sel_idxes, -2)  # [*, mc', D]
     # mask out invalid items with None
     ret_items[BK.get_value(1. - sel_valid_mask).astype(np.bool)] = None
     return ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
Example #4
0
 def __call__(self, char_input, add_root_token: bool):
     char_input_t = BK.input_idx(char_input)  # [*, slen, wlen]
     if add_root_token:
         slice_shape = BK.get_shape(char_input_t)
         slice_shape[-2] = 1
         char_input_t0 = BK.constants(slice_shape, 0, dtype=char_input_t.dtype)  # todo(note): simply put 0 here!
         char_input_t1 = BK.concat([char_input_t0, char_input_t], -2)  # [*, 1?+slen, wlen]
     else:
         char_input_t1 = char_input_t
     char_embeds = self.E(char_input_t1)  # [*, 1?+slen, wlen, D]
     char_cat_expr = BK.concat([z(char_embeds) for z in self.char_cnns])
     return self.dropout(char_cat_expr)  # todo(note): only final dropout
Example #5
0
 def _special_score(
         one_score):  # specially change ablpair scores into [bs,m,h,*]
     root_score = one_score[:, :, 0].unsqueeze(2)  # [bs, rlen, 1, *]
     tmp_shape = BK.get_shape(root_score)
     tmp_shape[1] = 1  # [bs, 1, 1, *]
     padded_root_score = BK.concat([BK.zeros(tmp_shape), root_score],
                                   dim=1)  # [bs, rlen+1, 1, *]
     final_score = BK.concat(
         [padded_root_score,
          one_score.transpose(1, 2)],
         dim=2)  # [bs, rlen+1[m], rlen+1[h], *]
     return final_score
Example #6
0
 def run_sents(self, all_sents: List, all_docs: List[DocInstance], training: bool, use_one_bucket=False):
     if use_one_bucket:
         all_buckets = [all_sents]  # when we do not want to split if we know the input lengths do not vary too much
     else:
         all_sents.sort(key=lambda x: x[0].length)
         all_buckets = self._bucket_sents_by_length(all_sents, self.bconf.enc_bucket_range)
     # doc hint
     use_doc_hint = self.use_doc_hint
     if use_doc_hint:
         dh_sent_repr = self.dh_node.run(all_docs)  # [NumDoc, MaxSent, D]
     else:
         dh_sent_repr = None
     # encoding for each of the bucket
     rets = []
     dh_add, dh_both, dh_cls = [self.dh_combine_method==z for z in ["add", "both", "cls"]]
     for one_bucket in all_buckets:
         one_sents = [z[0] for z in one_bucket]
         # [BS, Len, Di], [BS, Len]
         input_repr0, mask_arr0 = self._prepare_input(one_sents, training)
         if use_doc_hint:
             one_d_idxes = BK.input_idx([z[1] for z in one_bucket])
             one_s_idxes = BK.input_idx([z[2] for z in one_bucket])
             one_s_reprs = dh_sent_repr[one_d_idxes, one_s_idxes].unsqueeze(-2)  # [BS, 1, D]
             if dh_add:
                 input_repr = input_repr0 + one_s_reprs  # [BS, slen, D]
                 mask_arr = mask_arr0
             elif dh_both:
                 input_repr = BK.concat([one_s_reprs, input_repr0, one_s_reprs], -2)  # [BS, 2+slen, D]
                 mask_arr = np.pad(mask_arr0, ((0,0),(1,1)), 'constant', constant_values=1.)  # [BS, 2+slen]
             elif dh_cls:
                 input_repr = BK.concat([one_s_reprs, input_repr0[:, 1:]], -2)  # [BS, slen, D]
                 mask_arr = mask_arr0
             else:
                 raise NotImplementedError()
         else:
             input_repr, mask_arr = input_repr0, mask_arr0
         # [BS, Len, De]
         enc_repr = self.enc(input_repr, mask_arr)
         # separate ones (possibly using detach to avoid gradients for some of them)
         enc_repr_ef = self.enc_ef(enc_repr.detach() if self.bconf.enc_ef_input_detach else enc_repr, mask_arr)
         enc_repr_evt = self.enc_evt(enc_repr.detach() if self.bconf.enc_evt_input_detach else enc_repr, mask_arr)
         if use_doc_hint and dh_both:
             one_ret = (one_sents, input_repr0, enc_repr_ef[:, 1:-1].contiguous(), enc_repr_evt[:, 1:-1].contiguous(), mask_arr0)
         else:
             one_ret = (one_sents, input_repr0, enc_repr_ef, enc_repr_evt, mask_arr0)
         rets.append(one_ret)
     # todo(note): returning tuple is (List[Sentence], Tensor, Tensor, Tensor)
     return rets
Example #7
0
 def calculate_repr(self, cur_t, par_t, label_t, par_mask_t, chs_t,
                    chs_label_t, chs_mask_t, chs_valid_mask_t):
     ret_t = cur_t  # [*, D]
     # padding 0 if not using labels
     dim_label = self.dim_label
     # child features
     if self.use_chs and chs_t is not None:
         if self.use_label_feat:
             chs_label_rt = self.label_embeddings(
                 chs_label_t)  # [*, max-chs, dlab]
         else:
             labels_shape = BK.get_shape(chs_t)
             labels_shape[-1] = dim_label
             chs_label_rt = BK.zeros(labels_shape)
         chs_input_t = BK.concat([chs_t, chs_label_rt], -1)
         chs_feat0 = self.chs_reprer(cur_t, chs_input_t, chs_mask_t,
                                     chs_valid_mask_t)
         chs_feat = self.chs_ff(chs_feat0)
         ret_t += chs_feat
     # parent features
     if self.use_par and par_t is not None:
         if self.use_label_feat:
             cur_label_t = self.label_embeddings(label_t)  # [*, dlab]
         else:
             labels_shape = BK.get_shape(par_t)
             labels_shape[-1] = dim_label
             cur_label_t = BK.zeros(labels_shape)
         par_feat = self.par_ff([par_t, cur_label_t])
         if par_mask_t is not None:
             par_feat *= par_mask_t.unsqueeze(-1)
         ret_t += par_feat
     return ret_t
Example #8
0
 def _ts_self_f(_list_attn_info):
     _rets = []
     for _t, _d in zip(_list_attn_info[3], _list_attn_info[4]):
         # extra dim at idx 0; todo(note): must repeat insts at outmost idx: repeated = insts * copy_num
         _t1 = _t.view([copy_num, -1] + BK.get_shape(_t)[1:])  # [copy, bs, ...]
         # roll it by 1
         _t2 = BK.concat([_t1[-1].unsqueeze(0), _t1[:-1]], dim=0)
         _rets.append((_t1, _t2, _d))
     return _rets
Example #9
0
 def inference_on_batch(self, insts: List[GeneralSentence], **kwargs):
     conf = self.conf
     self.refresh_batch(False)
     # print(f"{len(insts)}: {insts[0].sid}")
     with BK.no_grad_env():
         # decode for dpar
         input_map = self.model.inputter(insts)
         emb_t, mask_t, enc_t, cache, _ = self.model._emb_and_enc(input_map, collect_loss=False)
         input_t = BK.concat(cache.list_attn, -1)  # [bs, slen, slen, L*H]
         self.dpar.predict(insts, BK.zeros([1,1]), input_t, mask_t)
     return {}
Example #10
0
 def _enc(self, input_lexi, input_expr, input_mask, sel_idxes):
     if self.dmxnn:
         bsize, slen = BK.get_shape(input_mask)
         if sel_idxes is None:
             sel_idxes = BK.arange_idx(slen).unsqueeze(
                 0)  # select all, [1, slen]
         ncand = BK.get_shape(sel_idxes, -1)
         # enc_expr aug with PE
         rel_dist = BK.arange_idx(slen).unsqueeze(0).unsqueeze(
             0) - sel_idxes.unsqueeze(-1)  # [*, ?, slen]
         pe_embeds = self.posi_embed(rel_dist)  # [*, ?, slen, Dpe]
         aug_enc_expr = BK.concat([
             pe_embeds.expand(bsize, -1, -1, -1),
             input_expr.unsqueeze(1).expand(-1, ncand, -1, -1)
         ], -1)  # [*, ?, slen, D+Dpe]
         # [*, ?, slen, Denc]
         hidden_expr = self.e_encoder(
             aug_enc_expr.view(bsize * ncand, slen, -1),
             input_mask.unsqueeze(1).expand(-1, ncand,
                                            -1).contiguous().view(
                                                bsize * ncand, slen))
         hidden_expr = hidden_expr.view(bsize, ncand, slen, -1)
         # dynamic max-pooling (dist<0, dist=0, dist>0)
         NEG = Constants.REAL_PRAC_MIN
         mp_hiddens = []
         mp_masks = [rel_dist < 0, rel_dist == 0, rel_dist > 0]
         for mp_mask in mp_masks:
             float_mask = mp_mask.float() * input_mask.unsqueeze(
                 -2)  # [*, ?, slen]
             valid_mask = (float_mask.sum(-1) > 0.).float().unsqueeze(
                 -1)  # [*, ?, 1]
             mask_neg_val = (
                 1. - float_mask).unsqueeze(-1) * NEG  # [*, ?, slen, 1]
             # todo(+2): or do we simply multiply mask?
             mp_hid0 = (hidden_expr + mask_neg_val).max(-2)[0]
             mp_hid = mp_hid0 * valid_mask  # [*, ?, Denc]
             mp_hiddens.append(self.special_drop(mp_hid))
             # mp_hiddens.append(mp_hid)
         final_hiddens = mp_hiddens
     else:
         hidden_expr = self.e_encoder(input_expr,
                                      input_mask)  # [*, slen, D']
         if sel_idxes is None:
             hidden_expr1 = hidden_expr
         else:
             hidden_expr1 = BK.gather_first_dims(hidden_expr, sel_idxes,
                                                 -2)  # [*, ?, D']
         final_hiddens = [self.special_drop(hidden_expr1)]
     if self.lab_f_use_lexi:
         final_hiddens.append(
             BK.gather_first_dims(input_lexi, sel_idxes,
                                  -2))  # [*, ?, DLex]
     ret_expr = self.lab_f(final_hiddens)  # [*, ?, DLab]
     return ret_expr
Example #11
0
 def _flatten_packs(self, packs):
     NUM_RET_PACK = 6  # discard the first mb-size
     ret_packs = [[] for _ in range(NUM_RET_PACK)]
     cur_base_idx = 0
     for one_pack in packs:
         mb_size = one_pack[0]
         ret_packs[0].append(one_pack[1] + cur_base_idx)
         for i in range(1, NUM_RET_PACK):
             ret_packs[i].append(one_pack[i + 1])
         cur_base_idx += mb_size
     ret = [(None if z[0] is None else BK.concat(z, 0)) for z in ret_packs]
     return ret
Example #12
0
 def predict(self, insts: List, input_lexi, input_expr, input_mask):
     conf = self.conf
     # step 1: select mention candidates
     if conf.use_selector:
         sel_mask = self.sel.predict(input_expr, input_mask)
     else:
         sel_mask = input_mask
     sel_idxes, sel_valid_mask = BK.mask2idx(sel_mask)  # [*, max-count]
     # step 2: encoding and labeling
     sel_hid_exprs = self._enc(input_lexi, input_expr, input_mask,
                               sel_idxes)
     sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(
         sel_hid_exprs, None)  # [*, mc], [*, mc, D]
     # =====
     if self.use_secondary_type:
         sectype_embeds = self.t1tot2(sel_lab_idxes)  # [*, mc, D]
         sel2_input = sel_hid_exprs + sectype_embeds  # [*, mc, D]
         sel2_lab_logprobs, sel2_lab_idxes, sel2_lab_embeds = self.hl.predict(
             sel2_input, None)
         if conf.sectype_t2ift1:
             sel2_lab_idxes *= (
                 sel_lab_idxes >
                 0).long()  # pred t2 only if t1 is not 0 (nil)
         # first concat here and then exclude nil at one pass # [*, mc*2, ~]
         if sel2_lab_idxes.sum().item() > 0:  # if there are any predictions
             sel_lab_logprobs = BK.concat(
                 [sel_lab_logprobs, sel2_lab_logprobs], -1)
             sel_idxes = BK.concat([sel_idxes, sel_idxes], -1)
             sel_valid_mask = BK.concat([sel_valid_mask, sel_valid_mask],
                                        -1)
             sel_lab_idxes = BK.concat([sel_lab_idxes, sel2_lab_idxes], -1)
             sel_lab_embeds = BK.concat([sel_lab_embeds, sel2_lab_embeds],
                                        -2)
     # =====
     # step 3: exclude nil and return
     if conf.exclude_nil:  # [*, mc', ...]
         sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_lab_logprobs, _ = \
             self._exclude_nil(sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_logprobs=sel_lab_logprobs)
     # sel_enc_expr = BK.gather_first_dims(input_expr, sel_idxes, -2)  # [*, mc', D]
     return sel_lab_logprobs, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
Example #13
0
 def fb_on_batch(self, insts: List[GeneralSentence], training=True, loss_factor=1.,
                 rand_gen=None, assign_attns=False, **kwargs):
     self.refresh_batch(training)
     # get inputs with models
     with BK.no_grad_env():
         input_map = self.model.inputter(insts)
         emb_t, mask_t, enc_t, cache, enc_loss = self.model._emb_and_enc(input_map, collect_loss=True)
         input_t = BK.concat(cache.list_attn, -1)  # [bs, slen, slen, L*H]
     losses = [self.dpar.loss(insts, BK.zeros([1,1]), input_t, mask_t)]
     # -----
     info = self.collect_loss_and_backward(losses, training, loss_factor)
     info.update({"fb": 1, "sent": len(insts), "tok": sum(len(z) for z in insts)})
     return info
Example #14
0
 def __call__(self, input_map: Dict):
     exprs = []
     # get masks: this mask is for validing of inst batching
     final_masks = BK.input_real(input_map["mask"])  # [*, slen]
     if self.add_root_token:  # append 1
         slice_t = BK.constants(BK.get_shape(final_masks)[:-1]+[1], 1.)
         final_masks = BK.concat([slice_t, final_masks], -1)  # [*, 1+slen]
     # -----
     # for each component
     for idx, name in enumerate(self.comp_names):
         cur_node = self.nodes[idx]
         cur_input = input_map[name]
         cur_expr = cur_node(cur_input, self.add_root_token)
         exprs.append(cur_expr)
     # -----
     concated_exprs = BK.concat(exprs, dim=-1)
     # optional proj
     if self.has_proj:
         final_expr = self.final_layer(concated_exprs)
     else:
         final_expr = concated_exprs
     return final_expr, final_masks
Example #15
0
 def loss(self, insts: List, input_lexi, input_expr, input_mask, margin=0.):
     # todo(+N): currently margin is not used
     conf = self.conf
     bsize = len(insts)
     arange_t = BK.arange_idx(bsize)
     assert conf.train_force, "currently only have forced training"
     # get the gold ones
     gold_widxes, gold_lidxes, gold_vmasks, ret_items, _ = self.batch_inputs_g1(insts)  # [*, ?]
     # for all the steps
     num_step = BK.get_shape(gold_widxes, -1)
     # recurrent states
     hard_coverage = BK.zeros(BK.get_shape(input_mask))  # [*, slen]
     prev_state = self.rnn_unit.zero_init_hidden(bsize)  # tuple([*, D], )
     all_tok_logprobs, all_lab_logprobs = [], []
     for cstep in range(num_step):
         slice_widx, slice_lidx = gold_widxes[:,cstep], gold_lidxes[:,cstep]
         _, sel_tok_logprobs, _, sel_lab_logprobs, _, next_state = \
             self._step(input_expr, input_mask, hard_coverage, prev_state, slice_widx, slice_lidx, None)
         all_tok_logprobs.append(sel_tok_logprobs)  # add one of [*, 1]
         all_lab_logprobs.append(sel_lab_logprobs)
         hard_coverage = BK.copy(hard_coverage)  # todo(note): cannot modify inplace!
         hard_coverage[arange_t, slice_widx] += 1.
         prev_state = [z.squeeze(-2) for z in next_state]
     # concat all the loss and mask
     # todo(note): no need to use gold_valid since things are telled in vmasks
     cat_tok_logprobs = BK.concat(all_tok_logprobs, -1) * gold_vmasks  # [*, steps]
     cat_lab_logprobs = BK.concat(all_lab_logprobs, -1) * gold_vmasks
     loss_sum = - (cat_tok_logprobs.sum() * conf.lambda_att + cat_lab_logprobs.sum() * conf.lambda_lab)
     # todo(+N): here we are dividing lab_logprobs with the all-count, do we need to separate?
     loss_count = gold_vmasks.sum()
     ret_losses = [[loss_sum, loss_count]]
     # =====
     # make eos unvalid for return
     ret_valid_mask = gold_vmasks * (gold_widxes>0).float()
     # embeddings
     sel_lab_embeds = self._hl_lookup(gold_lidxes)
     return ret_losses, ret_items, gold_widxes, ret_valid_mask, gold_lidxes, sel_lab_embeds
Example #16
0
 def _step(self, input_expr, input_mask, hard_coverage, prev_state, force_widx, force_lidx, free_beam_size):
     conf = self.conf
     free_mode = (force_widx is None)
     prev_state_h = prev_state[0]
     # =====
     # collect att scores
     key_up = self.affine_k([input_expr, hard_coverage.unsqueeze(-1)])  # [*, slen, h]
     query_up = self.affine_q([self.repos.unsqueeze(0), prev_state_h.unsqueeze(-2)])  # [*, R, h]
     orig_scores = BK.matmul(key_up, query_up.transpose(-2, -1))  # [*, slen, R]
     orig_scores += (1.-input_mask).unsqueeze(-1) * Constants.REAL_PRAC_MIN  # [*, slen, R]
     # first maximum across the R dim (this step is hard max)
     maxr_scores, maxr_idxes = orig_scores.max(-1)  # [*, slen]
     if conf.zero_eos_score:
         # use mask to make it able to be backward
         tmp_mask = BK.constants(BK.get_shape(maxr_scores), 1.)
         tmp_mask.index_fill_(-1, BK.input_idx(0), 0.)
         maxr_scores *= tmp_mask
     # then select over the slen dim (this step is prob based)
     maxr_logprobs = BK.log_softmax(maxr_scores)  # [*, slen]
     if free_mode:
         cur_beam_size = min(free_beam_size, BK.get_shape(maxr_logprobs, -1))
         sel_tok_logprobs, sel_tok_idxes = maxr_logprobs.topk(cur_beam_size, dim=-1, sorted=False)  # [*, beam]
     else:
         sel_tok_idxes = force_widx.unsqueeze(-1)  # [*, 1]
         sel_tok_logprobs = maxr_logprobs.gather(-1, sel_tok_idxes)  # [*, 1]
     # then collect the info and perform labeling
     lf_input_expr = BK.gather_first_dims(input_expr, sel_tok_idxes, -2)  # [*, ?, ~]
     lf_coverage = hard_coverage.gather(-1, sel_tok_idxes).unsqueeze(-1)  # [*, ?, 1]
     lf_repos = self.repos[maxr_idxes.gather(-1, sel_tok_idxes)]  # [*, ?, ~]  # todo(+3): using soft version?
     lf_prev_state = prev_state_h.unsqueeze(-2)  # [*, 1, ~]
     lab_hid_expr = self.lab_f([lf_input_expr, lf_coverage, lf_repos, lf_prev_state])  # [*, ?, ~]
     # final predicting labels
     # todo(+N): here we select only max at labeling part, only beam at previous one
     if free_mode:
         sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, None)  # [*, ?]
     else:
         sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, force_lidx.unsqueeze(-1))
     # no lab-logprob (*=0) for eos (sel_tok==0)
     sel_lab_logprobs *= (sel_tok_idxes>0).float()
     # compute next-state [*, ?, ~]
     # todo(note): here we flatten the first two dims
     tmp_rnn_dims = BK.get_shape(sel_tok_idxes) + [-1]
     tmp_rnn_input = BK.concat([lab_hid_expr, sel_lab_embeds], -1)
     tmp_rnn_input = tmp_rnn_input.view(-1, BK.get_shape(tmp_rnn_input, -1))
     tmp_rnn_hidden = [z.unsqueeze(-2).expand(tmp_rnn_dims).contiguous().view(-1, BK.get_shape(z, -1))
                       for z in prev_state]  # [*, ?, ?, D]
     next_state = self.rnn_unit(tmp_rnn_input, tmp_rnn_hidden, None)
     next_state = [z.view(tmp_rnn_dims) for z in next_state]
     return sel_tok_idxes, sel_tok_logprobs, sel_lab_idxes, sel_lab_logprobs, sel_lab_embeds, next_state
Example #17
0
 def __call__(self, input, add_root_token: bool):
     voc = self.voc
     # todo(note): append a [cls/root] idx, currently use "bos"
     input_t = BK.input_idx(input)  # [*, 1+slen]
     # rare unk in training
     if self.rop.training and self.use_rare_unk:
         rare_unk_rate = self.ec_conf.comp_rare_unk
         cur_unk_imask = (self.rare_mask[input_t] * (BK.rand(BK.get_shape(input_t))<rare_unk_rate)).detach().long()
         input_t = input_t * (1-cur_unk_imask) + self.voc.unk * cur_unk_imask
     # root
     if add_root_token:
         input_t_p0 = BK.constants(BK.get_shape(input_t)[:-1]+[1], voc.bos, dtype=input_t.dtype)  # [*, 1+slen]
         input_t_p1 = BK.concat([input_t_p0, input_t], -1)
     else:
         input_t_p1 = input_t
     expr = self.E(input_t_p1)  # [*, 1?+slen]
     return self.dropout(expr)
Example #18
0
 def _normalize(self, cnode: ConcreteNode, orig_scores, use_noop: bool,
                noop_fixed_val: float, temperature: float, dim: int):
     cur_shape = BK.get_shape(orig_scores)  # original
     orig_that_dim = cur_shape[dim]
     cur_shape[dim] = 1
     if use_noop:
         noop_scores = BK.constants(cur_shape,
                                    value=noop_fixed_val)  # [*, 1, *]
         to_norm_scores = BK.concat([orig_scores, noop_scores],
                                    dim=dim)  # [*, D+1, *]
     else:
         to_norm_scores = orig_scores  # [*, D, *]
     # normalize
     prob_full = cnode(to_norm_scores, temperature=temperature,
                       dim=dim)  # [*, ?, *]
     if use_noop:
         prob_valid, prob_noop = BK.split(prob_full, [orig_that_dim, 1],
                                          dim)  # [*, D|1, *]
     else:
         prob_valid, prob_noop = prob_full, None
     return prob_valid, prob_noop, prob_full
Example #19
0
 def run(self, insts, training, input_word_mask_repl=None):
     self._cache_subword_tokens(insts)
     # prepare inputs
     word_arr, char_arr, extra_arrs, aux_arrs, mask_arr = \
         self.prepare_inputs(insts, training, input_word_mask_repl=input_word_mask_repl)
     # layer0: emb + bert
     layer0_reprs = []
     if self.emb_output_dim > 0:
         emb_repr = self.emb(word_arr, char_arr, extra_arrs,
                             aux_arrs)  # [BS, Len, Dim]
         layer0_reprs.append(emb_repr)
     if self.bert_output_dim > 0:
         # prepare bert inputs
         BERT_MASK_ID = self.bert.tokenizer.mask_token_id
         batch_subword_ids, batch_subword_is_starts = [], []
         for bidx, one_inst in enumerate(insts):
             st = one_inst.extra_features["st"]
             if input_word_mask_repl is not None:
                 cur_subword_ids, cur_subword_is_start, _ = \
                     st.mask_and_return(input_word_mask_repl[bidx][1:], BERT_MASK_ID)  # todo(note): exclude ROOT for bert tokens
             else:
                 cur_subword_ids, cur_subword_is_start = st.subword_ids, st.subword_is_start
             batch_subword_ids.append(cur_subword_ids)
             batch_subword_is_starts.append(cur_subword_is_start)
         bert_repr, _ = self.bert.forward_batch(
             batch_subword_ids,
             batch_subword_is_starts,
             batched_typeids=None,
             training=training)  # [BS, Len, D']
         layer0_reprs.append(bert_repr)
     # layer1: enc
     enc_input_repr = BK.concat(layer0_reprs, -1)  # [BS, Len, D+D']
     if self.middle_node is not None:
         enc_input_repr = self.middle_node(enc_input_repr)  # [BS, Len, D??]
     enc_repr = self.enc(enc_input_repr, mask_arr)
     mask_repr = BK.input_real(mask_arr)
     return enc_repr, mask_repr  # [bs, len, *], [bs, len]
Example #20
0
 def run(self, insts: List[DocInstance], training: bool):
     conf = self.conf
     BERT_MAX_LEN = 510  # save 2 for CLS and SEP
     # =====
     # encoder 1: the basic encoder
     # todo(note): only DocInstane input for this mode, otherwise will break
     if conf.m2e_use_basic:
         reidx_pad_len = conf.ms_extend_budget
         # enc the basic part + also get some indexes
         sentid2offset = {}  # id(sent)->overall_seq_offset
         seq_offset = 0  # if look at the docs in one seq
         all_sents = []  # (inst, d_idx, s_idx)
         for d_idx, one_doc in enumerate(insts):
             assert isinstance(one_doc, DocInstance)
             for s_idx, one_sent in enumerate(one_doc.sents):
                 # todo(note): here we encode all the sentences
                 all_sents.append((one_sent, d_idx, s_idx))
                 sentid2offset[id(one_sent)] = seq_offset
                 seq_offset += one_sent.length - 1  # exclude extra ROOT node
         sent_reprs = self.run_sents(all_sents, insts, training)
         # flatten and concatenate and re-index
         reidxes_arr = np.zeros(
             seq_offset + reidx_pad_len, dtype=np.long
         )  # todo(note): extra padding to avoid out of boundary
         all_flattened_reprs = []
         all_flatten_offset = 0  # the local offset for batched basic encoding
         for one_pack in sent_reprs:
             one_sents, _, one_repr_ef, one_repr_evt, _ = one_pack
             assert one_repr_ef is one_repr_evt, "Currently does not support separate basic enc in m3 mode"
             one_repr_t = one_repr_evt
             _, one_slen, one_ldim = BK.get_shape(one_repr_t)
             all_flattened_reprs.append(one_repr_t.view([-1, one_ldim]))
             # fill in the indexes
             for one_sent in one_sents:
                 cur_start_offset = sentid2offset[id(one_sent)]
                 cur_real_slen = one_sent.length - 1
                 # again, +1 to get rid of extra ROOT
                 reidxes_arr[cur_start_offset:cur_start_offset+cur_real_slen] = \
                     np.arange(cur_real_slen, dtype=np.long) + (all_flatten_offset+1)
                 all_flatten_offset += one_slen  # here add the slen in batched version
         # re-idxing
         seq_sent_repr0 = BK.concat(all_flattened_reprs, 0)
         seq_sent_repr = BK.select(seq_sent_repr0, reidxes_arr,
                                   0)  # [all_seq_len, D]
     else:
         sentid2offset = defaultdict(int)
         seq_sent_repr = None
     # =====
     # repack and prepare for multiple sent enc
     # todo(note): here, the criterion is based on bert's tokenizer
     all_ms_info = []
     if isinstance(insts[0], DocInstance):
         for d_idx, one_doc in enumerate(insts):
             for s_idx, x in enumerate(one_doc.sents):
                 # the basic criterion is the same as the basic one
                 include_flag = False
                 if training:
                     if x.length<self.train_skip_length and x.length>=self.train_min_length \
                             and (len(x.events)>0 or next(self.random_sample_stream)>self.train_skip_noevt_rate):
                         include_flag = True
                 else:
                     if x.length >= self.test_min_length:
                         include_flag = True
                 if include_flag:
                     all_ms_info.append(
                         x.preps["ms"])  # use the pre-calculated one
     else:
         # multisent based
         all_ms_info = insts.copy()  # shallow copy
     # =====
     # encoder 2: the bert one (multi-sent encoding)
     ms_size_f = lambda x: x.subword_size
     all_ms_info.sort(key=ms_size_f)
     all_ms_buckets = self._bucket_sents_by_length(
         all_ms_info,
         conf.benc_bucket_range,
         ms_size_f,
         max_bsize=conf.benc_bucket_msize)
     berter = self.berter
     rets = []
     bert_use_center_typeids = conf.bert_use_center_typeids
     bert_use_special_typeids = conf.bert_use_special_typeids
     bert_other_inputs = conf.bert_other_inputs
     for one_bucket in all_ms_buckets:
         # prepare
         batched_ids = []
         batched_starts = []
         batched_seq_offset = []
         batched_typeids = []
         batched_other_inputs_list: List = [
             [] for _ in bert_other_inputs
         ]  # List(comp) of List(batch) of List(idx)
         for one_item in one_bucket:
             one_sents = one_item.sents
             one_center_sid = one_item.center_idx
             one_ids, one_starts, one_typeids = [], [], []
             one_other_inputs_list = [[] for _ in bert_other_inputs
                                      ]  # List(comp) of List(idx)
             for one_sid, one_sent in enumerate(one_sents):  # for bert
                 one_bidxes = one_sent.preps["bidx"]
                 one_ids.extend(one_bidxes.subword_ids)
                 one_starts.extend(one_bidxes.subword_is_start)
                 # prepare other inputs
                 for this_field_name, this_tofill_list in zip(
                         bert_other_inputs, one_other_inputs_list):
                     this_tofill_list.extend(
                         one_sent.preps["sub_" + this_field_name])
                 # todo(note): special procedure
                 if bert_use_center_typeids:
                     if one_sid != one_center_sid:
                         one_typeids.extend([0] *
                                            len(one_bidxes.subword_ids))
                     else:
                         this_typeids = [1] * len(one_bidxes.subword_ids)
                         if bert_use_special_typeids:
                             # todo(note): this is the special mode that we are given the events!!
                             for this_event in one_sents[
                                     one_center_sid].events:
                                 _, this_wid, this_wlen = this_event.mention.hard_span.position(
                                     headed=False)
                                 for a, b in one_item.center_word2sub[
                                         this_wid - 1:this_wid - 1 +
                                         this_wlen]:
                                     this_typeids[a:b] = [0] * (b - a)
                         one_typeids.extend(this_typeids)
             batched_ids.append(one_ids)
             batched_starts.append(one_starts)
             batched_typeids.append(one_typeids)
             for comp_one_oi, comp_batched_oi in zip(
                     one_other_inputs_list, batched_other_inputs_list):
                 comp_batched_oi.append(comp_one_oi)
             # for basic part
             batched_seq_offset.append(sentid2offset[id(one_sents[0])])
         # bert forward: [bs, slen, fold, D]
         if not bert_use_center_typeids:
             batched_typeids = None
         bert_expr0, mask_expr = berter.forward_batch(
             batched_ids,
             batched_starts,
             batched_typeids,
             training=training,
             other_inputs=batched_other_inputs_list)
         if self.m3_enc_is_empty:
             bert_expr = bert_expr0
         else:
             mask_arr = BK.get_value(mask_expr)  # [bs, slen]
             m3e_exprs = [
                 cur_enc(bert_expr0[:, :, cur_i], mask_arr)
                 for cur_i, cur_enc in enumerate(self.m3_encs)
             ]
             bert_expr = BK.stack(m3e_exprs, -2)  # on the fold dim again
         # collect basic ones: [bs, slen, D'] or None
         if seq_sent_repr is not None:
             arange_idxes_t = BK.arange_idx(BK.get_shape(
                 mask_expr, -1)).unsqueeze(0)  # [1, slen]
             offset_idxes_t = BK.input_idx(batched_seq_offset).unsqueeze(
                 -1) + arange_idxes_t  # [bs, slen]
             basic_expr = seq_sent_repr[offset_idxes_t]  # [bs, slen, D']
         elif conf.m2e_use_basic_dep:
             # collect each token's head-bert and ud-label, then forward with adp
             fake_sents = [one_item.fake_sent for one_item in one_bucket]
             # head idx and labels, no artificial ROOT
             padded_head_arr, _ = self.dep_padder.pad(
                 [s.ud_heads.vals[1:] for s in fake_sents])
             padded_label_arr, _ = self.dep_padder.pad(
                 [s.ud_labels.idxes[1:] for s in fake_sents])
             # get tensor
             padded_head_t = (BK.input_idx(padded_head_arr) - 1
                              )  # here, the idx exclude root
             padded_head_t.clamp_(min=0)  # [bs, slen]
             padded_label_t = BK.input_idx(padded_label_arr)
             # get inputs
             input_head_bert_t = bert_expr[
                 BK.arange_idx(len(fake_sents)).unsqueeze(-1),
                 padded_head_t]  # [bs, slen, fold, D]
             input_label_emb_t = self.dep_label_emb(
                 padded_label_t)  # [bs, slen, D']
             basic_expr = self.dep_layer(
                 input_head_bert_t, None,
                 [input_label_emb_t])  # [bs, slen, ?]
         elif conf.m2e_use_basic_plus:
             sent_reprs = self.run_sents([(one_item.fake_sent, None, None)
                                          for one_item in one_bucket],
                                         insts,
                                         training,
                                         use_one_bucket=True)
             assert len(
                 sent_reprs
             ) == 1, "Unsupported split reprs for basic encoder, please set enc_bucket_range<=benc_bucket_range"
             _, _, one_repr_ef, one_repr_evt, _ = sent_reprs[0]
             assert one_repr_ef is one_repr_evt, "Currently does not support separate basic enc in m3 mode"
             basic_expr = one_repr_evt[:, 1:]  # exclude ROOT, [bs, slen, D]
             assert BK.get_shape(basic_expr)[:2] == BK.get_shape(
                 bert_expr)[:2]
         else:
             basic_expr = None
         # pack: (List[ms_item], bert_expr, basic_expr)
         rets.append((one_bucket, bert_expr, basic_expr))
     return rets
Example #21
0
 def forward_batch(self,
                   batched_ids: List,
                   batched_starts: List,
                   batched_typeids: List,
                   training: bool,
                   other_inputs: List[List] = None):
     conf = self.bconf
     tokenizer = self.tokenizer
     PAD_IDX = tokenizer.pad_token_id
     MASK_IDX = tokenizer.mask_token_id
     CLS_IDX = tokenizer.cls_token_id
     SEP_IDX = tokenizer.sep_token_id
     if other_inputs is None:
         other_inputs = []
     # =====
     # batch: here add CLS and SEP
     bsize = len(batched_ids)
     max_len = max(len(z) for z in batched_ids) + 2  # plus [CLS] and [SEP]
     input_shape = (bsize, max_len)
     # first collect on CPU
     input_ids_arr = np.full(input_shape, PAD_IDX, dtype=np.int64)
     input_ids_arr[:, 0] = CLS_IDX
     input_mask_arr = np.full(input_shape, 0, dtype=np.float32)
     input_is_start_arr = np.full(input_shape, 0, dtype=np.int64)
     input_typeids = None if batched_typeids is None else np.full(
         input_shape, 0, dtype=np.int64)
     other_input_arrs = [
         np.full(input_shape, 0, dtype=np.int64) for _ in other_inputs
     ]
     if conf.bert2_retinc_cls:  # act as the ROOT word
         input_is_start_arr[:, 0] = 1
     training_mask_rate = conf.bert2_training_mask_rate if training else 0.
     self_sample_stream = self.random_sample_stream
     for bidx in range(bsize):
         cur_ids, cur_starts = batched_ids[bidx], batched_starts[bidx]
         cur_end = len(cur_ids) + 2  # plus CLS and SEP
         if training_mask_rate > 0.:
             # input dropout
             input_ids_arr[bidx, 1:cur_end] = [
                 (MASK_IDX
                  if next(self_sample_stream) < training_mask_rate else z)
                 for z in cur_ids
             ] + [SEP_IDX]
         else:
             input_ids_arr[bidx, 1:cur_end] = cur_ids + [SEP_IDX]
         input_is_start_arr[bidx, 1:cur_end - 1] = cur_starts
         input_mask_arr[bidx, :cur_end] = 1.
         if batched_typeids is not None and batched_typeids[
                 bidx] is not None:
             input_typeids[bidx, 1:cur_end - 1] = batched_typeids[bidx]
         for one_other_input_arr, one_other_input_list in zip(
                 other_input_arrs, other_inputs):
             one_other_input_arr[bidx,
                                 1:cur_end - 1] = one_other_input_list[bidx]
     # arr to tensor
     input_ids_t = BK.input_idx(input_ids_arr)
     input_mask_t = BK.input_real(input_mask_arr)
     input_is_start_t = BK.input_idx(input_is_start_arr)
     input_typeid_t = None if input_typeids is None else BK.input_idx(
         input_typeids)
     other_input_ts = [BK.input_idx(z) for z in other_input_arrs]
     # =====
     # forward (maybe need multiple times to fit maxlen constraint)
     MAX_LEN = 510  # save two for [CLS] and [SEP]
     BACK_LEN = 100  # for splitting cases, still remaining some of previous sub-tokens for context
     if max_len <= MAX_LEN:
         # directly once
         final_outputs = self.forward_features(
             input_ids_t, input_mask_t, input_typeid_t,
             other_input_ts)  # [bs, slen, *...]
         start_idxes, start_masks = BK.mask2idx(
             input_is_start_t.float())  # [bsize, ?]
     else:
         all_outputs = []
         cur_sub_idx = 0
         slice_size = [bsize, 1]
         slice_cls, slice_sep = BK.constants(slice_size,
                                             CLS_IDX,
                                             dtype=BK.int64), BK.constants(
                                                 slice_size,
                                                 SEP_IDX,
                                                 dtype=BK.int64)
         while cur_sub_idx < max_len - 1:  # minus 1 to ignore ending SEP
             cur_slice_start = max(1, cur_sub_idx - BACK_LEN)
             cur_slice_end = min(cur_slice_start + MAX_LEN, max_len - 1)
             cur_input_ids_t = BK.concat([
                 slice_cls, input_ids_t[:, cur_slice_start:cur_slice_end],
                 slice_sep
             ], 1)
             # here we simply extend extra original masks
             cur_input_mask_t = input_mask_t[:, cur_slice_start -
                                             1:cur_slice_end + 1]
             cur_input_typeid_t = None if input_typeid_t is None else input_typeid_t[:,
                                                                                     cur_slice_start
                                                                                     -
                                                                                     1:
                                                                                     cur_slice_end
                                                                                     +
                                                                                     1]
             cur_other_input_ts = [
                 z[:, cur_slice_start - 1:cur_slice_end + 1]
                 for z in other_input_ts
             ]
             cur_outputs = self.forward_features(cur_input_ids_t,
                                                 cur_input_mask_t,
                                                 cur_input_typeid_t,
                                                 cur_other_input_ts)
             # only include CLS in the first run, no SEP included
             if cur_sub_idx == 0:
                 # include CLS, exclude SEP
                 all_outputs.append(cur_outputs[:, :-1])
             else:
                 # include only new ones, discard BACK ones, exclude CLS, SEP
                 all_outputs.append(cur_outputs[:, cur_sub_idx -
                                                cur_slice_start + 1:-1])
                 zwarn(
                     f"Add multiple-seg range: [{cur_slice_start}, {cur_sub_idx}, {cur_slice_end})] "
                     f"for all-len={max_len}")
             cur_sub_idx = cur_slice_end
         final_outputs = BK.concat(all_outputs, 1)  # [bs, max_len-1, *...]
         start_idxes, start_masks = BK.mask2idx(
             input_is_start_t[:, :-1].float())  # [bsize, ?]
     start_expr = BK.gather_first_dims(final_outputs, start_idxes,
                                       1)  # [bsize, ?, *...]
     return start_expr, start_masks  # [bsize, ?, ...], [bsize, ?]
Example #22
0
 def loss(self, insts: List, input_lexi, input_expr, input_mask, margin=0.):
     conf = self.conf
     bsize = len(insts)
     # first get gold info, also multiple valid-masks
     gold_masks, gold_idxes, gold_items_arr, gold_valid, gold_idxes2, gold_items2_arr = self.batch_inputs_h(
         insts)
     input_mask = input_mask * gold_valid.unsqueeze(-1)  # [*, slen]
     # step 1: selector
     if conf.use_selector:
         sel_loss, sel_mask = self.sel.loss(input_expr,
                                            input_mask,
                                            gold_masks,
                                            margin=margin)
     else:
         sel_loss, sel_mask = None, self._select_cands_training(
             input_mask, gold_masks, conf.train_min_rate)
     sel_idxes, sel_valid_mask = BK.mask2idx(sel_mask)  # [*, max-count]
     sel_gold_idxes = gold_idxes.gather(-1, sel_idxes)
     sel_gold_idxes2 = gold_idxes2.gather(-1, sel_idxes)
     # todo(+N): only get items by head position!
     _tmp_i0, _tmp_i1 = np.arange(bsize)[:, np.newaxis], BK.get_value(
         sel_idxes)
     sel_items = gold_items_arr[_tmp_i0, _tmp_i1]  # [*, mc]
     sel2_items = gold_items2_arr[_tmp_i0, _tmp_i1]
     # step 2: encoding and labeling
     # if we select nothing
     # ----- debug
     # zlog(f"fb-extractor 1: shape sel_idxes = {sel_idxes.shape}")
     # -----
     sel_shape = BK.get_shape(sel_idxes)
     if sel_shape[-1] == 0:
         lab_loss = [[BK.zeros([]), BK.zeros([])]]
         sel2_lab_loss = [[BK.zeros([]), BK.zeros([])]
                          ] if self.use_secondary_type else None
         sel_lab_idxes = sel_gold_idxes
         sel_lab_embeds = BK.zeros(sel_shape + [conf.lab_conf.n_dim])
         ret_items = sel_items  # dim-1==0
     else:
         sel_hid_exprs = self._enc(input_lexi, input_expr, input_mask,
                                   sel_idxes)  # [*, mc, DLab]
         lab_loss, sel_lab_idxes, sel_lab_embeds = self.hl.loss(
             sel_hid_exprs, sel_valid_mask, sel_gold_idxes, margin=margin)
         if conf.train_gold_corr:
             sel_lab_idxes = sel_gold_idxes
             if not self.hl.conf.use_lookup_soft:
                 sel_lab_embeds = self.hl.lookup(sel_lab_idxes)
         ret_items = sel_items
         # =====
         if self.use_secondary_type:
             sectype_embeds = self.t1tot2(sel_lab_idxes)  # [*, mc, D]
             if conf.sectype_noback_enc:
                 sel2_input = sel_hid_exprs.detach(
                 ) + sectype_embeds  # [*, mc, D]
             else:
                 sel2_input = sel_hid_exprs + sectype_embeds  # [*, mc, D]
             # =====
             # sepcial for the sectype mask (sample it within the gold ones)
             sel2_valid_mask = self._select_cands_training(
                 (sel_gold_idxes > 0).float(),
                 (sel_gold_idxes2 > 0).float(), conf.train_min_rate_s2)
             # =====
             sel2_lab_loss, sel2_lab_idxes, sel2_lab_embeds = self.hl.loss(
                 sel2_input,
                 sel2_valid_mask,
                 sel_gold_idxes2,
                 margin=margin)
             if conf.train_gold_corr:
                 sel2_lab_idxes = sel_gold_idxes2
                 if not self.hl.conf.use_lookup_soft:
                     sel2_lab_embeds = self.hl.lookup(sel2_lab_idxes)
             if conf.sectype_t2ift1:
                 sel2_lab_idxes = sel2_lab_idxes * (sel_lab_idxes > 0).long(
                 )  # pred t2 only if t1 is not 0 (nil)
             # combine the two
             if sel2_lab_idxes.sum().item(
             ) > 0:  # if there are any gold sectypes
                 ret_items = np.concatenate([ret_items, sel2_items],
                                            -1)  # [*, mc*2]
                 sel_idxes = BK.concat([sel_idxes, sel_idxes], -1)
                 sel_valid_mask = BK.concat(
                     [sel_valid_mask, sel2_valid_mask], -1)
                 sel_lab_idxes = BK.concat([sel_lab_idxes, sel2_lab_idxes],
                                           -1)
                 sel_lab_embeds = BK.concat(
                     [sel_lab_embeds, sel2_lab_embeds], -2)
         else:
             sel2_lab_loss = None
         # =====
         # step 3: exclude nil and return
         if conf.exclude_nil:  # [*, mc', ...]
             sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, _, ret_items = \
                 self._exclude_nil(sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_items_arr=ret_items)
     # sel_enc_expr = BK.gather_first_dims(input_expr, sel_idxes, -2)  # [*, mc', D]
     # step 4: finally prepare loss and items
     for one_loss in lab_loss:
         one_loss[0] *= conf.lambda_ne
     ret_losses = lab_loss
     if sel2_lab_loss is not None:
         for one_loss in sel2_lab_loss:
             one_loss[0] *= conf.lambda_ne2
         ret_losses = ret_losses + sel2_lab_loss
     if sel_loss is not None:
         for one_loss in sel_loss:
             one_loss[0] *= conf.lambda_ns
         ret_losses = ret_losses + sel_loss
     # ----- debug
     # zlog(f"fb-extractor 2: shape sel_idxes = {sel_idxes.shape}")
     # -----
     # mask out invalid items with None
     ret_items[BK.get_value(1. - sel_valid_mask).astype(np.bool)] = None
     return ret_losses, ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
Example #23
0
 def predict(self, insts: List, input_lexi, input_expr, input_mask):
     conf = self.conf
     bsize, slen = BK.get_shape(input_mask)
     bsize_arange_t_1d = BK.arange_idx(bsize)  # [*]
     bsize_arange_t_2d = bsize_arange_t_1d.unsqueeze(-1)  # [*, 1]
     beam_size = conf.beam_size
     # prepare things with an extra beam dimension
     beam_input_expr, beam_input_mask = input_expr.unsqueeze(-3).expand(-1, beam_size, -1, -1).contiguous(), \
                                        input_mask.unsqueeze(-2).expand(-1, beam_size, -1).contiguous()  # [*, beam, slen, D?]
     # -----
     # recurrent states
     beam_hard_coverage = BK.zeros([bsize, beam_size, slen])  # [*, beam, slen]
     # tuple([*, beam, D], )
     beam_prev_state = [z.unsqueeze(-2).expand(-1, beam_size, -1) for z in self.rnn_unit.zero_init_hidden(bsize)]
     # frozen after reach eos
     beam_noneos = 1.-BK.zeros([bsize, beam_size])  # [*, beam]
     beam_logprobs = BK.zeros([bsize, beam_size])  # [*, beam], sum of logprobs
     beam_logprobs_paths = BK.zeros([bsize, beam_size, 0])  # [*, beam, step]
     beam_tok_paths = BK.zeros([bsize, beam_size, 0]).long()
     beam_lab_paths = BK.zeros([bsize, beam_size, 0]).long()
     # -----
     for cstep in range(conf.max_step):
         # get things of [*, beam, beam]
         sel_tok_idxes, sel_tok_logprobs, sel_lab_idxes, sel_lab_logprobs, sel_lab_embeds, next_state = \
             self._step(beam_input_expr, beam_input_mask, beam_hard_coverage, beam_prev_state, None, None, beam_size)
         sel_logprobs = sel_tok_logprobs + sel_lab_logprobs  # [*, beam, beam]
         if cstep == 0:
             # special for the first step, only select for the first element
             cur_selections = BK.arange_idx(beam_size).unsqueeze(0).expand(bsize, beam_size)  # [*, beam]
         else:
             # then select the topk in beam*beam (be careful about the frozen ones!!)
             beam_noneos_3d = beam_noneos.unsqueeze(-1)
             # eos can only followed by eos
             sel_tok_idxes *= beam_noneos_3d.long()
             sel_lab_idxes *= beam_noneos_3d.long()
             # numeric tricks to keep the frozen ones ([0] with 0. score, [1:] with -inf scores)
             sel_logprobs *= beam_noneos_3d
             tmp_exclude_mask = 1. - beam_noneos_3d.expand_as(sel_logprobs)
             tmp_exclude_mask[:, :, 0] = 0.
             sel_logprobs += tmp_exclude_mask * Constants.REAL_PRAC_MIN
             # select for topk
             topk_logprobs = (beam_noneos * beam_logprobs).unsqueeze(-1) + sel_logprobs
             _, cur_selections = topk_logprobs.view([bsize, -1]).topk(beam_size, dim=-1, sorted=True)  # [*, beam]
         # read and write the selections
         # gathering previous ones
         cur_sel_previ = cur_selections // beam_size  # [*, beam]
         prev_hard_coverage = beam_hard_coverage[bsize_arange_t_2d, cur_sel_previ]  # [*, beam]
         prev_noneos = beam_noneos[bsize_arange_t_2d, cur_sel_previ]  # [*, beam]
         prev_logprobs = beam_logprobs[bsize_arange_t_2d, cur_sel_previ]  # [*, beam]
         prev_logprobs_paths = beam_logprobs_paths[bsize_arange_t_2d, cur_sel_previ]  # [*, beam, step]
         prev_tok_paths = beam_tok_paths[bsize_arange_t_2d, cur_sel_previ]  # [*, beam, step]
         prev_lab_paths = beam_lab_paths[bsize_arange_t_2d, cur_sel_previ]  # [*, beam, step]
         # prepare new ones
         cur_sel_newi = cur_selections % beam_size
         new_tok_idxes = sel_tok_idxes[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi]  # [*, beam]
         new_lab_idxes = sel_lab_idxes[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi]  # [*, beam]
         new_logprobs = sel_logprobs[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi]  # [*, beam]
         new_prev_state = [z[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi] for z in next_state]  # [*, beam, ~]
         # update
         prev_hard_coverage[bsize_arange_t_2d, BK.arange_idx(beam_size).unsqueeze(0), new_tok_idxes] += 1.
         beam_hard_coverage = prev_hard_coverage
         beam_prev_state = new_prev_state
         beam_noneos = prev_noneos * (new_tok_idxes!=0).float()
         beam_logprobs = prev_logprobs + new_logprobs
         beam_logprobs_paths = BK.concat([prev_logprobs_paths, new_logprobs.unsqueeze(-1)], -1)
         beam_tok_paths = BK.concat([prev_tok_paths, new_tok_idxes.unsqueeze(-1)], -1)
         beam_lab_paths = BK.concat([prev_lab_paths, new_lab_idxes.unsqueeze(-1)], -1)
     # finally force an extra eos step to get ending tok-logprob (no need to update other things)
     final_eos_idxes = BK.zeros([bsize, beam_size]).long()
     _, eos_logprobs, _, _, _, _ = self._step(beam_input_expr, beam_input_mask, beam_hard_coverage, beam_prev_state, final_eos_idxes, final_eos_idxes, None)
     beam_logprobs += eos_logprobs.squeeze(-1) * beam_noneos  # [*, beam]
     # select and return the best one
     beam_tok_valids = (beam_tok_paths > 0).float()  # [*, beam, steps]
     final_scores = beam_logprobs / ((beam_tok_valids.sum(-1) + 1.) ** conf.len_alpha)  # [*, beam]
     _, best_beam_idx = final_scores.max(-1)  # [*]
     # -----
     # prepare returns; cut by max length: [*, all_step] -> [*, max_step]
     ret0_valid_mask = beam_tok_valids[bsize_arange_t_1d, best_beam_idx]
     cur_max_step = ret0_valid_mask.long().sum(-1).max().item()
     ret_valid_mask = ret0_valid_mask[:, :cur_max_step]
     ret_logprobs = beam_logprobs_paths[bsize_arange_t_1d, best_beam_idx][:, :cur_max_step]
     ret_tok_idxes = beam_tok_paths[bsize_arange_t_1d, best_beam_idx][:, :cur_max_step]
     ret_lab_idxes = beam_lab_paths[bsize_arange_t_1d, best_beam_idx][:, :cur_max_step]
     # embeddings
     ret_lab_embeds = self._hl_lookup(ret_lab_idxes)
     return ret_logprobs, ret_tok_idxes, ret_valid_mask, ret_lab_idxes, ret_lab_embeds
Example #24
0
 def decode(self, inst: DocInstance):
     conf, model = self.conf, self.model
     model.refresh_batch(False)
     test_constrain_evt_types = self.test_constrain_evt_types
     with BK.no_grad_env():
         # =====
         # init the collections
         flattened_ef_ereprs, flattened_evt_ereprs = [], []
         sent_offsets = [Constants.INT_PRAC_MIN]*len(inst.sents)  # start offset of sent in the flattened erepr
         cur_offset = 0  # current offset
         all_ef_items, all_evt_items = [], []
         # =====
         # first basic run and ef and evt
         all_packs = model.bter.run([inst], training=False)
         for one_pack in all_packs:
             sent_insts, lexi_repr, enc_repr_ef, enc_repr_evt, mask_arr = one_pack
             mask_expr = BK.input_real(mask_arr)
             # =====
             # store the enc reprs and sent offsets
             sent_size, sent_len = BK.get_shape(enc_repr_ef)[:2]
             assert BK.get_shape(enc_repr_ef) == BK.get_shape(enc_repr_evt)
             flattened_ef_ereprs.append(enc_repr_ef.view(sent_size*sent_len, -1))  # [cur_flatten_size, D]
             flattened_evt_ereprs.append(enc_repr_evt.view(sent_size*sent_len, -1))
             for one_sent in sent_insts:
                 sent_offsets[one_sent.sid] = cur_offset
                 cur_offset += sent_len
             # =====
             lkrc = not conf.dec_debug_mode  # lookup.ret_copy?
             # =====
             # ef
             if conf.lookup_ef:
                 ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, ef_lab_embeds = \
                     model._lookup_mentions(sent_insts, lexi_repr, enc_repr_ef, mask_expr, model.ef_extractor, ret_copy=lkrc)
             else:
                 ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, ef_lab_embeds = \
                     model._inference_mentions(sent_insts, lexi_repr, enc_repr_ef, mask_expr, model.ef_extractor, model.ef_creator)
             # collect all valid ones
             all_ef_items.extend(ef_items[BK.get_value(ef_valid_mask).astype(np.bool)])
             # event
             if conf.lookup_evt:
                 evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, evt_lab_embeds = \
                     model._lookup_mentions(sent_insts, lexi_repr, enc_repr_evt, mask_expr, model.evt_extractor, ret_copy=lkrc)
             else:
                 evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, evt_lab_embeds = \
                     model._inference_mentions(sent_insts, lexi_repr, enc_repr_evt, mask_expr, model.evt_extractor, model.evt_creator)
             # collect all valid ones
             if test_constrain_evt_types is None:
                 all_evt_items.extend(evt_items[BK.get_value(evt_valid_mask).astype(np.bool)])
             else:
                 all_evt_items.extend([z for z in evt_items[BK.get_value(evt_valid_mask).astype(np.bool)]
                                       if z.type in test_constrain_evt_types])
         # ====
         # cross-sentence pairwise arg score
         # flattened all enc: [Offset, D]
         flattened_ef_enc_repr, flattened_evt_enc_repr = BK.concat(flattened_ef_ereprs, 0), BK.concat(flattened_evt_ereprs, 0)
         # sort by position in doc
         all_ef_items.sort(key=lambda x: x.mention.hard_span.position(True))
         all_evt_items.sort(key=lambda x: x.mention.hard_span.position(True))
         if not conf.dec_debug_mode:
             # todo(note): delete origin links!
             for z in all_ef_items:
                 if z is not None:
                     z.links.clear()
             for z in all_evt_items:
                 if z is not None:
                     z.links.clear()
         # get other info
         # todo(note): currently all using head word
         all_ef_offsets = BK.input_idx([sent_offsets[x.mention.hard_span.sid]+x.mention.hard_span.head_wid for x in all_ef_items])
         all_evt_offsets = BK.input_idx([sent_offsets[x.mention.hard_span.sid]+x.mention.hard_span.head_wid for x in all_evt_items])
         all_ef_lab_idxes = BK.input_idx([model.ef_extractor.hlidx2idx(x.type_idx) for x in all_ef_items])
         all_evt_lab_idxes = BK.input_idx([model.evt_extractor.hlidx2idx(x.type_idx) for x in all_evt_items])
         # score all the pairs (with mini-batch)
         mini_batch_size = conf.score_mini_batch
         arg_linker = model.arg_linker
         all_logprobs = BK.zeros([len(all_ef_items), len(all_evt_items), arg_linker.num_label])
         for bidx_ef in range(0, len(all_ef_items), mini_batch_size):
             cur_ef_enc_repr = flattened_ef_enc_repr[all_ef_offsets[bidx_ef:bidx_ef+mini_batch_size]].unsqueeze(0)
             cur_ef_lab_idxes = all_ef_lab_idxes[bidx_ef:bidx_ef+mini_batch_size].unsqueeze(0)
             for bidx_evt in range(0, len(all_evt_items), mini_batch_size):
                 cur_evt_enc_repr = flattened_evt_enc_repr[all_evt_offsets[bidx_evt:bidx_evt+mini_batch_size]].unsqueeze(0)
                 cur_evt_lab_idxes = all_evt_lab_idxes[bidx_evt:bidx_evt + mini_batch_size].unsqueeze(0)
                 all_logprobs[bidx_ef:bidx_ef+mini_batch_size,bidx_evt:bidx_evt+mini_batch_size] = \
                     arg_linker.predict(cur_ef_enc_repr, cur_evt_enc_repr, cur_ef_lab_idxes, cur_evt_lab_idxes,
                                        ret_full_logprobs=True).squeeze(0)
         all_logprobs_arr = BK.get_value(all_logprobs)
     # =====
     # then decode them all using the scores
     self.arg_decode(inst, all_ef_items, all_evt_items, all_logprobs_arr)
     # =====
     # assign and return
     num_pred_arg = 0
     for one_sent in inst.sents:
         one_sent.pred_entity_fillers.clear()
         one_sent.pred_events.clear()
     for z in all_ef_items:
         inst.sents[z.mention.hard_span.sid].pred_entity_fillers.append(z)
     for z in all_evt_items:
         inst.sents[z.mention.hard_span.sid].pred_events.append(z)
         num_pred_arg += len(z.links)
     info = {"doc": 1, "sent": len(inst.sents), "token": sum(s.length-1 for s in inst.sents),
             "p_ef": len(all_ef_items), "p_evt": len(all_evt_items), "p_arg": num_pred_arg}
     return info
Example #25
0
 def loss(self, repr_t, orig_map: Dict, **kwargs):
     conf = self.conf
     _tie_input_embeddings = conf.tie_input_embeddings
     # --
     # specify input
     add_root_token = self.add_root_token
     # get from inputs
     if isinstance(repr_t, (list, tuple)):
         l2r_repr_t, r2l_repr_t = repr_t
     elif self.split_input_blm:
         l2r_repr_t, r2l_repr_t = BK.chunk(repr_t, 2, -1)
     else:
         l2r_repr_t, r2l_repr_t = repr_t, None
     # l2r and r2l
     word_t = BK.input_idx(orig_map["word"])  # [bs, rlen]
     slice_zero_t = BK.zeros([BK.get_shape(word_t, 0), 1]).long()  # [bs, 1]
     if add_root_token:
         l2r_trg_t = BK.concat([word_t, slice_zero_t],
                               -1)  # pad one extra 0, [bs, rlen+1]
         r2l_trg_t = BK.concat(
             [slice_zero_t, slice_zero_t, word_t[:, :-1]],
             -1)  # pad two extra 0 at front, [bs, 2+rlen-1]
     else:
         l2r_trg_t = BK.concat(
             [word_t[:, 1:], slice_zero_t], -1
         )  # pad one extra 0, but remove the first one, [bs, -1+rlen+1]
         r2l_trg_t = BK.concat(
             [slice_zero_t, word_t[:, :-1]],
             -1)  # pad one extra 0 at front, [bs, 1+rlen-1]
     # gather the losses
     all_losses = []
     pred_range_min, pred_range_max = max(
         1, conf.min_pred_rank), self.pred_size - 1
     if _tie_input_embeddings:
         pred_W = self.inputter_embed_node.E.E[:self.
                                               pred_size]  # [PSize, Dim]
     else:
         pred_W = None
     # get input embeddings for output
     for pred_name, hid_node, pred_node, input_t, trg_t in \
                 zip(["l2r", "r2l"], [self.l2r_hid_layer, self.r2l_hid_layer], [self.l2r_pred, self.r2l_pred],
                     [l2r_repr_t, r2l_repr_t], [l2r_trg_t, r2l_trg_t]):
         if input_t is None:
             continue
         # hidden
         hid_t = hid_node(
             input_t) if hid_node else input_t  # [bs, slen, hid]
         # pred: [bs, slen, Vsize]
         if _tie_input_embeddings:
             scores_t = BK.matmul(hid_t, pred_W.T)
         else:
             scores_t = pred_node(hid_t)
         # loss
         mask_t = ((trg_t >= pred_range_min) &
                   (trg_t <= pred_range_max)).float()  # [bs, slen]
         trg_t.clamp_(max=pred_range_max)  # make it in range
         losses_t = BK.loss_nll(scores_t, trg_t) * mask_t  # [bs, slen]
         _, argmax_idxes = scores_t.max(-1)  # [bs, slen]
         corrs_t = (argmax_idxes == trg_t).float() * mask_t  # [bs, slen]
         # compile leaf loss
         one_loss = LossHelper.compile_leaf_info(pred_name,
                                                 losses_t.sum(),
                                                 mask_t.sum(),
                                                 loss_lambda=1.,
                                                 corr=corrs_t.sum())
         all_losses.append(one_loss)
     return self._compile_component_loss("plm", all_losses)