def aug_word2_vocab(self, stream, extra_stream, extra_embed_file: str): zlog( f"Aug another word vocab from streams and extra_embed_file={extra_embed_file}" ) word_builder = VocabBuilder("word2") for inst in stream: word_builder.feed_stream(inst.word_seq.vals) # embeddings if len(extra_embed_file) > 0: extra_word_set = set(w for inst in extra_stream for w in inst.word_seq.vals) w2vec = WordVectors.load(extra_embed_file) for w in extra_word_set: if w2vec.has_key(w) and ( not word_builder.has_key_currently(w)): word_builder.feed_one(w) word_vocab = word_builder.finish() # no filtering!! word_embed1 = word_vocab.filter_embed(w2vec, init_nohit=1.0, scale=1.0) else: zwarn("WARNING: No pretrain file for aug node!!") word_vocab = word_builder.finish() # no filtering!! word_embed1 = None self.put_voc("word2", word_vocab) self.put_emb("word2", word_embed1)
def disable_final_dropout(self): if len(self.layers) < 1: zwarn( "Cannot disable final dropout since this Enc layer is empty!!") else: # get the final one from sequential final_layer = self.layers[-1] while isinstance(final_layer, Sequential): final_layer = final_layer.ns_[-1] if len( final_layer.ns_) else None # get final dropout node final_drop_node: Dropout = None if isinstance(final_layer, RnnLayerBatchFirstWrapper): final_drop_nodes = final_layer.rnn_node.drop_nodes if final_drop_nodes is not None and len(final_drop_nodes) > 0: final_drop_node = final_drop_nodes[-1] elif isinstance(final_layer, CnnLayer): final_drop_node = final_layer.drop_node elif isinstance(final_layer, TransformerEncoder): pass # todo(note): final is LayerNorm? if final_drop_node is None: zwarn( f"Failed at disabling final enc-layer dropout: type={type(final_layer)}: {final_layer}" ) else: final_drop_node.rop.add_fixed_value("hdrop", 0.) zlog( f"Ok at disabling final enc-layer dropout: type={type(final_layer)}: {final_layer}" )
def __init__(self, pc: BK.ParamCollection, conf: MaskLMNodeConf, vpack: VocabPackage): super().__init__(pc, None, None) self.conf = conf # vocab and padder self.word_vocab = vpack.get_voc("word") self.padder = DataPadder( 2, pad_vals=self.word_vocab.pad, mask_range=2) # todo(note): <pad>-id is very large # models self.hid_layer = self.add_sub_node( "hid", Affine(pc, conf._input_dim, conf.hid_dim, act=conf.hid_act)) self.pred_layer = self.add_sub_node( "pred", Affine(pc, conf.hid_dim, conf.max_pred_rank + 1, init_rop=NoDropRop())) if conf.init_pred_from_pretrain: npvec = vpack.get_emb("word") if npvec is None: zwarn( "Pretrained vector not provided, skip init pred embeddings!!" ) else: with BK.no_grad_env(): self.pred_layer.ws[0].copy_( BK.input_real(npvec[:conf.max_pred_rank + 1].T)) zlog( f"Init pred embeddings from pretrained vectors (size={conf.max_pred_rank+1})." )
def find_head(posi, sentences: List[Sentence], is_event: bool): sid, wid, wlen = posi idx_start, idx_end = wid, wid + wlen assert wlen > 0 if wlen == 1: # only one word return wid cur_ddists = sentences[sid].get_ddist() cur_heads = sentences[sid].ud_heads.vals cur_poses = sentences[sid].uposes.vals # todo(note): rule 1: simply find the highest node (nearest to root and not punct) # first pass by ddist min_ddist = min(cur_ddists[z] for z in range(idx_start, idx_end)) cand_idxes1 = [ z for z in range(idx_start, idx_end) if cur_ddists[z] <= min_ddist ] assert len(cand_idxes1) > 0 if len(cand_idxes1) == 1: return cand_idxes1[0] # next pass by POS POS_SCORES_MAP = MyDocReader.VERB_HEAD_SCORES if is_event else MyDocReader.NOUN_HEAD_SCORES pos_scores = [ POS_SCORES_MAP.get(cur_poses[z], -100) for z in cand_idxes1 ] max_pos_score = max(pos_scores) cand_idxes2 = [ v for i, v in enumerate(cand_idxes1) if pos_scores[i] >= max_pos_score ] assert len(cand_idxes2) > 0 if len(cand_idxes2) == 1: return cand_idxes2[0] # todo(note): rule 2: if same head and same pos, use the rightmost one # todo(+N): fine only for English? cand_idxes = cand_idxes2 cand_heads, cand_poses = [cur_heads[z] for z in cand_idxes ], [cur_poses[z] for z in cand_idxes] if all(z == cand_heads[0] for z in cand_heads) and all(z == cand_poses[0] for z in cand_poses): return cand_idxes[-1] if all(z == "PROPN" for z in cand_poses): return cand_idxes[-1] if all(z == "NUM" for z in cand_poses): return cand_idxes[-1] # todo(note): extra one: AUX+PART like "did not" if cand_poses == ["AUX", "PART"]: return cand_idxes[0] # todo(note): rule final: simply the rightmost if 1: cur_words = sentences[sid].words.vals ranged_words = cur_words[idx_start:idx_end] ranged_ddists = cur_ddists[idx_start:idx_end] ranged_heads = cur_heads[idx_start:idx_end] ranged_poses = cur_poses[idx_start:idx_end] zwarn( f"Cannot heuristically set head (is_event={is_event}), use the last one: " f"{ranged_words} {ranged_ddists} {ranged_heads} {ranged_poses}" ) return cand_idxes[-1]
def prepare_test(args, ConfType=None): # conf conf: OverallConf = init_everything(args, ConfType) dconf, mconf = conf.dconf, conf.mconf iconf = mconf.iconf # vocab vpack = IEVocabPackage.build_by_reading(conf) # prepare data test_streamer = get_data_reader(dconf.test, dconf.input_format, dconf.use_label0, dconf.noef_link0, dconf.aux_repr_test, max_evt_layers=dconf.max_evt_layers) # model model = build_model(conf.model_type, conf, vpack) if dconf.model_load_name != "": model.load(dconf.model_load_name) else: zwarn("No model to load, Debugging mode??") # ===== # augment with extra embeddings extra_embed_files = dconf.test_extra_pretrain_files if len(extra_embed_files) > 0: # get embeddings extra_codes = [] # todo(note): ignore this mode for this project if len(extra_codes) == 0: extra_codes = [""] * len(extra_embed_files) extra_embedding = WordVectors.load(extra_embed_files[0], aug_code=extra_codes[0]) extra_embedding.merge_others([ WordVectors.load(one_file, aug_code=one_code) for one_file, one_code in zip( extra_embed_files[1:], extra_codes[1:]) ]) # get extra dictionary (only those words hit in extra-embed) extra_vocab = VocabBuilder.build_from_stream(iter_hit_words( test_streamer, extra_embedding), sort_by_count=True, pre_list=(), post_list=()) # give them to the model new_vocab = model.aug_words_and_embs(extra_vocab, extra_embedding) vpack.put_voc("word", new_vocab) # ===== # use bert? todo(note): no pre-compute here in testing! if dconf.use_bert: bmodel = get_berter(dconf.bconf) test_streamer = BerterDataAuger(test_streamer, bmodel, "aux_repr") # # No Cache!! test_inst_preparer = model.get_inst_preper(False) test_iter = batch_stream( index_stream(test_streamer, vpack, False, False, test_inst_preparer), iconf, False) return conf, model, vpack, test_iter
def expand_span(self, head_wid, sent): doc_id, sid = sent.doc.doc_id, sent.sid key = f"{doc_id}_{sid}_{head_wid}" if key not in self.tables: zwarn("Not covered from external-table!!") return head_wid, 1 # simply return singleton else: head, start, end = self.tables[key] assert head == head_wid - 1 return start + 1, end + 1 - start
def __init__(self, pc: BK.ParamCollection, n_dim: int, max_val: int = 5000, min_val: int = None, init_sincos: bool = True, freeze: bool = True, no_dropout: bool = True, init_scale=1., zero0=False): super().__init__(pc, None, None) # set range assert max_val >= 0 if min_val is None: min_val = -max_val self.min_val, self.max_val = min_val, max_val # how to init self.dim = n_dim self.init_sincos = init_sincos self.freeze = freeze self.no_dropout = no_dropout if init_sincos: pe = PosiEmbedding2.init_sincos_arr(min_val, max_val, n_dim, init_scale) # [all_size, d] if freeze: self.rop.add_fixed_value("trainable", False) else: zwarn( "Init-from sin/cos positional embeddings should be freezed?" ) else: pe = None assert not freeze # params self.offset = 0 - self.min_val # used to make things into positive for indexing if pe is not None and freeze: self.E = BK.input_real(pe) # simply static one else: self.E = self.add_param("E", (max_val - min_val + 1, n_dim), init=pe, lookup=True, scale=init_scale) # zero 0 if zero0: if min_val <= 0 and max_val >= 0: BK.zero_row(self.E, self.offset) else: zwarn(f"Cannot zero0 for PosiEmbed2 of [{min_val}, {max_val}]") # ----- if no_dropout: self.drop_node = lambda x: x else: self.drop_node = self.add_sub_node("drop", Dropout(pc, (self.dim, )))
def get_data_writer(file_or_fd, output_format): if output_format == "conllu": return ParserConlluWriter(file_or_fd) elif output_format == "plain": zwarn("May be meaningless to write plain files for parses!") return ParserPlainWriter(file_or_fd) elif output_format == "json": return ParserJsonWriter(file_or_fd) else: zfatal( "Unknown output_format %s, should select from {conllu,plain,json}" % output_format)
def __init__(self): # whether allow multi-source mode self.multi_source = False # split file names for inputs (use multi_reader) # data paths self.train = "" self.dev = "" self.test = "" self.cache_data = True # turn off if large data self.to_cache_shuffle = False self.dict_dir = "./" # cuttings for training (simply for convenience without especially preparing data...) self.cut_train = "" self.cut_dev = "" # extra aux inputs, must be aligned with the corresponded data self.aux_repr_train = "" self.aux_repr_dev = "" self.aux_repr_test = "" self.aux_score_train = "" self.aux_score_dev = "" self.aux_score_test = "" # save name for trainer not here!! self.model_load_name = "zmodel.best" # load name self.output_file = "zout" # format (conllu, plain, json) self.input_format = "conllu" self.output_format = "conllu" # pretrain self.pretrain_file = [] self.init_from_pretrain = False self.pretrain_scale = 1.0 self.pretrain_init_nohit = 1.0 # thresholds for word self.word_rthres = 50000 # rank <= this self.word_sthres = 2 # the threshold of considering singleton treating (freq<=this) self.word_fthres = 1 # freq >= this # special processing self.lower_case = False self.norm_digit = False # norm digits to 0 self.use_label0 = True # using only first-level label # todo(note): change the default behaviour zwarn( "Note: currently we change default value of 'use_label0' to True!") self.vocab_add_prevalues = True # add pre-defined UDv2 values when building dicts # ===== # for multi-lingual processing (another option is to pre-processing suitable data) # language code (empty str for no effects) self.code_train = "" self.code_dev = "" self.code_test = "" self.code_pretrain = [] # testing mode extra embeddings self.test_extra_pretrain_files = [] self.test_extra_pretrain_codes = []
def __init__(self, conf: TdParserConf, vpack: VocabPackage): super().__init__(conf, vpack) # ===== For decoding ===== self.inferencer = TdInferencer(self.scorer, conf.iconf) # ===== For training ===== sched_depth = ScheduledValue("depth", conf.tconf.sched_depth) self.add_scheduled_values(sched_depth) self.fber = TdFber(self.scorer, conf.iconf, conf.tconf, self.margin, self.sched_sampling, sched_depth) # todo(warn): not elegant, global flag! TdState.is_bfs = conf.is_bfs # ===== zcheck(not self.bter.jpos_multitask_enabled(), "Not implemented for joint pos in this mode!!") zwarn("WARN: This topdown mode is deprecated!!")
def load(self, prefix="./"): for name in self.vocabs: fname = prefix+"vv_"+name+".txt" if FileHelper.exists(fname): self.vocabs[name] = Vocab.read(fname) else: zwarn("Cannot find Vocab " + name) self.vocabs[name] = None for name in self.embeds: fname = prefix+"ve_"+name+".pic" if FileHelper.exists(fname): self.embeds[name] = PickleRW.from_file(fname) else: self.embeds[name] = None
def main(args): conf = init_everything(args) dconf, mconf = conf.dconf, conf.mconf # dev/test can be non-existing! if not dconf.dev and dconf.test: utils.zwarn("No dev but give test, actually use test as dev (for early stopping)!!") dt_golds, dt_cuts = [], [] for file, one_cut in [(dconf.dev, dconf.cut_dev), (dconf.test, "")]: # no cut for test! if len(file)>0: utils.zlog(f"Add file `{file}(cut={one_cut})' as dt-file #{len(dt_golds)}.") dt_golds.append(file) dt_cuts.append(one_cut) if len(dt_golds) == 0: utils.zwarn("No dev set, then please specify static lrate schedule!!") # data train_streamer = PreprocessStreamer(get_data_reader(dconf.train, dconf.input_format, cut=dconf.cut_train), lower_case=dconf.lower_case, norm_digit=dconf.norm_digit) dt_streamers = [PreprocessStreamer(get_data_reader(f, dconf.dev_input_format, cut=one_cut), lower_case=dconf.lower_case, norm_digit=dconf.norm_digit) for f, one_cut in zip(dt_golds, dt_cuts)] # vocab if mconf.no_build_dict: vpack = MLMVocabPackage.build_by_reading(dconf.dict_dir) else: # include dev/test only for convenience of including words hit in pre-trained embeddings vpack = MLMVocabPackage.build_from_stream(dconf.vconf, train_streamer, MultiCatStreamer(dt_streamers)) vpack.save(dconf.dict_dir) # model model = build_model(conf, vpack) # index the data train_inst_preparer = model.get_inst_preper(True) test_inst_preparer = model.get_inst_preper(False) to_cache = dconf.cache_data to_cache_shuffle = dconf.to_cache_shuffle # todo(note): make sure to cache both train and dev to save time for cached computation backoff_pos_idx = dconf.backoff_pos_idx train_iter = batch_stream(index_stream(train_streamer, vpack, to_cache, to_cache_shuffle, train_inst_preparer, backoff_pos_idx), mconf.train_batch_size, mconf, True) dt_iters = [batch_stream(index_stream(z, vpack, to_cache, to_cache_shuffle, test_inst_preparer, backoff_pos_idx), mconf.test_batch_size, mconf, False) for z in dt_streamers] # training runner tr = MltTrainingRunner(mconf.rconf, model, vpack, dev_outfs=dconf.output_file, dev_goldfs=dt_golds, dev_out_format=dconf.output_format) if mconf.train_preload_model: tr.load(dconf.model_load_name, mconf.train_preload_process) # ===== # switch with the linear model linear_model = LinearProbeModel(model) tr.model = linear_model # ===== # go tr.run(train_iter, dt_iters) utils.zlog("The end of Training.")
def aug_words_and_embs(model, aug_vocab, aug_wv): orig_vocab = model.word_vocab word_emb_node = model.embedder.get_node('word') if word_emb_node is not None: orig_arr = word_emb_node.E.detach().cpu().numpy() # todo(+2): find same-spelling words in the original vocab if not-hit in the extra_embed? # todo(warn): here aug_vocab should be find in aug_wv aug_arr = aug_vocab.filter_embed(aug_wv, assert_all_hit=True) new_vocab, new_arr = MultiHelper.aug_vocab_and_arr(orig_vocab, orig_arr, aug_vocab, aug_arr, aug_override=True) # assign model.word_vocab = new_vocab # there should be more to replace? but since only testing maybe no need... word_emb_node.replace_weights(new_arr) else: zwarn("No need to aug vocab since delexicalized model!!") new_vocab = orig_vocab return new_vocab
def aug_words_and_embs(self, aug_vocab, aug_wv): orig_vocab = self.word_vocab if self.emb.has_word: orig_arr = self.emb.word_embed.E.detach().cpu().numpy() # todo(+2): find same-spelling words in the original vocab if not-hit in the extra_embed? # todo(warn): here aug_vocab should be find in aug_wv aug_arr = aug_vocab.filter_embed(aug_wv, assert_all_hit=True) new_vocab, new_arr = MultiHelper.aug_vocab_and_arr( orig_vocab, orig_arr, aug_vocab, aug_arr, aug_override=True) # assign self.word_vocab = new_vocab self.emb.word_embed.replace_weights(new_arr) else: zwarn("No need to aug vocab since delexicalized model!!") new_vocab = orig_vocab return new_vocab
def __init__(self, heads: List[int], labels: List[str], label_idxes: List[int] = None): # todo(note): ARTI_ROOT as 0 should be included here, added from the outside! self.heads = heads self.labels = labels self.label_idxes = label_idxes # == some cache values self._dep_dists = None # m-h self._label_matrix = None # Arr[m,h] of int # ----- if self.heads is not None: if self.heads[0] != 0 or self.labels[0] != "": zwarn("Bad values for ARTI_ROOT!!") self._build_tree() # build and check
def _go_index(self, inst: GeneralSentence): # word # todo(warn): remember to norm word; replace singleton at model's input, not here if inst.word_seq.has_vals(): w_voc = self.w_vocab word_idxes = [w_voc.get_else_unk(w) for w in inst.word_seq.vals ] # todo(note): currently unk-idx is large if self.backoff_pos_idx >= 0: # when using this mode, there must be backoff strings in word vocab zwarn( "The option of 'backoff_pos_idx' is deprecated, do not mix things in this way!" ) word_backoff_idxes = [ w_voc[UD2_POS_UNK_MAP[z]] for z in inst.pos_seq.vals ] word_idxes = [(zb if z >= self.backoff_pos_idx else z) for z, zb in zip(word_idxes, word_backoff_idxes)] inst.word_seq.set_idxes(word_idxes) # ===== # aug extra word set!! if self.w2_vocab is not None: w2_voc = self.w2_vocab inst.word2_seq = deepcopy(inst.word_seq) word_idxes = [w2_voc.get_else_unk(w) for w in inst.word2_seq.vals ] # todo(note): currently unk-idx is large inst.word2_seq.set_idxes(word_idxes) # ===== # others char_seq, pos_seq, dep_tree, ner_seq = [ getattr(inst, z, None) for z in ["char_seq", "pos_seq", "dep_tree", "ner_seq"] ] if char_seq is not None and char_seq.has_vals(): char_seq.build_idxes(self.c_vocab) if pos_seq is not None and pos_seq.has_vals(): pos_seq.build_idxes(self.p_vocab) if dep_tree is not None and dep_tree.has_vals(): dep_tree.build_label_idxes(self.l_vocab) if ner_seq is not None and ner_seq.has_vals(): ner_seq.build_idxes(self.n_vocab) if self.inst_preparer is not None: inst = self.inst_preparer(inst) # in fact, inplaced if not wrapping model specific preparer return inst
def __init__( self, pc, n_words, n_dim, fix_row0=True, dropout_wordceil=None, npvec=None, name=None, init_rop=None, freeze=False, init_scale=1., ): super(Embedding, self).__init__(pc, name, init_rop) if npvec is not None: if not (len(npvec.shape) == 2 and npvec.shape[0] == n_words and npvec.shape[1] == n_dim): zlog( f"Wrong dimension for init embeddings {npvec.shape} instead of ({n_words}, {n_dim}), use random instead!!" ) npvec = None else: zlog("Add embed W from npvec %s." % (npvec.shape, )) if freeze: self.rop.add_fixed_value("trainable", False) else: if freeze: self.rop.add_fixed_value("trainable", False) zwarn("Meaningless to freeze random embeddings?") self.E = self.add_param("E", (n_words, n_dim), init=npvec, lookup=True, scale=init_scale) # self.n_words = n_words self.n_dim = n_dim self.fix_row0 = fix_row0 self.dropout_wordceil_hp = dropout_wordceil self.dropout_wordceil = dropout_wordceil if dropout_wordceil is not None else n_words # refreshed values self._input_f = None self.drop_node = self.add_sub_node("drop", Dropout(pc, (self.n_dim, )))
def _load_txt(fname, sep=" "): printing("Going to load pre-trained (txt) w2v from %s ..." % fname) one = WordVectors(sep=sep) repeated_count = 0 with zopen(fname) as fd: # first line line = fd.readline() try: one.num_words, one.embed_size = [int(x) for x in line.split(sep)] printing("Reading w2v num_words=%d, embed_size=%d." % (one.num_words, one.embed_size)) line = fd.readline() except: printing("Reading w2v.") # the rest while len(line) > 0: line = line.rstrip() fields = line.split(sep) word, vec = fields[0], [float(x) for x in fields[1:]] # zcheck(word not in one.wmap, "Repeated key.") # keep the old one if word in one.wmap: repeated_count += 1 zwarn(f"Repeat key {word}") line = fd.readline() continue # if one.embed_size is None: one.embed_size = len(vec) else: zcheck(len(vec) == one.embed_size, "Unmatched embed dimension.") one.vecs.append(vec) one.wmap[word] = len(one.words) one.words.append(word) line = fd.readline() # final if one.num_words is not None: zcheck(one.num_words == len(one.vecs)+repeated_count, "Unmatched num of words.") one.num_words = len(one.vecs) printing(f"Read ok: w2v num_words={one.num_words:d}, embed_size={one.embed_size:d}, repeat={repeated_count:d}") return one
def __init__(self, pc: BK.ParamCollection, n_dim: int, max_len: int = 5000, init_sincos: bool = True, freeze: bool = True, no_dropout: bool = True): super().__init__(pc, None, None) # init from sin/cos values self.init_sincos = init_sincos if init_sincos: pe = np.zeros([max_len, n_dim]) position = np.arange(0, max_len).reshape([-1, 1]) div_term = np.exp( (np.arange(0, n_dim, 2) * -(math.log(2 * max_len) / n_dim))) div_results = position * div_term pe[:, 0::2] = np.sin(div_results) pe[:, 1::2] = np.cos(div_results) # make it similar to the range of plain Embedding pe *= np.sqrt(3. / n_dim) if freeze: self.rop.add_fixed_value("trainable", False) else: zwarn( "Init-from sin/cos positional embeddings should be freezed?" ) else: pe = None if freeze: self.rop.add_fixed_value("trainable", False) zwarn("Meaningless to freeze random embeddings?") # self.dim = n_dim self.max_len = max_len self.E = self.add_param("E", (max_len, n_dim), init=pe, lookup=True) if no_dropout: self.drop_node = lambda x: x else: self.drop_node = self.add_sub_node("drop", Dropout(pc, (self.dim, )))
def prepare_test(args, ConfType=None): # conf conf = init_everything(args, ConfType) dconf, mconf = conf.dconf, conf.mconf # vocab vpack = MLMVocabPackage.build_by_reading(dconf.dict_dir) # prepare data test_streamer = PreprocessStreamer(get_data_reader(dconf.test, dconf.input_format), lower_case=dconf.lower_case, norm_digit=dconf.norm_digit) # model model = build_model(conf, vpack) if dconf.model_load_name != "": model.load(dconf.model_load_name) else: zwarn("No model to load, Debugging mode??") # ----- # augment with extra embeddings for test stream? extra_embed_files = dconf.vconf.test_extra_pretrain_files if len(extra_embed_files) > 0: # get embeddings extra_codes = dconf.vconf.test_extra_pretrain_codes if len(extra_codes) == 0: extra_codes = [""] * len(extra_embed_files) extra_embedding = WordVectors.load(extra_embed_files[0], aug_code=extra_codes[0]) extra_embedding.merge_others([WordVectors.load(one_file, aug_code=one_code) for one_file, one_code in zip(extra_embed_files[1:], extra_codes[1:])]) # get extra dictionary (only those words hit in extra-embed) extra_vocab = VocabBuilder.build_from_stream(iter_hit_words(test_streamer, extra_embedding), sort_by_count=True, pre_list=(), post_list=()) # give them to the model new_vocab = aug_words_and_embs(model, extra_vocab, extra_embedding) vpack.put_voc("word", new_vocab) # ===== # No Cache!! test_inst_preparer = model.get_inst_preper(False) backoff_pos_idx = dconf.backoff_pos_idx test_iter = batch_stream(index_stream(test_streamer, vpack, False, False, test_inst_preparer, backoff_pos_idx), mconf.test_batch_size, mconf, False) return conf, model, vpack, test_iter
def __init__(self, pc: BK.ParamCollection, comp_name: str, ec_conf: EmbedderCompConf, conf: EmbedderNodeConf, vpack: VocabPackage): super().__init__(pc, comp_name, ec_conf, conf, vpack) # ----- # get embeddings npvec = None if self.ec_conf.comp_init_from_pretrain: npvec = vpack.get_emb(comp_name) zlog(f"Try to init InputEmbedNode {comp_name} with npvec.shape={npvec.shape if (npvec is not None) else None}") if npvec is None: zwarn("Warn: cannot get pre-trained embeddings to init!!") # get rare unk range # - get freq vals, make sure special ones will not be pruned; todo(note): directly use that field voc_rare_mask = [float(z is not None and z<=ec_conf.comp_rare_thr) for z in self.voc.final_vals] self.rare_mask = BK.input_real(voc_rare_mask) self.use_rare_unk = (ec_conf.comp_rare_unk>0. and ec_conf.comp_rare_thr>0) # -- # dropout outside explicitly self.E = self.add_sub_node(f"E{self.comp_name}", Embedding( pc, len(self.voc), self.comp_dim, fix_row0=conf.embed_fix_row0, npvec=npvec, name=comp_name, init_rop=NoDropRop(), init_scale=self.comp_init_scale)) self.create_dropout_node()
def prepare_test(args, ConfType=None): # conf conf = init_everything(args, ConfType) dconf, pconf = conf.dconf, conf.pconf iconf = pconf.iconf # vocab vpack = ParserVocabPackage.build_by_reading(dconf) # prepare data test_streamer = get_data_reader(dconf.test, dconf.input_format, dconf.code_test, dconf.use_label0, dconf.aux_repr_test, dconf.aux_score_test) # model model = build_model(conf.partype, conf, vpack) if dconf.model_load_name != "": model.load(dconf.model_load_name) else: zwarn("No model to load, Debugging mode??") # ===== # augment with extra embeddings extra_embed_files = dconf.test_extra_pretrain_files if len(extra_embed_files) > 0: # get embeddings extra_codes = dconf.test_extra_pretrain_codes if len(extra_codes) == 0: extra_codes = [""] * len(extra_embed_files) extra_embedding = WordVectors.load(extra_embed_files[0], aug_code=extra_codes[0]) extra_embedding.merge_others([WordVectors.load(one_file, aug_code=one_code) for one_file, one_code in zip(extra_embed_files[1:], extra_codes[1:])]) # get extra dictionary (only those words hit in extra-embed) extra_vocab = VocabBuilder.build_from_stream(iter_hit_words(test_streamer, extra_embedding), sort_by_count=True, pre_list=(), post_list=()) # give them to the model new_vocab = model.aug_words_and_embs(extra_vocab, extra_embedding) vpack.put_voc("word", new_vocab) # ===== # No Cache!! test_inst_preparer = model.get_inst_preper(False) test_iter = batch_stream(index_stream(test_streamer, vpack, False, False, test_inst_preparer), iconf, False) return conf, model, vpack, test_iter
def build_model(partype, conf, vpack): pconf = conf.pconf parser = None if partype == "graph": # original first-order graph with various output constraints from ..graph.parser import GraphParser parser = GraphParser(pconf, vpack) elif partype == "td": # re-implementation of the top-down stack-pointer parser zwarn( "Warning: Current implementation of td-mode is deprecated and outdated." ) from ..transition.topdown.parser import TdParser parser = TdParser(pconf, vpack) elif partype == "ef": # generalized easy-first parser from ..ef.parser import EfParser parser = EfParser(pconf, vpack) elif partype == "g1": # first-order graph parser from ..ef.parser import G1Parser parser = G1Parser(pconf, vpack) elif partype == "g2": # higher-order graph parser from ..ef.parser import G2Parser parser = G2Parser(pconf, vpack) elif partype == "s2": # two-stage parser from ..ef.parser import S2Parser parser = S2Parser(pconf, vpack) elif partype == "fp": # the finale parser from ..zfp.fp import FpParser parser = FpParser(pconf, vpack) else: zfatal("Unknown parser type: %s") return parser
def parse_mention(self, mention: Dict, sentences: List[Sentence], is_event: bool): # ===== def _get_posi(posi): if posi is None: return None else: assert len(set([z[0] for z in posi])) == 1 # same sid sid, wid = posi[0] assert all(z[1] == wid + i for i, z in enumerate(posi)) # todo(note): here += 1 for ROOT offset return sid, wid + 1, len(posi) # sid, wid, length # ===== # first general span: return None if cannot find position (pre-processing problems) posi = _get_posi(mention["posi"]) if posi is None: return None sid, wid, length = posi # then get head word head_posi0 = mention.get("head", None) if head_posi0 is not None: head_posi0 = head_posi0["posi"] head_posi = _get_posi(head_posi0) if head_posi is None: # use the original whole span head_posi = posi # guess head by heuristic rules head_sid, head_wid, head_length = head_posi if not (head_sid == sid and head_wid >= wid and head_wid + head_length <= wid + length): zwarn( f"Head span is not inside full span, use full posi instead: {mention}" ) head_posi = posi head_wid = self.find_head(head_posi, sentences, is_event) # we only have original position from the origin gold file if "offset" in mention and "length" in mention: origin_char_posi = (mention["offset"], mention["length"]) else: origin_char_posi = None ret = Mention(HardSpan(sid, head_wid, wid, length), origin_char_posi=origin_char_posi) if is_event and ret.hard_span.length > 3: zwarn( f"Strange long event span: {sid}/{wid}/{length}/head={head_wid}: " f"{sentences[sid].words.vals[wid:wid+length]}", level=2) zwarn("", level=2) return ret
def parse_doc(self, doc_dict: Dict): doc_id, dataset, source = doc_dict["doc_id"], doc_dict[ "dataset"], doc_dict["source"] label_normer = get_label_normer(dataset) # build all the sentences (basic inputs) sentences = [ self.parse_sent(sid, one_sent) for sid, one_sent in enumerate(doc_dict["sents"]) ] # build entities, events and arguments args_maps = {} # id -> EntityFiller # doc = DocInstance(sentences, {k:v for k,v in doc_dict.items() if isinstance(v, str)}) # record simple fields as props doc = DocInstance(sentences, doc_dict) for s in sentences: s.doc = doc # link back # ----- entity and fillers sig2ef = {} # todo(note): used to merge ef, but not for evt for cur_name in ["entity_mentions", "fillers"]: is_entity = (cur_name == "entity_mentions") cur_mentions = doc_dict.get(cur_name, None) if cur_mentions is not None: # first prepare the lists for s in sentences: if s.entity_fillers is None: s.entity_fillers = [] if doc.entity_fillers is None: doc.entity_fillers = [] # then parse them for cur_mention in cur_mentions: new_mention = self.parse_mention(cur_mention, sentences, is_event=False) # assert cur_mention["type"] is not None cur_type = label_normer.norm_ef_label(cur_mention["type"]) mtype = cur_mention.get("mtype", "") new_id = cur_mention["id"] kwargs = { k: cur_mention.get(k) for k in ["score", "extra_info", "gid"] if k in cur_mention } new_ef = EntityFiller(new_id, new_mention, cur_type, mtype, is_entity, **kwargs) # todo(note): check for repeat if new_mention is not None: cur_sig = (new_mention.hard_span.position(False), cur_type) # same span and type repeat_ef = sig2ef.get(cur_sig) if repeat_ef is not None: assert new_id not in args_maps args_maps[new_id] = repeat_ef # if is_entity: # zwarn(f"Repeated entity: {repeat_ef} <- {new_ef}") continue # not adding new one, only put id for later arg finding else: sig2ef[cur_sig] = new_ef # if not repeat assert new_id not in args_maps args_maps[new_id] = new_ef doc.entity_fillers.append(new_ef) if new_mention is not None: sentences[new_mention.hard_span. sid].entity_fillers.append(new_ef) else: # zwarn(f"Cannot find posi for a entity/filler mention: {cur_mention}") pass # --- read entity corefs for solving some cross-sent distance args entity_chains = {} entity_corefs = doc_dict.get("entities") if entity_corefs is not None: entity_chains = {z['id']: z['mentions'] for z in entity_corefs} # --- events and arguments event_mentions = doc_dict.get("event_mentions", None) if event_mentions is not None: # first prepare the lists for s in sentences: s.events = [] doc.events = [] # then parse them for cur_mention in event_mentions: new_mention = self.parse_mention(cur_mention["trigger"], sentences, is_event=True) cur_type = label_normer.norm_evt_label(cur_mention["type"]) # ----- # cutoff event type layers cur_type = ".".join(cur_type.split(".")[:self.max_evt_layers]) # ----- kwargs = { k: cur_mention.get(k) for k in ["score", "extra_info", "gid", "realis", "realis_score"] if k in cur_mention } new_evt = Event(cur_mention["id"], new_mention, cur_type, **kwargs) em_args = cur_mention.get("em_arg", None) if em_args is None: new_evt.links = None # annotation not available else: cur_evt_sid = None if new_mention is None else new_mention.hard_span.sid for cur_arg in cur_mention["em_arg"]: aid, role = cur_arg[ "aid"], label_normer.norm_role_label( cur_arg["role"]) cur_ef = args_maps.get(aid) if cur_ef is None: zwarn(f"Cannot find event argument: {cur_arg}", level=3) continue # ===== # todo(note): change for same-sent args if self.alter_carg_by_coref and cur_evt_sid is not None: cur_ef_sid = None if cur_ef.mention is None else cur_ef.mention.hard_span.sid if cur_ef_sid is not None and cur_ef_sid != cur_evt_sid: coref_chain_ef_mentions = entity_chains.get( cur_ef.gid) if coref_chain_ef_mentions is not None: for alter_aid in coref_chain_ef_mentions: alter_ef = args_maps[alter_aid] alter_ef_sid = None if alter_ef.mention is None else alter_ef.mention.hard_span.sid if alter_ef_sid == cur_evt_sid: cur_ef = alter_ef # find the alternative break # ===== # todo(WARN): there can be no-position arguments kwargs = { k: cur_arg.get(k) for k in ["is_aug", "score", "extra_info"] if k in cur_arg } new_evt.add_arg(cur_ef, role, **kwargs) if cur_ef.mention is None: zwarn( f"Cannot find posi for an event argument: {cur_arg}", level=2) # add it doc.events.append(new_evt) if new_mention is not None: sentences[new_mention.hard_span.sid].events.append(new_evt) else: zwarn( f"Cannot find posi for an event mention: {cur_mention}", level=2) # ===== # special mode if self.noef_link0: tmp_ff = lambda x: len(x.links) > 0 if doc.entity_fillers is not None: doc.entity_fillers = list(filter(tmp_ff, doc.entity_fillers)) for one_sent in doc.sents: one_sent.entity_fillers = list( filter(tmp_ff, one_sent.entity_fillers)) return doc
def __init__(self, pc: BK.ParamCollection, econf: EmbedConf, vpack: VocabPackage): super().__init__(pc, None, None) self.conf = econf # repr_sizes = [] # word self.has_word = (econf.dim_word > 0) if self.has_word: npvec = vpack.get_emb( "word") if econf.init_words_from_pretrain else None self.word_embed = self.add_sub_node( "ew", Embedding(self.pc, len(vpack.get_voc("word")), econf.dim_word, npvec=npvec, name="word", freeze=econf.word_freeze)) repr_sizes.append(econf.dim_word) # char self.has_char = (econf.dim_char > 0) if self.has_char: # todo(warn): cnns will also use emb's drop? self.char_embed = self.add_sub_node( "ec", Embedding(self.pc, len(vpack.get_voc("char")), econf.dim_char, name="char")) per_cnn_size = econf.char_cnn_hidden // len(econf.char_cnn_windows) self.char_cnns = [ self.add_sub_node( "cnnc", CnnLayer(self.pc, econf.dim_char, per_cnn_size, z, pooling="max", act="tanh")) for z in econf.char_cnn_windows ] repr_sizes.append(econf.char_cnn_hidden) # posi: absolute positional embeddings self.has_posi = (econf.dim_posi > 0) if self.has_posi: self.posi_embed = self.add_sub_node( "ep", PosiEmbedding(self.pc, econf.dim_posi, econf.posi_clip, econf.posi_fix_sincos, econf.posi_freeze)) repr_sizes.append(econf.dim_posi) # extras: like POS, ... self.dim_extras = econf.dim_extras self.extra_names = econf.extra_names zcheck( len(self.dim_extras) == len(self.extra_names), "Unmatched dims and names!") self.extra_embeds = [] for one_extra_dim, one_name in zip(self.dim_extras, self.extra_names): self.extra_embeds.append( self.add_sub_node( "ext", Embedding(self.pc, len(vpack.get_voc(one_name)), one_extra_dim, npvec=vpack.get_emb(one_name, None), name="extra:" + one_name))) repr_sizes.append(one_extra_dim) # auxes self.dim_auxes = econf.dim_auxes self.fold_auxes = econf.fold_auxes self.aux_overall_gammas = [] self.aux_fold_lambdas = [] for one_aux_dim, one_aux_fold in zip(self.dim_auxes, self.fold_auxes): repr_sizes.append(one_aux_dim) # aux gamma and fold trainable lambdas self.aux_overall_gammas.append(self.add_param("AG", (), 1.)) # scalar self.aux_fold_lambdas.append( self.add_param( "AL", (), [1. / one_aux_fold for _ in range(one_aux_fold)])) # [#fold] # ===== # another projection layer? & set final dim if len(repr_sizes) <= 0: zwarn("No inputs??") # zcheck(len(repr_sizes)>0, "No inputs?") self.repr_sizes = repr_sizes self.has_proj = (econf.emb_proj_dim > 0) if self.has_proj: proj_layer = Affine(self.pc, sum(repr_sizes), econf.emb_proj_dim) if econf.emb_proj_norm: norm_layer = LayerNorm(self.pc, econf.emb_proj_dim) self.final_layer = self.add_sub_node( "fl", Sequential(self.pc, [proj_layer, norm_layer])) else: self.final_layer = self.add_sub_node("fl", proj_layer) self.output_dim = econf.emb_proj_dim else: self.final_layer = None self.output_dim = sum(repr_sizes) # ===== # special MdDropout: dropout the entire last dim (for word, char, extras, but not posi) self.dropmd_word = self.add_sub_node("md", DropoutLastN(pc, lastn=1)) self.dropmd_char = self.add_sub_node("md", DropoutLastN(pc, lastn=1)) self.dropmd_extras = [ self.add_sub_node("md", DropoutLastN(pc, lastn=1)) for _ in self.extra_names ] # dropouts for aux self.drop_auxes = [ self.add_sub_node("aux", Dropout(pc, (one_aux_dim, ))) for one_aux_dim in self.dim_auxes ]
def __init__(self, pc: BK.ParamCollection, input_dim: int, conf: PlainLMNodeConf, inputter: Inputter): super().__init__(pc, conf, name="PLM") self.conf = conf self.inputter = inputter self.input_dim = input_dim self.split_input_blm = conf.split_input_blm # this step is performed at the embedder, thus still does not influence the inputter self.add_root_token = self.inputter.embedder.add_root_token # vocab and padder vpack = inputter.vpack vocab_word = vpack.get_voc("word") # models real_input_dim = input_dim // 2 if self.split_input_blm else input_dim if conf.hid_dim <= 0: # no hidden layer self.l2r_hid_layer = self.r2l_hid_layer = None self.pred_input_dim = real_input_dim else: self.l2r_hid_layer = self.add_sub_node( "l2r_h", Affine(pc, real_input_dim, conf.hid_dim, act=conf.hid_act)) self.r2l_hid_layer = self.add_sub_node( "r2l_h", Affine(pc, real_input_dim, conf.hid_dim, act=conf.hid_act)) self.pred_input_dim = conf.hid_dim # todo(note): unk is the first one above real words self.pred_size = min(conf.max_pred_rank + 1, vocab_word.unk) if conf.tie_input_embeddings: zwarn("Tie all preds in plm with input embeddings!!") self.l2r_pred = self.r2l_pred = None self.inputter_embed_node = self.inputter.embedder.get_node("word") else: self.l2r_pred = self.add_sub_node( "l2r_p", Affine(pc, self.pred_input_dim, self.pred_size, init_rop=NoDropRop())) if conf.tie_bidirect_pred: self.r2l_pred = self.l2r_pred else: self.r2l_pred = self.add_sub_node( "r2l_p", Affine(pc, self.pred_input_dim, self.pred_size, init_rop=NoDropRop())) self.inputter_embed_node = None if conf.init_pred_from_pretrain: npvec = vpack.get_emb("word") if npvec is None: zwarn( "Pretrained vector not provided, skip init pred embeddings!!" ) else: with BK.no_grad_env(): self.l2r_pred.ws[0].copy_( BK.input_real(npvec[:self.pred_size].T)) self.r2l_pred.ws[0].copy_( BK.input_real(npvec[:self.pred_size].T)) zlog( f"Init pred embeddings from pretrained vectors (size={self.pred_size})." )
# # high-order parsing from msp.utils import zwarn, zfatal # todo(note): which version of TurboParser and AD3? # Updated: now directly use AD3's example and TurboParser # https://github.com/andre-martins/AD3/commit/22131c7457614dd159546500cd1a0fd8cdf2d282 # https://github.com/andre-martins/TurboParser/commit/a87b8e45694c18b826bb3c42e8344bd32928007d # import numpy as np try: from .parser2 import parse2 except: zwarn("Cannot find high-order parsing lib, please compile them if needed") def parse2(*args, **kwargs): raise NotImplementedError("Compile the C++ codes for this one!") # TODO(WARN): be careful about when there are no trees in the current pruning mask! (especially for proj) # and here only do unlabeled parsing, since labeled ones will cost more, handling labels at the outside # high order parsing decode # for the *_pack, is an list/tuple of related indexes and scores, None means no such features def hop_decode(slen: int, projective: bool, o1_masks, o1_scores, o2sib_pack, o2g_pack, o3gsib_pack): # dummy arguments, use None will get argument error dummy_int_arr = np.array([], dtype=np.int32) dummy_double_arr = dummy_int_arr.astype(np.double) # prepare inputs
def main(args): conf: OverallConf = init_everything(args) dconf, mconf = conf.dconf, conf.mconf tconf = mconf.tconf iconf = mconf.iconf # # dev/test can be non-existing! if not dconf.dev and dconf.test: utils.zwarn( "No dev but give test, actually use test as dev (for early stopping)!!" ) dt_golds, dt_aux_reprs = [], [] for file, aux_repr in [(dconf.dev, dconf.aux_repr_dev), (dconf.test, dconf.aux_repr_test)]: if len(file) > 0: utils.zlog( f"Add file `{file}(aux_repr={aux_repr})' as dt-file #{len(dt_golds)}." ) dt_golds.append(file) dt_aux_reprs.append(aux_repr) # data if len(dconf.ms_train) > 0: # do ms train, ignore dconf.train train_streamers = [ get_data_reader(f, dconf.input_format, dconf.use_label0, dconf.noef_link0, dconf.aux_repr_train, max_evt_layers=dconf.max_evt_layers) for f in dconf.ms_train ] train_streamer = MultiCatStreamer( train_streamers) # simple concat for building vocab ms_budgets = [ ScheduledValue(f"ms_budget{i}", c) for i, c in enumerate( dconf.ms_train_budget_list[:len(train_streamers)]) ] assert len(ms_budgets) == len(train_streamers) utils.zlog(f"Multi-source training with inputsL {dconf.ms_train}") else: train_streamers = ms_budgets = None train_streamer = get_data_reader(dconf.train, dconf.input_format, dconf.use_label0, dconf.noef_link0, dconf.aux_repr_train, max_evt_layers=dconf.max_evt_layers) dt_streamers = [ get_data_reader(f, dconf.input_format, dconf.use_label0, dconf.noef_link0, aux_r) for f, aux_r in zip(dt_golds, dt_aux_reprs) ] # vocab if tconf.no_build_dict: vpack = IEVocabPackage.build_by_reading(conf) else: # include dev/test only for convenience of including words hit in pre-trained embeddings vpack = IEVocabPackage.build_from_stream( conf, train_streamer, MultiCatStreamer(dt_streamers)) vpack.save(dconf.dict_dir) # model model = build_model(conf.model_type, conf, vpack) # use bert? todo(note): pre-compute here? if dconf.use_bert: bmodel = get_berter(dconf.bconf) train_streamer = BerterDataAuger(train_streamer, bmodel, "aux_repr") dt_streamers = [ BerterDataAuger(z, bmodel, "aux_repr") for z in dt_streamers ] # index the data train_inst_preparer = model.get_inst_preper(True) test_inst_preparer = model.get_inst_preper(False) to_cache = dconf.cache_data to_cache_shuffle = dconf.cache_shuffle # ----- if ms_budgets is None: train_iter = batch_stream( index_stream(train_streamer, vpack, to_cache, to_cache_shuffle, train_inst_preparer), tconf, True) else: indexes_streamers = [ index_stream(s, vpack, to_cache, to_cache_shuffle, train_inst_preparer) for s in train_streamers ] multi_streamer = MultiSpecialJoinStream(indexes_streamers, ms_budgets, dconf.ms_stop_idx) train_iter = batch_stream(multi_streamer, tconf, True) # ----- dt_iters = [ batch_stream( index_stream(z, vpack, to_cache, to_cache_shuffle, test_inst_preparer), iconf, False) for z in dt_streamers ] # training runner tr = MyIETrainingRunner(tconf, model, vpack, dev_outfs=dconf.output_file, dev_goldfs=dt_golds, dev_out_format=dconf.output_format, eval_conf=dconf.eval_conf) # ----- if ms_budgets is not None: tr.add_scheduled_values(ms_budgets) # add s-values # ----- if tconf.load_model: tr.load(dconf.model_load_name, tconf.load_process) # go tr.run(train_iter, dt_iters) utils.zlog("The end of Training.")
def forward_batch(self, batched_ids: List, batched_starts: List, batched_typeids: List, training: bool, other_inputs: List[List] = None): conf = self.bconf tokenizer = self.tokenizer PAD_IDX = tokenizer.pad_token_id MASK_IDX = tokenizer.mask_token_id CLS_IDX = tokenizer.cls_token_id SEP_IDX = tokenizer.sep_token_id if other_inputs is None: other_inputs = [] # ===== # batch: here add CLS and SEP bsize = len(batched_ids) max_len = max(len(z) for z in batched_ids) + 2 # plus [CLS] and [SEP] input_shape = (bsize, max_len) # first collect on CPU input_ids_arr = np.full(input_shape, PAD_IDX, dtype=np.int64) input_ids_arr[:, 0] = CLS_IDX input_mask_arr = np.full(input_shape, 0, dtype=np.float32) input_is_start_arr = np.full(input_shape, 0, dtype=np.int64) input_typeids = None if batched_typeids is None else np.full( input_shape, 0, dtype=np.int64) other_input_arrs = [ np.full(input_shape, 0, dtype=np.int64) for _ in other_inputs ] if conf.bert2_retinc_cls: # act as the ROOT word input_is_start_arr[:, 0] = 1 training_mask_rate = conf.bert2_training_mask_rate if training else 0. self_sample_stream = self.random_sample_stream for bidx in range(bsize): cur_ids, cur_starts = batched_ids[bidx], batched_starts[bidx] cur_end = len(cur_ids) + 2 # plus CLS and SEP if training_mask_rate > 0.: # input dropout input_ids_arr[bidx, 1:cur_end] = [ (MASK_IDX if next(self_sample_stream) < training_mask_rate else z) for z in cur_ids ] + [SEP_IDX] else: input_ids_arr[bidx, 1:cur_end] = cur_ids + [SEP_IDX] input_is_start_arr[bidx, 1:cur_end - 1] = cur_starts input_mask_arr[bidx, :cur_end] = 1. if batched_typeids is not None and batched_typeids[ bidx] is not None: input_typeids[bidx, 1:cur_end - 1] = batched_typeids[bidx] for one_other_input_arr, one_other_input_list in zip( other_input_arrs, other_inputs): one_other_input_arr[bidx, 1:cur_end - 1] = one_other_input_list[bidx] # arr to tensor input_ids_t = BK.input_idx(input_ids_arr) input_mask_t = BK.input_real(input_mask_arr) input_is_start_t = BK.input_idx(input_is_start_arr) input_typeid_t = None if input_typeids is None else BK.input_idx( input_typeids) other_input_ts = [BK.input_idx(z) for z in other_input_arrs] # ===== # forward (maybe need multiple times to fit maxlen constraint) MAX_LEN = 510 # save two for [CLS] and [SEP] BACK_LEN = 100 # for splitting cases, still remaining some of previous sub-tokens for context if max_len <= MAX_LEN: # directly once final_outputs = self.forward_features( input_ids_t, input_mask_t, input_typeid_t, other_input_ts) # [bs, slen, *...] start_idxes, start_masks = BK.mask2idx( input_is_start_t.float()) # [bsize, ?] else: all_outputs = [] cur_sub_idx = 0 slice_size = [bsize, 1] slice_cls, slice_sep = BK.constants(slice_size, CLS_IDX, dtype=BK.int64), BK.constants( slice_size, SEP_IDX, dtype=BK.int64) while cur_sub_idx < max_len - 1: # minus 1 to ignore ending SEP cur_slice_start = max(1, cur_sub_idx - BACK_LEN) cur_slice_end = min(cur_slice_start + MAX_LEN, max_len - 1) cur_input_ids_t = BK.concat([ slice_cls, input_ids_t[:, cur_slice_start:cur_slice_end], slice_sep ], 1) # here we simply extend extra original masks cur_input_mask_t = input_mask_t[:, cur_slice_start - 1:cur_slice_end + 1] cur_input_typeid_t = None if input_typeid_t is None else input_typeid_t[:, cur_slice_start - 1: cur_slice_end + 1] cur_other_input_ts = [ z[:, cur_slice_start - 1:cur_slice_end + 1] for z in other_input_ts ] cur_outputs = self.forward_features(cur_input_ids_t, cur_input_mask_t, cur_input_typeid_t, cur_other_input_ts) # only include CLS in the first run, no SEP included if cur_sub_idx == 0: # include CLS, exclude SEP all_outputs.append(cur_outputs[:, :-1]) else: # include only new ones, discard BACK ones, exclude CLS, SEP all_outputs.append(cur_outputs[:, cur_sub_idx - cur_slice_start + 1:-1]) zwarn( f"Add multiple-seg range: [{cur_slice_start}, {cur_sub_idx}, {cur_slice_end})] " f"for all-len={max_len}") cur_sub_idx = cur_slice_end final_outputs = BK.concat(all_outputs, 1) # [bs, max_len-1, *...] start_idxes, start_masks = BK.mask2idx( input_is_start_t[:, :-1].float()) # [bsize, ?] start_expr = BK.gather_first_dims(final_outputs, start_idxes, 1) # [bsize, ?, *...] return start_expr, start_masks # [bsize, ?, ...], [bsize, ?]