Example #1
0
 def _cascade_scores(self, raw_scores: List):
     all_scores = []
     for i, one_scores in enumerate(raw_scores):
         if i > 0:
             cur_prei = self.layered_prei[i]  # [?]
             prev_score = BK.select(all_scores[-1], cur_prei, -1)  # [*, ?]
             cur_score = prev_score + one_scores
         else:
             cur_score = one_scores
         all_scores.append(cur_score)
     return all_scores
Example #2
0
 def run(self, insts: List[DocInstance], training: bool):
     conf = self.conf
     BERT_MAX_LEN = 510  # save 2 for CLS and SEP
     # =====
     # encoder 1: the basic encoder
     # todo(note): only DocInstane input for this mode, otherwise will break
     if conf.m2e_use_basic:
         reidx_pad_len = conf.ms_extend_budget
         # enc the basic part + also get some indexes
         sentid2offset = {}  # id(sent)->overall_seq_offset
         seq_offset = 0  # if look at the docs in one seq
         all_sents = []  # (inst, d_idx, s_idx)
         for d_idx, one_doc in enumerate(insts):
             assert isinstance(one_doc, DocInstance)
             for s_idx, one_sent in enumerate(one_doc.sents):
                 # todo(note): here we encode all the sentences
                 all_sents.append((one_sent, d_idx, s_idx))
                 sentid2offset[id(one_sent)] = seq_offset
                 seq_offset += one_sent.length - 1  # exclude extra ROOT node
         sent_reprs = self.run_sents(all_sents, insts, training)
         # flatten and concatenate and re-index
         reidxes_arr = np.zeros(
             seq_offset + reidx_pad_len, dtype=np.long
         )  # todo(note): extra padding to avoid out of boundary
         all_flattened_reprs = []
         all_flatten_offset = 0  # the local offset for batched basic encoding
         for one_pack in sent_reprs:
             one_sents, _, one_repr_ef, one_repr_evt, _ = one_pack
             assert one_repr_ef is one_repr_evt, "Currently does not support separate basic enc in m3 mode"
             one_repr_t = one_repr_evt
             _, one_slen, one_ldim = BK.get_shape(one_repr_t)
             all_flattened_reprs.append(one_repr_t.view([-1, one_ldim]))
             # fill in the indexes
             for one_sent in one_sents:
                 cur_start_offset = sentid2offset[id(one_sent)]
                 cur_real_slen = one_sent.length - 1
                 # again, +1 to get rid of extra ROOT
                 reidxes_arr[cur_start_offset:cur_start_offset+cur_real_slen] = \
                     np.arange(cur_real_slen, dtype=np.long) + (all_flatten_offset+1)
                 all_flatten_offset += one_slen  # here add the slen in batched version
         # re-idxing
         seq_sent_repr0 = BK.concat(all_flattened_reprs, 0)
         seq_sent_repr = BK.select(seq_sent_repr0, reidxes_arr,
                                   0)  # [all_seq_len, D]
     else:
         sentid2offset = defaultdict(int)
         seq_sent_repr = None
     # =====
     # repack and prepare for multiple sent enc
     # todo(note): here, the criterion is based on bert's tokenizer
     all_ms_info = []
     if isinstance(insts[0], DocInstance):
         for d_idx, one_doc in enumerate(insts):
             for s_idx, x in enumerate(one_doc.sents):
                 # the basic criterion is the same as the basic one
                 include_flag = False
                 if training:
                     if x.length<self.train_skip_length and x.length>=self.train_min_length \
                             and (len(x.events)>0 or next(self.random_sample_stream)>self.train_skip_noevt_rate):
                         include_flag = True
                 else:
                     if x.length >= self.test_min_length:
                         include_flag = True
                 if include_flag:
                     all_ms_info.append(
                         x.preps["ms"])  # use the pre-calculated one
     else:
         # multisent based
         all_ms_info = insts.copy()  # shallow copy
     # =====
     # encoder 2: the bert one (multi-sent encoding)
     ms_size_f = lambda x: x.subword_size
     all_ms_info.sort(key=ms_size_f)
     all_ms_buckets = self._bucket_sents_by_length(
         all_ms_info,
         conf.benc_bucket_range,
         ms_size_f,
         max_bsize=conf.benc_bucket_msize)
     berter = self.berter
     rets = []
     bert_use_center_typeids = conf.bert_use_center_typeids
     bert_use_special_typeids = conf.bert_use_special_typeids
     bert_other_inputs = conf.bert_other_inputs
     for one_bucket in all_ms_buckets:
         # prepare
         batched_ids = []
         batched_starts = []
         batched_seq_offset = []
         batched_typeids = []
         batched_other_inputs_list: List = [
             [] for _ in bert_other_inputs
         ]  # List(comp) of List(batch) of List(idx)
         for one_item in one_bucket:
             one_sents = one_item.sents
             one_center_sid = one_item.center_idx
             one_ids, one_starts, one_typeids = [], [], []
             one_other_inputs_list = [[] for _ in bert_other_inputs
                                      ]  # List(comp) of List(idx)
             for one_sid, one_sent in enumerate(one_sents):  # for bert
                 one_bidxes = one_sent.preps["bidx"]
                 one_ids.extend(one_bidxes.subword_ids)
                 one_starts.extend(one_bidxes.subword_is_start)
                 # prepare other inputs
                 for this_field_name, this_tofill_list in zip(
                         bert_other_inputs, one_other_inputs_list):
                     this_tofill_list.extend(
                         one_sent.preps["sub_" + this_field_name])
                 # todo(note): special procedure
                 if bert_use_center_typeids:
                     if one_sid != one_center_sid:
                         one_typeids.extend([0] *
                                            len(one_bidxes.subword_ids))
                     else:
                         this_typeids = [1] * len(one_bidxes.subword_ids)
                         if bert_use_special_typeids:
                             # todo(note): this is the special mode that we are given the events!!
                             for this_event in one_sents[
                                     one_center_sid].events:
                                 _, this_wid, this_wlen = this_event.mention.hard_span.position(
                                     headed=False)
                                 for a, b in one_item.center_word2sub[
                                         this_wid - 1:this_wid - 1 +
                                         this_wlen]:
                                     this_typeids[a:b] = [0] * (b - a)
                         one_typeids.extend(this_typeids)
             batched_ids.append(one_ids)
             batched_starts.append(one_starts)
             batched_typeids.append(one_typeids)
             for comp_one_oi, comp_batched_oi in zip(
                     one_other_inputs_list, batched_other_inputs_list):
                 comp_batched_oi.append(comp_one_oi)
             # for basic part
             batched_seq_offset.append(sentid2offset[id(one_sents[0])])
         # bert forward: [bs, slen, fold, D]
         if not bert_use_center_typeids:
             batched_typeids = None
         bert_expr0, mask_expr = berter.forward_batch(
             batched_ids,
             batched_starts,
             batched_typeids,
             training=training,
             other_inputs=batched_other_inputs_list)
         if self.m3_enc_is_empty:
             bert_expr = bert_expr0
         else:
             mask_arr = BK.get_value(mask_expr)  # [bs, slen]
             m3e_exprs = [
                 cur_enc(bert_expr0[:, :, cur_i], mask_arr)
                 for cur_i, cur_enc in enumerate(self.m3_encs)
             ]
             bert_expr = BK.stack(m3e_exprs, -2)  # on the fold dim again
         # collect basic ones: [bs, slen, D'] or None
         if seq_sent_repr is not None:
             arange_idxes_t = BK.arange_idx(BK.get_shape(
                 mask_expr, -1)).unsqueeze(0)  # [1, slen]
             offset_idxes_t = BK.input_idx(batched_seq_offset).unsqueeze(
                 -1) + arange_idxes_t  # [bs, slen]
             basic_expr = seq_sent_repr[offset_idxes_t]  # [bs, slen, D']
         elif conf.m2e_use_basic_dep:
             # collect each token's head-bert and ud-label, then forward with adp
             fake_sents = [one_item.fake_sent for one_item in one_bucket]
             # head idx and labels, no artificial ROOT
             padded_head_arr, _ = self.dep_padder.pad(
                 [s.ud_heads.vals[1:] for s in fake_sents])
             padded_label_arr, _ = self.dep_padder.pad(
                 [s.ud_labels.idxes[1:] for s in fake_sents])
             # get tensor
             padded_head_t = (BK.input_idx(padded_head_arr) - 1
                              )  # here, the idx exclude root
             padded_head_t.clamp_(min=0)  # [bs, slen]
             padded_label_t = BK.input_idx(padded_label_arr)
             # get inputs
             input_head_bert_t = bert_expr[
                 BK.arange_idx(len(fake_sents)).unsqueeze(-1),
                 padded_head_t]  # [bs, slen, fold, D]
             input_label_emb_t = self.dep_label_emb(
                 padded_label_t)  # [bs, slen, D']
             basic_expr = self.dep_layer(
                 input_head_bert_t, None,
                 [input_label_emb_t])  # [bs, slen, ?]
         elif conf.m2e_use_basic_plus:
             sent_reprs = self.run_sents([(one_item.fake_sent, None, None)
                                          for one_item in one_bucket],
                                         insts,
                                         training,
                                         use_one_bucket=True)
             assert len(
                 sent_reprs
             ) == 1, "Unsupported split reprs for basic encoder, please set enc_bucket_range<=benc_bucket_range"
             _, _, one_repr_ef, one_repr_evt, _ = sent_reprs[0]
             assert one_repr_ef is one_repr_evt, "Currently does not support separate basic enc in m3 mode"
             basic_expr = one_repr_evt[:, 1:]  # exclude ROOT, [bs, slen, D]
             assert BK.get_shape(basic_expr)[:2] == BK.get_shape(
                 bert_expr)[:2]
         else:
             basic_expr = None
         # pack: (List[ms_item], bert_expr, basic_expr)
         rets.append((one_bucket, bert_expr, basic_expr))
     return rets