def _cascade_scores(self, raw_scores: List): all_scores = [] for i, one_scores in enumerate(raw_scores): if i > 0: cur_prei = self.layered_prei[i] # [?] prev_score = BK.select(all_scores[-1], cur_prei, -1) # [*, ?] cur_score = prev_score + one_scores else: cur_score = one_scores all_scores.append(cur_score) return all_scores
def run(self, insts: List[DocInstance], training: bool): conf = self.conf BERT_MAX_LEN = 510 # save 2 for CLS and SEP # ===== # encoder 1: the basic encoder # todo(note): only DocInstane input for this mode, otherwise will break if conf.m2e_use_basic: reidx_pad_len = conf.ms_extend_budget # enc the basic part + also get some indexes sentid2offset = {} # id(sent)->overall_seq_offset seq_offset = 0 # if look at the docs in one seq all_sents = [] # (inst, d_idx, s_idx) for d_idx, one_doc in enumerate(insts): assert isinstance(one_doc, DocInstance) for s_idx, one_sent in enumerate(one_doc.sents): # todo(note): here we encode all the sentences all_sents.append((one_sent, d_idx, s_idx)) sentid2offset[id(one_sent)] = seq_offset seq_offset += one_sent.length - 1 # exclude extra ROOT node sent_reprs = self.run_sents(all_sents, insts, training) # flatten and concatenate and re-index reidxes_arr = np.zeros( seq_offset + reidx_pad_len, dtype=np.long ) # todo(note): extra padding to avoid out of boundary all_flattened_reprs = [] all_flatten_offset = 0 # the local offset for batched basic encoding for one_pack in sent_reprs: one_sents, _, one_repr_ef, one_repr_evt, _ = one_pack assert one_repr_ef is one_repr_evt, "Currently does not support separate basic enc in m3 mode" one_repr_t = one_repr_evt _, one_slen, one_ldim = BK.get_shape(one_repr_t) all_flattened_reprs.append(one_repr_t.view([-1, one_ldim])) # fill in the indexes for one_sent in one_sents: cur_start_offset = sentid2offset[id(one_sent)] cur_real_slen = one_sent.length - 1 # again, +1 to get rid of extra ROOT reidxes_arr[cur_start_offset:cur_start_offset+cur_real_slen] = \ np.arange(cur_real_slen, dtype=np.long) + (all_flatten_offset+1) all_flatten_offset += one_slen # here add the slen in batched version # re-idxing seq_sent_repr0 = BK.concat(all_flattened_reprs, 0) seq_sent_repr = BK.select(seq_sent_repr0, reidxes_arr, 0) # [all_seq_len, D] else: sentid2offset = defaultdict(int) seq_sent_repr = None # ===== # repack and prepare for multiple sent enc # todo(note): here, the criterion is based on bert's tokenizer all_ms_info = [] if isinstance(insts[0], DocInstance): for d_idx, one_doc in enumerate(insts): for s_idx, x in enumerate(one_doc.sents): # the basic criterion is the same as the basic one include_flag = False if training: if x.length<self.train_skip_length and x.length>=self.train_min_length \ and (len(x.events)>0 or next(self.random_sample_stream)>self.train_skip_noevt_rate): include_flag = True else: if x.length >= self.test_min_length: include_flag = True if include_flag: all_ms_info.append( x.preps["ms"]) # use the pre-calculated one else: # multisent based all_ms_info = insts.copy() # shallow copy # ===== # encoder 2: the bert one (multi-sent encoding) ms_size_f = lambda x: x.subword_size all_ms_info.sort(key=ms_size_f) all_ms_buckets = self._bucket_sents_by_length( all_ms_info, conf.benc_bucket_range, ms_size_f, max_bsize=conf.benc_bucket_msize) berter = self.berter rets = [] bert_use_center_typeids = conf.bert_use_center_typeids bert_use_special_typeids = conf.bert_use_special_typeids bert_other_inputs = conf.bert_other_inputs for one_bucket in all_ms_buckets: # prepare batched_ids = [] batched_starts = [] batched_seq_offset = [] batched_typeids = [] batched_other_inputs_list: List = [ [] for _ in bert_other_inputs ] # List(comp) of List(batch) of List(idx) for one_item in one_bucket: one_sents = one_item.sents one_center_sid = one_item.center_idx one_ids, one_starts, one_typeids = [], [], [] one_other_inputs_list = [[] for _ in bert_other_inputs ] # List(comp) of List(idx) for one_sid, one_sent in enumerate(one_sents): # for bert one_bidxes = one_sent.preps["bidx"] one_ids.extend(one_bidxes.subword_ids) one_starts.extend(one_bidxes.subword_is_start) # prepare other inputs for this_field_name, this_tofill_list in zip( bert_other_inputs, one_other_inputs_list): this_tofill_list.extend( one_sent.preps["sub_" + this_field_name]) # todo(note): special procedure if bert_use_center_typeids: if one_sid != one_center_sid: one_typeids.extend([0] * len(one_bidxes.subword_ids)) else: this_typeids = [1] * len(one_bidxes.subword_ids) if bert_use_special_typeids: # todo(note): this is the special mode that we are given the events!! for this_event in one_sents[ one_center_sid].events: _, this_wid, this_wlen = this_event.mention.hard_span.position( headed=False) for a, b in one_item.center_word2sub[ this_wid - 1:this_wid - 1 + this_wlen]: this_typeids[a:b] = [0] * (b - a) one_typeids.extend(this_typeids) batched_ids.append(one_ids) batched_starts.append(one_starts) batched_typeids.append(one_typeids) for comp_one_oi, comp_batched_oi in zip( one_other_inputs_list, batched_other_inputs_list): comp_batched_oi.append(comp_one_oi) # for basic part batched_seq_offset.append(sentid2offset[id(one_sents[0])]) # bert forward: [bs, slen, fold, D] if not bert_use_center_typeids: batched_typeids = None bert_expr0, mask_expr = berter.forward_batch( batched_ids, batched_starts, batched_typeids, training=training, other_inputs=batched_other_inputs_list) if self.m3_enc_is_empty: bert_expr = bert_expr0 else: mask_arr = BK.get_value(mask_expr) # [bs, slen] m3e_exprs = [ cur_enc(bert_expr0[:, :, cur_i], mask_arr) for cur_i, cur_enc in enumerate(self.m3_encs) ] bert_expr = BK.stack(m3e_exprs, -2) # on the fold dim again # collect basic ones: [bs, slen, D'] or None if seq_sent_repr is not None: arange_idxes_t = BK.arange_idx(BK.get_shape( mask_expr, -1)).unsqueeze(0) # [1, slen] offset_idxes_t = BK.input_idx(batched_seq_offset).unsqueeze( -1) + arange_idxes_t # [bs, slen] basic_expr = seq_sent_repr[offset_idxes_t] # [bs, slen, D'] elif conf.m2e_use_basic_dep: # collect each token's head-bert and ud-label, then forward with adp fake_sents = [one_item.fake_sent for one_item in one_bucket] # head idx and labels, no artificial ROOT padded_head_arr, _ = self.dep_padder.pad( [s.ud_heads.vals[1:] for s in fake_sents]) padded_label_arr, _ = self.dep_padder.pad( [s.ud_labels.idxes[1:] for s in fake_sents]) # get tensor padded_head_t = (BK.input_idx(padded_head_arr) - 1 ) # here, the idx exclude root padded_head_t.clamp_(min=0) # [bs, slen] padded_label_t = BK.input_idx(padded_label_arr) # get inputs input_head_bert_t = bert_expr[ BK.arange_idx(len(fake_sents)).unsqueeze(-1), padded_head_t] # [bs, slen, fold, D] input_label_emb_t = self.dep_label_emb( padded_label_t) # [bs, slen, D'] basic_expr = self.dep_layer( input_head_bert_t, None, [input_label_emb_t]) # [bs, slen, ?] elif conf.m2e_use_basic_plus: sent_reprs = self.run_sents([(one_item.fake_sent, None, None) for one_item in one_bucket], insts, training, use_one_bucket=True) assert len( sent_reprs ) == 1, "Unsupported split reprs for basic encoder, please set enc_bucket_range<=benc_bucket_range" _, _, one_repr_ef, one_repr_evt, _ = sent_reprs[0] assert one_repr_ef is one_repr_evt, "Currently does not support separate basic enc in m3 mode" basic_expr = one_repr_evt[:, 1:] # exclude ROOT, [bs, slen, D] assert BK.get_shape(basic_expr)[:2] == BK.get_shape( bert_expr)[:2] else: basic_expr = None # pack: (List[ms_item], bert_expr, basic_expr) rets.append((one_bucket, bert_expr, basic_expr)) return rets