def __init__(self, pc: BK.ParamCollection, conf: MaskLMNodeConf, vpack: VocabPackage): super().__init__(pc, None, None) self.conf = conf # vocab and padder self.word_vocab = vpack.get_voc("word") self.padder = DataPadder( 2, pad_vals=self.word_vocab.pad, mask_range=2) # todo(note): <pad>-id is very large # models self.hid_layer = self.add_sub_node( "hid", Affine(pc, conf._input_dim, conf.hid_dim, act=conf.hid_act)) self.pred_layer = self.add_sub_node( "pred", Affine(pc, conf.hid_dim, conf.max_pred_rank + 1, init_rop=NoDropRop())) if conf.init_pred_from_pretrain: npvec = vpack.get_emb("word") if npvec is None: zwarn( "Pretrained vector not provided, skip init pred embeddings!!" ) else: with BK.no_grad_env(): self.pred_layer.ws[0].copy_( BK.input_real(npvec[:conf.max_pred_rank + 1].T)) zlog( f"Init pred embeddings from pretrained vectors (size={conf.max_pred_rank+1})." )
def __init__(self, conf: G1ParserConf, vpack: VocabPackage): super().__init__(conf, vpack) # todo(note): the neural parameters are exactly the same as the EF one self.scorer_helper = GScorerHelper(self.scorer) self.predict_padder = DataPadder(2, pad_vals=0) # self.g1_use_aux_scores = conf.debug_use_aux_scores # assining here is only for debugging usage, otherwise assigning outside self.num_label = self.label_vocab.trg_len( True) # todo(WARN): use the original idx # self.loss_hinge = (self.conf.tconf.loss_function == "hinge") if not self.loss_hinge: assert self.conf.tconf.loss_function == "prob", "This model only supports hinge or prob"
def __init__(self, embedder: EmbedderNode, vpack: VocabPackage): self.vpack = vpack self.embedder = embedder # ----- # prepare the inputter self.comp_names = embedder.comp_names self.comp_helpers = [] for comp_name in self.comp_names: one_helper = InputHelper.get_input_helper(comp_name, comp_name, vpack) self.comp_helpers.append(one_helper) if comp_name == "bert": assert embedder.berter is not None one_helper.set_berter(berter=embedder.berter) # ==== self.mask_padder = DataPadder(2, pad_vals=0., mask_range=2)
def __init__(self, pc, conf: FpDecConf, label_vocab): super().__init__(pc, None, None) self.conf = conf self.label_vocab = label_vocab self.predict_padder = DataPadder(2, pad_vals=0) # the scorer self.use_ablpair = conf.use_ablpair conf.sconf._input_dim = conf._input_dim conf.sconf._num_label = conf._num_label if self.use_ablpair: self.scorer = self.add_sub_node("s", FpSingleScorer(pc, conf.sconf)) else: self.scorer = self.add_sub_node("s", FpPairedScorer(pc, conf.sconf))
class Inputter: def __init__(self, embedder: EmbedderNode, vpack: VocabPackage): self.vpack = vpack self.embedder = embedder # ----- # prepare the inputter self.comp_names = embedder.comp_names self.comp_helpers = [] for comp_name in self.comp_names: one_helper = InputHelper.get_input_helper(comp_name, comp_name, vpack) self.comp_helpers.append(one_helper) if comp_name == "bert": assert embedder.berter is not None one_helper.set_berter(berter=embedder.berter) # ==== self.mask_padder = DataPadder(2, pad_vals=0., mask_range=2) def __call__(self, insts: List[GeneralSentence]): # first pad words to get masks _, masks_arr = self.mask_padder.pad([z.word_seq.idxes for z in insts]) # then get each one ret_map = {"mask": masks_arr} for comp_name, comp_helper in zip(self.comp_names, self.comp_helpers): ret_map[comp_name] = comp_helper.prepare(insts) return ret_map # todo(note): return new masks (input is read only!!) def mask_input(self, input_map: Dict, input_erase_mask, nomask_names_set: Set): ret_map = {"mask": input_map["mask"]} for comp_name, comp_helper in zip(self.comp_names, self.comp_helpers): if comp_name in nomask_names_set: # direct borrow that one ret_map[comp_name] = input_map[comp_name] else: ret_map[comp_name] = comp_helper.mask(input_map[comp_name], input_erase_mask) return ret_map
def __init__(self, conf: GraphParserConf, vpack: VocabPackage): super().__init__(conf, vpack) # ===== Input Specification ===== # both head/label padding with 0 (does not matter what, since will be masked) self.predict_padder = DataPadder(2, pad_vals=0) self.hlocal_padder = DataPadder(3, pad_vals=0.) # # todo(warn): adding-styled hlocal has problems intuitively, maybe not suitable for graph-parser self.norm_single, self.norm_local, self.norm_global, self.norm_hlocal = \ [conf.output_normalizing==z for z in ["single", "local", "global", "hlocal"]] self.loss_prob, self.loss_hinge, self.loss_mr = [ conf.tconf.loss_function == z for z in ["prob", "hinge", "mr"] ] self.alg_proj, self.alg_unproj, self.alg_greedy = [ conf.iconf.dec_algorithm == z for z in ["proj", "unproj", "greedy"] ]
def __init__(self, conf: G2ParserConf, vpack: VocabPackage): super().__init__(conf, vpack) # todo(note): the neural parameters are exactly the same as the EF one # ===== basic G1 Parser's loading # todo(note): there can be parameter mismatch (but all of them in non-trained part, thus will be fine) self.g1parser = G1Parser.pre_g1_init(self, conf.pre_g1_conf) self.lambda_g1_arc_training = conf.pre_g1_conf.lambda_g1_arc_training self.lambda_g1_arc_testing = conf.pre_g1_conf.lambda_g1_arc_testing self.lambda_g1_lab_training = conf.pre_g1_conf.lambda_g1_lab_training self.lambda_g1_lab_testing = conf.pre_g1_conf.lambda_g1_lab_testing # self.add_slayer() self.dl = G2DL(self.scorer, self.slayer, conf) # self.predict_padder = DataPadder(2, pad_vals=0) self.num_label = self.label_vocab.trg_len( True) # todo(WARN): use the original idx
def __init__(self, pc: BK.ParamCollection, conf: M3EncConf, tconf, vpack: VocabPackage): super().__init__(pc, conf, tconf, vpack) # self.conf = conf # ----- bert # modify bert_conf for other input BERT_OTHER_VSIZE = 50 # todo(+N): this should be enough for small inputs! conf.bert_conf.bert2_other_input_names = conf.bert_other_inputs conf.bert_conf.bert2_other_input_vsizes = [BERT_OTHER_VSIZE] * len( conf.bert_other_inputs) self.berter = self.add_sub_node("bert", Berter2(pc, conf.bert_conf)) # ----- # index fake sent self.index_helper = IndexerHelper(vpack) # extra encoder over bert? self.bert_dim, self.bert_fold = self.berter.get_output_dims() conf.m3_enc_conf._input_dim = self.bert_dim self.m3_encs = [ self.add_sub_node("m3e", MyEncoder(pc, conf.m3_enc_conf)) for _ in range(self.bert_fold) ] self.m3_enc_out_dim = self.m3_encs[0].get_output_dims()[0] # skip m3_enc? self.m3_enc_is_empty = all(len(z.layers) == 0 for z in self.m3_encs) if self.m3_enc_is_empty: assert all(z.get_output_dims()[0] == self.bert_dim for z in self.m3_encs) zlog("For m3_enc, we will skip it since it is empty!!") # dep as basic? if conf.m2e_use_basic_dep: MAX_LABEL_NUM = 200 # this should be enough self.dep_label_emb = self.add_sub_node( "dlab", Embedding(self.pc, MAX_LABEL_NUM, conf.dep_label_dim, name="dlab")) self.dep_layer = self.add_sub_node( "dep", TaskSpecAdp(pc, [(self.m3_enc_out_dim, self.bert_fold), None], [conf.dep_label_dim], conf.dep_output_dim)) else: self.dep_label_emb = self.dep_layer = None self.dep_padder = DataPadder( 2, pad_vals=0) # 0 for both head-idx and label
def __init__(self, pc, conf: NodeExtractorConfBase, vocab: HLabelVocab, extract_type: str): super().__init__(pc, None, None) self.conf = conf self.vocab = vocab self.hl: HLabelNode = self.add_sub_node( "hl", HLabelNode(pc, conf.lab_conf, vocab)) self.hl_output_size = self.hl.prediction_sizes[ self.hl.eff_max_layer - 1] # num of output labels # self.extract_type = extract_type self.items_getter = { "evt": self.get_events, "ef": lambda sent: sent.entity_fillers }[extract_type] self.constrain_evt_types = None # 2d pad self.padder_mask = DataPadder(2, pad_vals=0.) self.padder_idxes = DataPadder( 2, pad_vals=0) # todo(warn): 0 for full-nil self.padder_items = DataPadder(2, pad_vals=None) # 3d pad self.padder_mask_3d = DataPadder(3, pad_vals=0.) self.padder_items_3d = DataPadder(3, pad_vals=None)
def __init__(self, pc: BK.ParamCollection, rconf: SL0Conf): super().__init__(pc, None, None) self.dim = rconf._input_dim # both input/output dim # padders for child nodes self.chs_start_posi = -rconf.chs_num self.ch_idx_padder = DataPadder(2, pad_vals=0, mask_range=2) # [*, num-ch] self.ch_label_padder = DataPadder(2, pad_vals=0) # self.label_embeddings = self.add_sub_node( "label", Embedding(pc, rconf._num_label, rconf.dim_label, fix_row0=False)) self.dim_label = rconf.dim_label # todo(note): now adopting flatten groupings for basic, and then that is all, no more recurrent features # group 1: [cur, chs, par] -> head_pre_size self.use_chs = rconf.use_chs self.use_par = rconf.use_par self.use_label_feat = rconf.use_label_feat # components (add the parameters anyway) # todo(note): children features: children + (label of mod->children) self.chs_reprer = self.add_sub_node("chs", ChsReprer(pc, rconf)) self.chs_ff = self.add_sub_node( "chs_ff", Affine(pc, self.chs_reprer.get_output_dims()[0], self.dim, act="tanh")) # todo(note): parent features: parent + (label of parent->mod) # todo(warn): always add label related params par_ff_inputs = [self.dim, rconf.dim_label] self.par_ff = self.add_sub_node( "par_ff", Affine(pc, par_ff_inputs, self.dim, act="tanh")) # no other groups anymore! if rconf.zero_extra_output_params: self.par_ff.zero_params() self.chs_ff.zero_params()
def __init__(self, pc: BK.ParamCollection, conf: FpEncConf, vpack: VocabPackage): super().__init__(pc, None, None) self.conf = conf # ===== Vocab ===== self.word_vocab = vpack.get_voc("word") self.char_vocab = vpack.get_voc("char") self.pos_vocab = vpack.get_voc("pos") # avoid no params error self._tmp_v = self.add_param("nope", (1, )) # ===== Model ===== # embedding self.emb = self.add_sub_node("emb", MyEmbedder(self.pc, conf.emb_conf, vpack)) self.emb_output_dim = self.emb.get_output_dims()[0] # bert self.bert = self.add_sub_node("bert", Berter2(self.pc, conf.bert_conf)) self.bert_output_dim = self.bert.get_output_dims()[0] # make sure there are inputs assert self.emb_output_dim > 0 or self.bert_output_dim > 0 # middle? if conf.middle_dim > 0: self.middle_node = self.add_sub_node( "mid", Affine(self.pc, self.emb_output_dim + self.bert_output_dim, conf.middle_dim, act="elu")) self.enc_input_dim = conf.middle_dim else: self.middle_node = None self.enc_input_dim = self.emb_output_dim + self.bert_output_dim # concat the two parts (if needed) # encoder? # todo(note): feed compute-on-the-fly hp conf.enc_conf._input_dim = self.enc_input_dim self.enc = self.add_sub_node("enc", MyEncoder(self.pc, conf.enc_conf)) self.enc_output_dim = self.enc.get_output_dims()[0] # ===== Input Specification ===== # inputs (word, char, pos) and vocabulary self.need_word = self.emb.has_word self.need_char = self.emb.has_char # todo(warn): currently only allow extra fields for POS self.need_pos = False if len(self.emb.extra_names) > 0: assert len( self.emb.extra_names) == 1 and self.emb.extra_names[0] == "pos" self.need_pos = True # self.word_padder = DataPadder(2, pad_vals=self.word_vocab.pad, mask_range=2) self.char_padder = DataPadder(3, pad_lens=(0, 0, conf.char_max_length), pad_vals=self.char_vocab.pad) self.pos_padder = DataPadder(2, pad_vals=self.pos_vocab.pad)
def __init__(self, pc: BK.ParamCollection, bconf: BTConf, vpack: VocabPackage): super().__init__(pc, None, None) self.bconf = bconf # ===== Vocab ===== self.word_vocab = vpack.get_voc("word") self.char_vocab = vpack.get_voc("char") self.pos_vocab = vpack.get_voc("pos") # ===== Model ===== # embedding self.emb = self.add_sub_node( "emb", MyEmbedder(self.pc, bconf.emb_conf, vpack)) emb_output_dim = self.emb.get_output_dims()[0] # encoder0 for jpos # todo(note): will do nothing if not use_jpos bconf.jpos_conf._input_dim = emb_output_dim self.jpos_enc = self.add_sub_node( "enc0", JPosModule(self.pc, bconf.jpos_conf, self.pos_vocab)) enc0_output_dim = self.jpos_enc.get_output_dims()[0] # encoder # todo(0): feed compute-on-the-fly hp bconf.enc_conf._input_dim = enc0_output_dim self.enc = self.add_sub_node("enc", MyEncoder(self.pc, bconf.enc_conf)) self.enc_output_dim = self.enc.get_output_dims()[0] # ===== Input Specification ===== # inputs (word, char, pos) and vocabulary self.need_word = self.emb.has_word self.need_char = self.emb.has_char # todo(warn): currently only allow extra fields for POS self.need_pos = False if len(self.emb.extra_names) > 0: assert len( self.emb.extra_names) == 1 and self.emb.extra_names[0] == "pos" self.need_pos = True # todo(warn): currently only allow one aux field self.need_aux = False if len(self.emb.dim_auxes) > 0: assert len(self.emb.dim_auxes) == 1 self.need_aux = True # self.word_padder = DataPadder(2, pad_vals=self.word_vocab.pad, mask_range=2) self.char_padder = DataPadder(3, pad_lens=(0, 0, bconf.char_max_length), pad_vals=self.char_vocab.pad) self.pos_padder = DataPadder(2, pad_vals=self.pos_vocab.pad) # self.random_sample_stream = Random.stream(Random.random_sample)
class InputHelper: def __init__(self, comp_name, vpack: VocabPackage): self.comp_name = comp_name self.comp_seq_name = f"{comp_name}_seq" self.voc = vpack.get_voc(comp_name) self.padder = DataPadder(2, pad_vals=0) # pad 0 # return batched arr def prepare(self, insts: List[GeneralSentence]): cur_input_list = [getattr(z, self.comp_seq_name).idxes for z in insts] cur_input_arr, _ = self.padder.pad(cur_input_list) return cur_input_arr def mask(self, v, erase_mask): IDX_MASK = self.voc.err # todo(note): this one is unused, thus just take it! ret_arr = v * (1-erase_mask) + IDX_MASK * erase_mask return ret_arr @staticmethod def get_input_helper(name, *args, **kwargs): helper_type = {"char": CharCnnHelper, "posi": PosiHelper, "bert": BertInputHelper}.get(name, PlainSeqHelper) return helper_type(*args, **kwargs)
class NodeExtractorBase(BasicNode): def __init__(self, pc, conf: NodeExtractorConfBase, vocab: HLabelVocab, extract_type: str): super().__init__(pc, None, None) self.conf = conf self.vocab = vocab self.hl: HLabelNode = self.add_sub_node( "hl", HLabelNode(pc, conf.lab_conf, vocab)) self.hl_output_size = self.hl.prediction_sizes[ self.hl.eff_max_layer - 1] # num of output labels # self.extract_type = extract_type self.items_getter = { "evt": self.get_events, "ef": lambda sent: sent.entity_fillers }[extract_type] self.constrain_evt_types = None # 2d pad self.padder_mask = DataPadder(2, pad_vals=0.) self.padder_idxes = DataPadder( 2, pad_vals=0) # todo(warn): 0 for full-nil self.padder_items = DataPadder(2, pad_vals=None) # 3d pad self.padder_mask_3d = DataPadder(3, pad_vals=0.) self.padder_items_3d = DataPadder(3, pad_vals=None) # ===== # idx tranforms def hlidx2idx(self, hlidx: HLabelIdx) -> int: return hlidx.get_idx(self.hl.eff_max_layer - 1) def idx2hlidx(self, idx: int) -> HLabelIdx: return self.vocab.get_hlidx(idx, self.hl.eff_max_layer) # events possibly filtered by constrain_evt_types def get_events(self, sent): ret = sent.events constrain_evt_types = self.constrain_evt_types if constrain_evt_types is None: return ret else: return [z for z in ret if z.type in constrain_evt_types] def set_constrain_evt_types(self, constrain_evt_types): self.constrain_evt_types = constrain_evt_types # ===== # main procedure def loss(self, insts: List, input_lexi, input_expr, input_mask, margin=0.): raise NotImplementedError() def predict(self, insts: List, input_lexi, input_expr, input_mask): raise NotImplementedError() def lookup(self, insts: List, input_lexi, input_expr, input_mask): raise NotImplementedError() # ===== # basic input specifying for this module # todo(+N): currently entangling things with the least elegant way? # batch inputs for head mode def batch_inputs_h(self, insts: List[Sentence]): key, items_getter = self.extract_type, self.items_getter nil_idx = 0 # get gold/input data and batch all_masks, all_idxes, all_items, all_valid = [], [], [], [] all_idxes2, all_items2 = [], [] # secondary types for sent in insts: preps = sent.preps.get(key) # not cached, rebuild them if preps is None: length = sent.length items = items_getter(sent) # token-idx -> ... prep_masks, prep_idxes, prep_items = [0.] * length, [ nil_idx ] * length, [None] * length prep_idxes2, prep_items2 = [nil_idx] * length, [None] * length if items is None: # todo(note): there are samples that do not have entity annotations (KBP15) # final 0/1 indicates valid or not prep_valid = 0. else: prep_valid = 1. for one_item in items: this_hwidx = one_item.mention.hard_span.head_wid this_hlidx = one_item.type_idx # todo(+N): ignore except the first two types (already ranked by type-freq) if prep_idxes[this_hwidx] == 0: prep_masks[this_hwidx] = 1. prep_idxes[this_hwidx] = self.hlidx2idx( this_hlidx) # change to int here! prep_items[this_hwidx] = one_item elif prep_idxes2[this_hwidx] == 0: prep_idxes2[this_hwidx] = self.hlidx2idx( this_hlidx) # change to int here! prep_items2[this_hwidx] = one_item sent.preps[key] = (prep_masks, prep_idxes, prep_items, prep_valid, prep_idxes2, prep_items2) else: prep_masks, prep_idxes, prep_items, prep_valid, prep_idxes2, prep_items2 = preps # ===== all_masks.append(prep_masks) all_idxes.append(prep_idxes) all_items.append(prep_items) all_valid.append(prep_valid) all_idxes2.append(prep_idxes2) all_items2.append(prep_items2) # pad and batch mention_masks = BK.input_real( self.padder_mask.pad(all_masks)[0]) # [*, slen] mention_idxes = BK.input_idx( self.padder_idxes.pad(all_idxes)[0]) # [*, slen] mention_items_arr, _ = self.padder_items.pad(all_items) # [*, slen] mention_valid = BK.input_real(all_valid) # [*] mention_idxes2 = BK.input_idx( self.padder_idxes.pad(all_idxes2)[0]) # [*, slen] mention_items2_arr, _ = self.padder_items.pad(all_items2) # [*, slen] return mention_masks, mention_idxes, mention_items_arr, mention_valid, mention_idxes2, mention_items2_arr # batch inputs for gene0 mode (separate for each label) def batch_inputs_g0(self, insts: List[Sentence]): # similar to "batch_inputs_h", but further extend for each label key, items_getter = self.extract_type, self.items_getter # nil_idx = 0 # get gold/input data and batch output_size = self.hl_output_size all_masks, all_items, all_valid = [], [], [] for sent in insts: preps = sent.preps.get(key) # not cached, rebuild them if preps is None: length = sent.length items = items_getter(sent) # token-idx -> [slen, out-size] prep_masks = [[0. for _i1 in range(output_size)] for _i0 in range(length)] prep_items = [[None for _i1 in range(output_size)] for _i0 in range(length)] if items is None: # todo(note): there are samples that do not have entity annotations (KBP15) # final 0/1 indicates valid or not prep_valid = 0. else: prep_valid = 1. for one_item in items: this_hwidx = one_item.mention.hard_span.head_wid this_hlidx = one_item.type_idx this_tidx = self.hlidx2idx( this_hlidx) # change to int here! # todo(+N): simply ignore repeated ones with same type and trigger if prep_masks[this_hwidx][this_tidx] == 0.: prep_masks[this_hwidx][this_tidx] = 1. prep_items[this_hwidx][this_tidx] = one_item sent.preps[key] = (prep_masks, prep_items, prep_valid) else: prep_masks, prep_items, prep_valid = preps # ===== all_masks.append(prep_masks) all_items.append(prep_items) all_valid.append(prep_valid) # pad and batch mention_masks = BK.input_real( self.padder_mask_3d.pad(all_masks)[0]) # [*, slen, L] mention_idxes = None mention_items_arr, _ = self.padder_items_3d.pad( all_items) # [*, slen, L] mention_valid = BK.input_real(all_valid) # [*] return mention_masks, mention_idxes, mention_items_arr, mention_valid # batch inputs for gene1 mode (seq-gene mode) # todo(note): the return is different than previous, here directly idx-based def batch_inputs_g1(self, insts: List[Sentence]): train_reverse_evetns = self.conf.train_reverse_evetns # todo(note): this option is from derived class _tmp_f = lambda x: list(reversed(x) ) if train_reverse_evetns else lambda x: x key, items_getter = self.extract_type, self.items_getter # nil_idx = 0 # nil means eos # get gold/input data and batch all_widxes, all_lidxes, all_vmasks, all_items, all_valid = [], [], [], [], [] for sent in insts: preps = sent.preps.get(key) # not cached, rebuild them if preps is None: items = items_getter(sent) # todo(note): directly add, assume they are already sorted in a good way (widx+lidx); 0(nil) as eos if items is None: prep_valid = 0. # prep_widxes, prep_lidxes, prep_vmasks, prep_items = [0], [0], [1.], [None] prep_widxes, prep_lidxes, prep_vmasks, prep_items = [], [], [], [] else: prep_valid = 1. prep_widxes = _tmp_f( [z.mention.hard_span.head_wid for z in items]) + [0] prep_lidxes = _tmp_f( [self.hlidx2idx(z.type_idx) for z in items]) + [0] prep_vmasks = [1.] * (len(items) + 1) prep_items = _tmp_f(items.copy()) + [None] sent.preps[key] = (prep_widxes, prep_lidxes, prep_vmasks, prep_items, prep_valid) else: prep_widxes, prep_lidxes, prep_vmasks, prep_items, prep_valid = preps # ===== all_widxes.append(prep_widxes) all_lidxes.append(prep_lidxes) all_vmasks.append(prep_vmasks) all_items.append(prep_items) all_valid.append(prep_valid) # pad and batch mention_widxes = BK.input_idx( self.padder_idxes.pad(all_widxes)[0]) # [*, ?] mention_lidxes = BK.input_idx( self.padder_idxes.pad(all_lidxes)[0]) # [*, ?] mention_vmasks = BK.input_real( self.padder_mask.pad(all_vmasks)[0]) # [*, ?] mention_items_arr, _ = self.padder_items.pad(all_items) # [*, ?] mention_valid = BK.input_real(all_valid) # [*] return mention_widxes, mention_lidxes, mention_vmasks, mention_items_arr, mention_valid
class FpDecoder(BasicNode): def __init__(self, pc, conf: FpDecConf, label_vocab): super().__init__(pc, None, None) self.conf = conf self.label_vocab = label_vocab self.predict_padder = DataPadder(2, pad_vals=0) # the scorer self.use_ablpair = conf.use_ablpair conf.sconf._input_dim = conf._input_dim conf.sconf._num_label = conf._num_label if self.use_ablpair: self.scorer = self.add_sub_node("s", FpSingleScorer(pc, conf.sconf)) else: self.scorer = self.add_sub_node("s", FpPairedScorer(pc, conf.sconf)) # ----- # scoring def _score(self, enc_expr, mask_expr): # ----- def _special_score( one_score): # specially change ablpair scores into [bs,m,h,*] root_score = one_score[:, :, 0].unsqueeze(2) # [bs, rlen, 1, *] tmp_shape = BK.get_shape(root_score) tmp_shape[1] = 1 # [bs, 1, 1, *] padded_root_score = BK.concat([BK.zeros(tmp_shape), root_score], dim=1) # [bs, rlen+1, 1, *] final_score = BK.concat( [padded_root_score, one_score.transpose(1, 2)], dim=2) # [bs, rlen+1[m], rlen+1[h], *] return final_score # ----- if self.use_ablpair: input_mask_expr = ( mask_expr.unsqueeze(-1) * mask_expr.unsqueeze(-2))[:, 1:] # [bs, rlen, rlen+1] arc_score = self.scorer.transform_and_arc_score( enc_expr, input_mask_expr) # [bs, rlen, rlen+1, 1] lab_score = self.scorer.transform_and_lab_score( enc_expr, input_mask_expr) # [bs, rlen, rlen+1, Lab] # put root-scores for both directions arc_score = _special_score(arc_score) lab_score = _special_score(lab_score) else: # todo(+2): for training, we can simply select and lab-score arc_score = self.scorer.transform_and_arc_score( enc_expr, mask_expr) # [bs, m, h, 1] lab_score = self.scorer.transform_and_lab_score( enc_expr, mask_expr) # [bs, m, h, Lab] # mask out diag scores diag_mask = BK.eye(BK.get_shape(arc_score, 1)) diag_mask[0, 0] = 0. diag_add = Constants.REAL_PRAC_MIN * ( diag_mask.unsqueeze(-1).unsqueeze(0)) # [1, m, h, 1] arc_score += diag_add lab_score += diag_add return arc_score, lab_score # loss # todo(note): no margins here, simply using target-selection cross-entropy def loss(self, insts: List[ParseInstance], enc_expr, mask_expr, **kwargs): conf = self.conf # scoring arc_score, lab_score = self._score(enc_expr, mask_expr) # [bs, m, h, *] # loss bsize, max_len = BK.get_shape(mask_expr) # gold heads and labels gold_heads_arr, _ = self.predict_padder.pad( [z.heads.vals for z in insts]) # todo(note): here use the original idx of label, no shift! gold_labels_arr, _ = self.predict_padder.pad( [z.labels.idxes for z in insts]) gold_heads_expr = BK.input_idx(gold_heads_arr) # [bs, Len] gold_labels_expr = BK.input_idx(gold_labels_arr) # [bs, Len] # collect the losses arange_bs_expr = BK.arange_idx(bsize).unsqueeze(-1) # [bs, 1] arange_m_expr = BK.arange_idx(max_len).unsqueeze(0) # [1, Len] # logsoftmax and losses arc_logsoftmaxs = BK.log_softmax(arc_score.squeeze(-1), -1) # [bs, m, h] lab_logsoftmaxs = BK.log_softmax(lab_score, -1) # [bs, m, h, Lab] arc_sel_ls = arc_logsoftmaxs[arange_bs_expr, arange_m_expr, gold_heads_expr] # [bs, Len] lab_sel_ls = lab_logsoftmaxs[arange_bs_expr, arange_m_expr, gold_heads_expr, gold_labels_expr] # [bs, Len] # head selection (no root) arc_loss_sum = (-arc_sel_ls * mask_expr)[:, 1:].sum() lab_loss_sum = (-lab_sel_ls * mask_expr)[:, 1:].sum() final_loss = conf.lambda_arc * arc_loss_sum + conf.lambda_lab * lab_loss_sum final_loss_count = mask_expr[:, 1:].sum() return [[final_loss, final_loss_count]] # decode def predict(self, insts: List[ParseInstance], enc_expr, mask_expr, **kwargs): conf = self.conf # scoring arc_score, lab_score = self._score(enc_expr, mask_expr) # [bs, m, h, *] full_score = BK.log_softmax(arc_score, -2) + BK.log_softmax( lab_score, -1) # [bs, m, h, Lab] # decode mst_lengths = [len(z) + 1 for z in insts] # +1 to include ROOT for mst decoding mst_lengths_arr = np.asarray(mst_lengths, dtype=np.int32) mst_heads_arr, mst_labels_arr, mst_scores_arr = \ nmst_unproj(full_score, mask_expr, mst_lengths_arr, labeled=True, ret_arr=True) # ===== assign, todo(warn): here, the labels are directly original idx, no need to change misc_prefix = "g" for one_idx, one_inst in enumerate(insts): cur_length = mst_lengths[one_idx] one_inst.pred_heads.set_vals( mst_heads_arr[one_idx] [:cur_length]) # directly int-val for heads one_inst.pred_labels.build_vals( mst_labels_arr[one_idx][:cur_length], self.label_vocab) one_scores = mst_scores_arr[one_idx][:cur_length] one_inst.pred_par_scores.set_vals(one_scores) # extra output one_inst.extra_pred_misc[misc_prefix + "_score"] = one_scores.tolist()
def __init__(self, pc: BK.ParamCollection, bconf: BTConf, tconf: 'BaseTrainingConf', vpack: VocabPackage): super().__init__(pc, None, None) self.bconf = bconf # ===== Vocab ===== self.word_vocab = vpack.get_voc("word") self.char_vocab = vpack.get_voc("char") self.lemma_vocab = vpack.get_voc("lemma") self.upos_vocab = vpack.get_voc("upos") self.ulabel_vocab = vpack.get_voc("ulabel") # ===== Model ===== # embedding self.emb = self.add_sub_node("emb", MyEmbedder(self.pc, bconf.emb_conf, vpack)) emb_output_dim = self.emb.get_output_dims()[0] self.emb_output_dim = emb_output_dim # doc hint self.use_doc_hint = bconf.use_doc_hint self.dh_combine_method = bconf.dh_combine_method if self.use_doc_hint: assert len(bconf.emb_conf.dim_auxes)>0 # todo(note): currently use the concat of them if input multiple layers bconf.dh_conf._input_dim = bconf.emb_conf.dim_auxes[0] # same as input bert dim bconf.dh_conf._output_dim = emb_output_dim # same as emb_output_dim self.dh_node = self.add_sub_node("dh", DocHintModule(pc, bconf.dh_conf)) else: self.dh_node = None # encoders # shared # todo(note): feed compute-on-the-fly hp bconf.enc_conf._input_dim = emb_output_dim self.enc = self.add_sub_node("enc", MyEncoder(self.pc, bconf.enc_conf)) tmp_enc_output_dim = self.enc.get_output_dims()[0] # privates bconf.enc_ef_conf._input_dim = tmp_enc_output_dim self.enc_ef = self.add_sub_node("enc_ef", MyEncoder(self.pc, bconf.enc_ef_conf)) self.enc_ef_output_dim = self.enc_ef.get_output_dims()[0] bconf.enc_evt_conf._input_dim = tmp_enc_output_dim self.enc_evt = self.add_sub_node("enc_evt", MyEncoder(self.pc, bconf.enc_evt_conf)) self.enc_evt_output_dim = self.enc_evt.get_output_dims()[0] # ===== Input Specification ===== # inputs (word, lemma, char, upos, ulabel) and vocabulary self.need_word = self.emb.has_word self.need_char = self.emb.has_char # extra fields # todo(warn): need to self.need_lemma = False self.need_upos = False self.need_ulabel = False for one_extra_name in self.emb.extra_names: if one_extra_name == "lemma": self.need_lemma = True elif one_extra_name == "upos": self.need_upos = True elif one_extra_name == "ulabel": self.need_ulabel = True else: raise NotImplementedError("UNK extra input name: " + one_extra_name) # todo(warn): currently only allow one aux field self.need_aux = False if len(self.emb.dim_auxes) > 0: assert len(self.emb.dim_auxes) == 1 self.need_aux = True # padders self.word_padder = DataPadder(2, pad_vals=self.word_vocab.pad, mask_range=2) self.char_padder = DataPadder(3, pad_lens=(0, 0, bconf.char_max_length), pad_vals=self.char_vocab.pad) self.lemma_padder = DataPadder(2, pad_vals=self.lemma_vocab.pad) self.upos_padder = DataPadder(2, pad_vals=self.upos_vocab.pad) self.ulabel_padder = DataPadder(2, pad_vals=self.ulabel_vocab.pad) # self.random_sample_stream = Random.stream(Random.random_sample) self.train_skip_noevt_rate = tconf.train_skip_noevt_rate self.train_skip_length = tconf.train_skip_length self.train_min_length = tconf.train_min_length self.test_min_length = tconf.test_min_length self.test_skip_noevt_rate = tconf.test_skip_noevt_rate self.train_sent_based = tconf.train_sent_based # assert not self.train_sent_based, "The basic model should not use this sent-level mode!"
class MyIEBT(BasicNode): def __init__(self, pc: BK.ParamCollection, bconf: BTConf, tconf: 'BaseTrainingConf', vpack: VocabPackage): super().__init__(pc, None, None) self.bconf = bconf # ===== Vocab ===== self.word_vocab = vpack.get_voc("word") self.char_vocab = vpack.get_voc("char") self.lemma_vocab = vpack.get_voc("lemma") self.upos_vocab = vpack.get_voc("upos") self.ulabel_vocab = vpack.get_voc("ulabel") # ===== Model ===== # embedding self.emb = self.add_sub_node("emb", MyEmbedder(self.pc, bconf.emb_conf, vpack)) emb_output_dim = self.emb.get_output_dims()[0] self.emb_output_dim = emb_output_dim # doc hint self.use_doc_hint = bconf.use_doc_hint self.dh_combine_method = bconf.dh_combine_method if self.use_doc_hint: assert len(bconf.emb_conf.dim_auxes)>0 # todo(note): currently use the concat of them if input multiple layers bconf.dh_conf._input_dim = bconf.emb_conf.dim_auxes[0] # same as input bert dim bconf.dh_conf._output_dim = emb_output_dim # same as emb_output_dim self.dh_node = self.add_sub_node("dh", DocHintModule(pc, bconf.dh_conf)) else: self.dh_node = None # encoders # shared # todo(note): feed compute-on-the-fly hp bconf.enc_conf._input_dim = emb_output_dim self.enc = self.add_sub_node("enc", MyEncoder(self.pc, bconf.enc_conf)) tmp_enc_output_dim = self.enc.get_output_dims()[0] # privates bconf.enc_ef_conf._input_dim = tmp_enc_output_dim self.enc_ef = self.add_sub_node("enc_ef", MyEncoder(self.pc, bconf.enc_ef_conf)) self.enc_ef_output_dim = self.enc_ef.get_output_dims()[0] bconf.enc_evt_conf._input_dim = tmp_enc_output_dim self.enc_evt = self.add_sub_node("enc_evt", MyEncoder(self.pc, bconf.enc_evt_conf)) self.enc_evt_output_dim = self.enc_evt.get_output_dims()[0] # ===== Input Specification ===== # inputs (word, lemma, char, upos, ulabel) and vocabulary self.need_word = self.emb.has_word self.need_char = self.emb.has_char # extra fields # todo(warn): need to self.need_lemma = False self.need_upos = False self.need_ulabel = False for one_extra_name in self.emb.extra_names: if one_extra_name == "lemma": self.need_lemma = True elif one_extra_name == "upos": self.need_upos = True elif one_extra_name == "ulabel": self.need_ulabel = True else: raise NotImplementedError("UNK extra input name: " + one_extra_name) # todo(warn): currently only allow one aux field self.need_aux = False if len(self.emb.dim_auxes) > 0: assert len(self.emb.dim_auxes) == 1 self.need_aux = True # padders self.word_padder = DataPadder(2, pad_vals=self.word_vocab.pad, mask_range=2) self.char_padder = DataPadder(3, pad_lens=(0, 0, bconf.char_max_length), pad_vals=self.char_vocab.pad) self.lemma_padder = DataPadder(2, pad_vals=self.lemma_vocab.pad) self.upos_padder = DataPadder(2, pad_vals=self.upos_vocab.pad) self.ulabel_padder = DataPadder(2, pad_vals=self.ulabel_vocab.pad) # self.random_sample_stream = Random.stream(Random.random_sample) self.train_skip_noevt_rate = tconf.train_skip_noevt_rate self.train_skip_length = tconf.train_skip_length self.train_min_length = tconf.train_min_length self.test_min_length = tconf.test_min_length self.test_skip_noevt_rate = tconf.test_skip_noevt_rate self.train_sent_based = tconf.train_sent_based # assert not self.train_sent_based, "The basic model should not use this sent-level mode!" def get_output_dims(self, *input_dims): return ([self.enc_ef_output_dim, self.enc_evt_output_dim], ) # def refresh(self, rop=None): zfatal("Should call special_refresh instead!") # ==== # special routines def special_refresh(self, embed_rop, other_rop): self.emb.refresh(embed_rop) self.enc.refresh(other_rop) self.enc_ef.refresh(other_rop) self.enc_evt.refresh(other_rop) if self.dh_node is not None: return self.dh_node.refresh(other_rop) def prepare_training_rop(self): mconf = self.bconf embed_rop = RefreshOptions(hdrop=mconf.drop_embed, dropmd=mconf.dropmd_embed, fix_drop=mconf.fix_drop) other_rop = RefreshOptions(hdrop=mconf.drop_hidden, idrop=mconf.idrop_rnn, gdrop=mconf.gdrop_rnn, fix_drop=mconf.fix_drop) return embed_rop, other_rop # ===== # run def _prepare_input(self, sents: List[Sentence], training: bool): word_arr, char_arr, extra_arrs, aux_arrs = None, None, [], [] # ===== specially prepare for the words wv = self.word_vocab W_UNK = wv.unk UNK_REP_RATE = self.bconf.singleton_unk UNK_REP_THR = self.bconf.singleton_thr word_act_idxes = [] if training and UNK_REP_RATE>0.: # replace unfreq/singleton words with UNK for one_inst in sents: one_act_idxes = [] for one_idx in one_inst.words.idxes: one_freq = wv.idx2val(one_idx) if one_freq is not None and one_freq >= 1 and one_freq <= UNK_REP_THR: if next(self.random_sample_stream) < (UNK_REP_RATE/one_freq): one_idx = W_UNK one_act_idxes.append(one_idx) word_act_idxes.append(one_act_idxes) else: word_act_idxes = [z.words.idxes for z in sents] # todo(warn): still need the masks word_arr, mask_arr = self.word_padder.pad(word_act_idxes) # ===== if not self.need_word: word_arr = None if self.need_char: chars = [z.chars.idxes for z in sents] char_arr, _ = self.char_padder.pad(chars) # extra ones: lemma, upos, ulabel if self.need_lemma: lemmas = [z.lemmas.idxes for z in sents] lemmas_arr, _ = self.lemma_padder.pad(lemmas) extra_arrs.append(lemmas_arr) if self.need_upos: uposes = [z.uposes.idxes for z in sents] upos_arr, _ = self.upos_padder.pad(uposes) extra_arrs.append(upos_arr) if self.need_ulabel: ulabels = [z.ud_labels.idxes for z in sents] ulabels_arr, _ = self.ulabel_padder.pad(ulabels) extra_arrs.append(ulabels_arr) # aux ones if self.need_aux: aux_arr_list = [z.extra_features["aux_repr"] for z in sents] # pad padded_seq_len = int(mask_arr.shape[1]) final_aux_arr_list = [] for cur_arr in aux_arr_list: cur_len = len(cur_arr) if cur_len > padded_seq_len: final_aux_arr_list.append(cur_arr[:padded_seq_len]) else: final_aux_arr_list.append(np.pad(cur_arr, ((0,padded_seq_len-cur_len),(0,0)), 'constant')) aux_arrs.append(np.stack(final_aux_arr_list, 0)) # input_repr = self.emb(word_arr, char_arr, extra_arrs, aux_arrs) # [BS, Len, Dim], [BS, Len] return input_repr, mask_arr # input: Tuple(Sentence, ...) def _bucket_sents_by_length(self, all_sents, enc_bucket_range: int, getlen_f=lambda x: x[0].length, max_bsize=None): # split into buckets all_buckets = [] cur_local_sidx = 0 use_max_bsize = (max_bsize is not None) while cur_local_sidx < len(all_sents): cur_bucket = [] starting_slen = getlen_f(all_sents[cur_local_sidx]) ending_slen = starting_slen + enc_bucket_range # searching forward tmp_sidx = cur_local_sidx while tmp_sidx < len(all_sents): one_sent = all_sents[tmp_sidx] one_slen = getlen_f(one_sent) if one_slen>ending_slen or (use_max_bsize and len(cur_bucket)>=max_bsize): break else: cur_bucket.append(one_sent) tmp_sidx += 1 # put bucket and next all_buckets.append(cur_bucket) cur_local_sidx = tmp_sidx return all_buckets # todo(warn): for rnn, need to transpose masks, thus need np.array # TODO(+N): here we encode only at sentence level, encoding doc level might be helpful, but much harder to batch # therefore, still take DOC as input, since may be extended to doc-level encoding # return input_repr, enc_repr, mask_arr def run(self, insts: List[DocInstance], training: bool): # make it as sentence level processing (group the sentences by length, and ignore doc level for now) # skip no content sentences in training? # assert not self.train_sent_based, "The basic model should not use this sent-level mode!" all_sents = [] # (inst, d_idx, s_idx) for d_idx, one_doc in enumerate(insts): for s_idx, x in enumerate(one_doc.sents): if training: if x.length<self.train_skip_length and x.length>=self.train_min_length \ and (len(x.events)>0 or next(self.random_sample_stream)>self.train_skip_noevt_rate): all_sents.append((x, d_idx, s_idx)) else: if x.length >= self.test_min_length: all_sents.append((x, d_idx, s_idx)) return self.run_sents(all_sents, insts, training) # input interested sents: Tuple[Sentence, DocId, SentId] def run_sents(self, all_sents: List, all_docs: List[DocInstance], training: bool, use_one_bucket=False): if use_one_bucket: all_buckets = [all_sents] # when we do not want to split if we know the input lengths do not vary too much else: all_sents.sort(key=lambda x: x[0].length) all_buckets = self._bucket_sents_by_length(all_sents, self.bconf.enc_bucket_range) # doc hint use_doc_hint = self.use_doc_hint if use_doc_hint: dh_sent_repr = self.dh_node.run(all_docs) # [NumDoc, MaxSent, D] else: dh_sent_repr = None # encoding for each of the bucket rets = [] dh_add, dh_both, dh_cls = [self.dh_combine_method==z for z in ["add", "both", "cls"]] for one_bucket in all_buckets: one_sents = [z[0] for z in one_bucket] # [BS, Len, Di], [BS, Len] input_repr0, mask_arr0 = self._prepare_input(one_sents, training) if use_doc_hint: one_d_idxes = BK.input_idx([z[1] for z in one_bucket]) one_s_idxes = BK.input_idx([z[2] for z in one_bucket]) one_s_reprs = dh_sent_repr[one_d_idxes, one_s_idxes].unsqueeze(-2) # [BS, 1, D] if dh_add: input_repr = input_repr0 + one_s_reprs # [BS, slen, D] mask_arr = mask_arr0 elif dh_both: input_repr = BK.concat([one_s_reprs, input_repr0, one_s_reprs], -2) # [BS, 2+slen, D] mask_arr = np.pad(mask_arr0, ((0,0),(1,1)), 'constant', constant_values=1.) # [BS, 2+slen] elif dh_cls: input_repr = BK.concat([one_s_reprs, input_repr0[:, 1:]], -2) # [BS, slen, D] mask_arr = mask_arr0 else: raise NotImplementedError() else: input_repr, mask_arr = input_repr0, mask_arr0 # [BS, Len, De] enc_repr = self.enc(input_repr, mask_arr) # separate ones (possibly using detach to avoid gradients for some of them) enc_repr_ef = self.enc_ef(enc_repr.detach() if self.bconf.enc_ef_input_detach else enc_repr, mask_arr) enc_repr_evt = self.enc_evt(enc_repr.detach() if self.bconf.enc_evt_input_detach else enc_repr, mask_arr) if use_doc_hint and dh_both: one_ret = (one_sents, input_repr0, enc_repr_ef[:, 1:-1].contiguous(), enc_repr_evt[:, 1:-1].contiguous(), mask_arr0) else: one_ret = (one_sents, input_repr0, enc_repr_ef, enc_repr_evt, mask_arr0) rets.append(one_ret) # todo(note): returning tuple is (List[Sentence], Tensor, Tensor, Tensor) return rets # special routine def aug_words_and_embs(self, aug_vocab, aug_wv): orig_vocab = self.word_vocab orig_arr = self.emb.word_embed.E.detach().cpu().numpy() # todo(+2): find same-spelling words in the original vocab if not-hit in the extra_embed? # todo(warn): here aug_vocab should be find in aug_wv aug_arr = aug_vocab.filter_embed(aug_wv, assert_all_hit=True) new_vocab, new_arr = MultiHelper.aug_vocab_and_arr(orig_vocab, orig_arr, aug_vocab, aug_arr, aug_override=True) # assign self.word_vocab = new_vocab self.emb.word_embed.replace_weights(new_arr) return new_vocab
def main(): np.random.seed(1234) NUM_POS = 10 # build vocabs reader = TextReader("./test_utils.py") vb_word = VocabBuilder("w") vb_char = VocabBuilder("c") for one in reader: vb_word.feed_stream(one.tokens) vb_char.feed_stream((c for w in one.tokens for c in w)) voc_word = vb_word.finish() voc_char = vb_char.finish() voc_pos = VocabBuilder.build_from_stream(range(NUM_POS), name="pos") vpack = VocabPackage({ "word": voc_word, "char": voc_char, "pos": voc_pos }, {"word": None}) # build model pc = BK.ParamCollection() conf_emb = EmbedConf().init_from_kwargs(init_words_from_pretrain=False, dim_char=10, dim_posi=10, emb_proj_dim=400, dim_extras="50", extra_names="pos") conf_emb.do_validate() mod_emb = MyEmbedder(pc, conf_emb, vpack) conf_enc = EncConf().init_from_kwargs(enc_rnn_type="lstm2", enc_cnn_layer=1, enc_att_layer=1) conf_enc._input_dim = mod_emb.get_output_dims()[0] mod_enc = MyEncoder(pc, conf_enc) enc_output_dim = mod_enc.get_output_dims()[0] mod_scorer = BiAffineScorer(pc, enc_output_dim, enc_output_dim, 10) # build data word_padder = DataPadder(2, pad_lens=(0, 50), mask_range=2) char_padder = DataPadder(3, pad_lens=(0, 50, 20)) word_idxes = [] char_idxes = [] pos_idxes = [] for toks in reader: one_words = [] one_chars = [] for w in toks.tokens: one_words.append(voc_word.get_else_unk(w)) one_chars.append([voc_char.get_else_unk(c) for c in w]) word_idxes.append(one_words) char_idxes.append(one_chars) pos_idxes.append( np.random.randint(voc_pos.trg_len(), size=len(one_words)) + 1) # pred->trg word_arr, word_mask_arr = word_padder.pad(word_idxes) pos_arr, _ = word_padder.pad(pos_idxes) char_arr, _ = char_padder.pad(char_idxes) # # run rop = layers.RefreshOptions(hdrop=0.2, gdrop=0.2, fix_drop=True) for _ in range(5): mod_emb.refresh(rop) mod_enc.refresh(rop) mod_scorer.refresh(rop) # expr_emb = mod_emb(word_arr, char_arr, [pos_arr]) zlog(BK.get_shape(expr_emb)) expr_enc = mod_enc(expr_emb, word_mask_arr) zlog(BK.get_shape(expr_enc)) # mask_expr = BK.input_real(word_mask_arr) score0 = mod_scorer.paired_score(expr_enc, expr_enc, mask_expr, mask_expr) score1 = mod_scorer.plain_score(expr_enc.unsqueeze(-2), expr_enc.unsqueeze(-3), mask_expr.unsqueeze(-1), mask_expr.unsqueeze(-2)) # zmiss = float(BK.avg(score0 - score1)) assert zmiss < 0.0001 zlog("OK") pass
def __init__(self, comp_name, vpack: VocabPackage): super().__init__(comp_name, vpack) self.padder = DataPadder(3, pad_vals=0) # replace the padder
def __init__(self, comp_name, vpack: VocabPackage): self.comp_name = comp_name self.comp_seq_name = f"{comp_name}_seq" self.voc = vpack.get_voc(comp_name) self.padder = DataPadder(2, pad_vals=0) # pad 0
class GraphParser(BaseParser): def __init__(self, conf: GraphParserConf, vpack: VocabPackage): super().__init__(conf, vpack) # ===== Input Specification ===== # both head/label padding with 0 (does not matter what, since will be masked) self.predict_padder = DataPadder(2, pad_vals=0) self.hlocal_padder = DataPadder(3, pad_vals=0.) # # todo(warn): adding-styled hlocal has problems intuitively, maybe not suitable for graph-parser self.norm_single, self.norm_local, self.norm_global, self.norm_hlocal = \ [conf.output_normalizing==z for z in ["single", "local", "global", "hlocal"]] self.loss_prob, self.loss_hinge, self.loss_mr = [ conf.tconf.loss_function == z for z in ["prob", "hinge", "mr"] ] self.alg_proj, self.alg_unproj, self.alg_greedy = [ conf.iconf.dec_algorithm == z for z in ["proj", "unproj", "greedy"] ] # def build_decoder(self): conf = self.conf conf.sc_conf._input_dim = self.enc_output_dim conf.sc_conf._num_label = self.label_vocab.trg_len(False) return GraphScorer(self.pc, conf.sc_conf) # shared calculations for final scoring # -> the scores are masked with PRAC_MIN (by the scorer) for the paddings, but not handling diag here! def _prepare_score(self, insts, training): input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( insts, training) mask_expr = BK.input_real(mask_arr) # am_expr, ah_expr, lm_expr, lh_expr = self.scorer.transform_space(enc_repr) scoring_expr_pack = self.scorer.transform_space(enc_repr) return scoring_expr_pack, mask_expr, jpos_pack # scoring procedures def _score_arc_full(self, scoring_expr_pack, mask_expr, training, margin, gold_heads_expr=None): am_expr, ah_expr, _, _ = scoring_expr_pack # [BS, len-m, len-h] full_arc_score = self.scorer.score_arc_all(am_expr, ah_expr, mask_expr, mask_expr) # # set diag to small values # todo(warn): handled specifically in algorithms # maxlen = BK.get_shape(full_arc_score, 1) # full_arc_score += BK.diagflat(BK.constants([maxlen], Constants.REAL_PRAC_MIN)) # margin? if training and margin > 0.: full_arc_score = BK.minus_margin(full_arc_score, gold_heads_expr, margin) return full_arc_score def _score_label_full(self, scoring_expr_pack, mask_expr, training, margin, gold_heads_expr=None, gold_labels_expr=None): _, _, lm_expr, lh_expr = scoring_expr_pack # [BS, len-m, len-h, L] full_label_score = self.scorer.score_label_all(lm_expr, lh_expr, mask_expr, mask_expr) # # set diag to small values # todo(warn): handled specifically in algorithms # maxlen = BK.get_shape(full_label_score, 1) # full_label_score += BK.diagflat(BK.constants([maxlen], Constants.REAL_PRAC_MIN)).unsqueeze(-1) # margin? -- specially reshaping if training and margin > 0.: full_shape = BK.get_shape(full_label_score) # combine last two dim combiend_score_expr = full_label_score.view(full_shape[:-2] + [-1]) combined_idx_expr = gold_heads_expr * full_shape[ -1] + gold_labels_expr combined_changed_score = BK.minus_margin(combiend_score_expr, combined_idx_expr, margin) full_label_score = combined_changed_score.view(full_shape) return full_label_score def _score_label_selected(self, scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr=None): _, _, lm_expr, lh_expr = scoring_expr_pack # [BS, len-m, D] lh_expr_shape = BK.get_shape(lh_expr) selected_lh_expr = BK.gather( lh_expr, gold_heads_expr.unsqueeze(-1).expand(*lh_expr_shape), dim=len(lh_expr_shape) - 2) # [BS, len-m, L] select_label_score = self.scorer.score_label_select( lm_expr, selected_lh_expr, mask_expr) # margin? if training and margin > 0.: select_label_score = BK.minus_margin(select_label_score, gold_labels_expr, margin) return select_label_score # for global-norm + hinge(perceptron-like)-loss # [*, m, h, L], [*, m] def _losses_global_hinge(self, full_score_expr, gold_heads_expr, gold_labels_expr, pred_heads_expr, pred_labels_expr, mask_expr): return GraphParser.get_losses_global_hinge(full_score_expr, gold_heads_expr, gold_labels_expr, pred_heads_expr, pred_labels_expr, mask_expr) # ===== export @staticmethod def get_losses_global_hinge(full_score_expr, gold_heads_expr, gold_labels_expr, pred_heads_expr, pred_labels_expr, mask_expr, clamping=True): # combine the last two dimension full_shape = BK.get_shape(full_score_expr) # [*, m, h*L] last_size = full_shape[-1] combiend_score_expr = full_score_expr.view(full_shape[:-2] + [-1]) # [*, m] gold_combined_idx_expr = gold_heads_expr * last_size + gold_labels_expr pred_combined_idx_expr = pred_heads_expr * last_size + pred_labels_expr # [*, m] gold_scores = BK.gather_one_lastdim(combiend_score_expr, gold_combined_idx_expr).squeeze(-1) pred_scores = BK.gather_one_lastdim(combiend_score_expr, pred_combined_idx_expr).squeeze(-1) # todo(warn): be aware of search error! # hinge_losses = BK.clamp(pred_scores - gold_scores, min=0.) # this is previous version hinge_losses = pred_scores - gold_scores # [*, len] if clamping: valid_losses = ((hinge_losses * mask_expr)[:, 1:].sum(-1) > 0.).float().unsqueeze(-1) # [*, 1] return hinge_losses * valid_losses else: # for this mode, will there be problems of search error? Maybe rare. return hinge_losses # ===== # for global-norm + prob-loss def _losses_global_prob(self, full_score_expr, gold_heads_expr, gold_labels_expr, marginals_expr, mask_expr): # combine the last two dimension full_shape = BK.get_shape(full_score_expr) last_size = full_shape[-1] # [*, m, h*L] combined_marginals_expr = marginals_expr.view(full_shape[:-2] + [-1]) # # todo(warn): make sure sum to 1., handled in algorithm instead # combined_marginals_expr = combined_marginals_expr / combined_marginals_expr.sum(dim=-1, keepdim=True) # [*, m] gold_combined_idx_expr = gold_heads_expr * last_size + gold_labels_expr # [*, m, h, L] gradients = BK.minus_margin(combined_marginals_expr, gold_combined_idx_expr, 1.).view(full_shape) # the gradients on h are already 0. from the marginal algorithm gradients_masked = gradients * mask_expr.unsqueeze(-1).unsqueeze( -1) * mask_expr.unsqueeze(-2).unsqueeze(-1) # for the h-dimension, need to divide by the real length. # todo(warn): this values should be directly summed rather than averaged, since directly from loss fake_losses = (full_score_expr * gradients_masked).sum(-1).sum( -1) # [BS, m] # todo(warn): be aware of search-error-like output constrains; # but this clamp for all is not good for loss-prob, dealt at outside with unproj-mask. # <bad> fake_losses = BK.clamp(fake_losses, min=0.) return fake_losses # for single-norm: 0-1 loss # [*, L], [*], float def _losses_single(self, score_expr, gold_idxes_expr, single_sample, is_hinge=False, margin=0.): # expand the idxes to 0/1 score_shape = BK.get_shape(score_expr) expanded_idxes_expr = BK.constants(score_shape, 0.) expanded_idxes_expr = BK.minus_margin(expanded_idxes_expr, gold_idxes_expr, -1.) # minus -1 means +1 # todo(+N): first adjust margin, since previously only minus margin for golds? if margin > 0.: adjusted_scores = margin + BK.minus_margin(score_expr, gold_idxes_expr, margin) else: adjusted_scores = score_expr # [*, L] if is_hinge: # multiply pos instances with -1 flipped_scores = adjusted_scores * (1. - 2 * expanded_idxes_expr) losses_all = BK.clamp(flipped_scores, min=0.) else: losses_all = BK.binary_cross_entropy_with_logits( adjusted_scores, expanded_idxes_expr, reduction='none') # special interpretation (todo(+2): there can be better implementation) if single_sample < 1.: # todo(warn): lower bound of sample_rate, ensure 2 samples real_sample_rate = max(single_sample, 2. / score_shape[-1]) elif single_sample >= 2.: # including the positive one real_sample_rate = max(single_sample, 2.) / score_shape[-1] else: # [1., 2.) real_sample_rate = single_sample # if real_sample_rate < 1.: sample_weight = BK.random_bernoulli(score_shape, real_sample_rate, 1.) # make sure positive is valid sample_weight = (sample_weight + expanded_idxes_expr.float()).clamp_(0., 1.) # final_losses = (losses_all * sample_weight).sum(-1) / sample_weight.sum(-1) else: final_losses = losses_all.mean(-1) return final_losses # ===== # expr[BS, m, h, L], arr[BS] -> arr[BS, m] def _decode(self, full_score_expr, maske_expr, lengths_arr): if self.alg_unproj: return nmst_unproj(full_score_expr, maske_expr, lengths_arr, labeled=True, ret_arr=True) elif self.alg_proj: return nmst_proj(full_score_expr, maske_expr, lengths_arr, labeled=True, ret_arr=True) elif self.alg_greedy: return nmst_greedy(full_score_expr, maske_expr, lengths_arr, labeled=True, ret_arr=True) else: zfatal("Unknown decoding algorithm " + self.conf.iconf.dec_algorithm) return None # expr[BS, m, h, L], arr[BS] -> expr[BS, m, h, L] def _marginal(self, full_score_expr, maske_expr, lengths_arr): if self.alg_unproj: marginals_expr = nmarginal_unproj(full_score_expr, maske_expr, lengths_arr, labeled=True) elif self.alg_proj: marginals_expr = nmarginal_proj(full_score_expr, maske_expr, lengths_arr, labeled=True) else: zfatal( "Unsupported marginal-calculation for the decoding algorithm of " + self.conf.iconf.dec_algorithm) marginals_expr = None return marginals_expr # ===== main methods: training and decoding # only getting scores def score_on_batch(self, insts: List[ParseInstance]): with BK.no_grad_env(): self.refresh_batch(False) scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score( insts, False) full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr, False, 0.) full_label_score = self._score_label_full(scoring_expr_pack, mask_expr, False, 0.) return full_arc_score, full_label_score # full score and inference def inference_on_batch(self, insts: List[ParseInstance], **kwargs): with BK.no_grad_env(): self.refresh_batch(False) # ===== calculate scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score( insts, False) full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr, False, 0.) full_label_score = self._score_label_full(scoring_expr_pack, mask_expr, False, 0.) # normalizing scores full_score = None final_exp_score = False # whether to provide PROB by exp if self.norm_local and self.loss_prob: full_score = BK.log_softmax(full_arc_score, -1).unsqueeze(-1) + BK.log_softmax( full_label_score, -1) final_exp_score = True elif self.norm_hlocal and self.loss_prob: # normalize at m dimension, ignore each nodes's self-finish step. full_score = BK.log_softmax(full_arc_score, -2).unsqueeze(-1) + BK.log_softmax( full_label_score, -1) elif self.norm_single and self.loss_prob: if self.conf.iconf.dec_single_neg: # todo(+2): add all-neg for prob explanation full_arc_probs = BK.sigmoid(full_arc_score) full_label_probs = BK.sigmoid(full_label_score) fake_arc_scores = BK.log(full_arc_probs) - BK.log( 1. - full_arc_probs) fake_label_scores = BK.log(full_label_probs) - BK.log( 1. - full_label_probs) full_score = fake_arc_scores.unsqueeze( -1) + fake_label_scores else: full_score = BK.logsigmoid(full_arc_score).unsqueeze( -1) + BK.logsigmoid(full_label_score) final_exp_score = True else: full_score = full_arc_score.unsqueeze(-1) + full_label_score # decode mst_lengths = [len(z) + 1 for z in insts ] # +=1 to include ROOT for mst decoding mst_heads_arr, mst_labels_arr, mst_scores_arr = self._decode( full_score, mask_expr, np.asarray(mst_lengths, dtype=np.int32)) if final_exp_score: mst_scores_arr = np.exp(mst_scores_arr) # jpos prediction (directly index, no converting as in parsing) jpos_preds_expr = jpos_pack[2] has_jpos_pred = jpos_preds_expr is not None jpos_preds_arr = BK.get_value( jpos_preds_expr) if has_jpos_pred else None # ===== assign info = {"sent": len(insts), "tok": sum(mst_lengths) - len(insts)} mst_real_labels = self.pred2real_labels(mst_labels_arr) for one_idx, one_inst in enumerate(insts): cur_length = mst_lengths[one_idx] one_inst.pred_heads.set_vals( mst_heads_arr[one_idx] [:cur_length]) # directly int-val for heads one_inst.pred_labels.build_vals( mst_real_labels[one_idx][:cur_length], self.label_vocab) one_inst.pred_par_scores.set_vals( mst_scores_arr[one_idx][:cur_length]) if has_jpos_pred: one_inst.pred_poses.build_vals( jpos_preds_arr[one_idx][:cur_length], self.bter.pos_vocab) return info # list(mini-batch) of annotated instances # optional results are written in-place? return info. def fb_on_batch(self, annotated_insts, training=True, loss_factor=1, **kwargs): self.refresh_batch(training) margin = self.margin.value # gold heads and labels gold_heads_arr, _ = self.predict_padder.pad( [z.heads.vals for z in annotated_insts]) gold_labels_arr, _ = self.predict_padder.pad( [self.real2pred_labels(z.labels.idxes) for z in annotated_insts]) gold_heads_expr = BK.input_idx(gold_heads_arr) # [BS, Len] gold_labels_expr = BK.input_idx(gold_labels_arr) # [BS, Len] # ===== calculate scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score( annotated_insts, training) full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr, training, margin, gold_heads_expr) # final_losses = None if self.norm_local or self.norm_single: select_label_score = self._score_label_selected( scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) # already added margin previously losses_heads = losses_labels = None if self.loss_prob: if self.norm_local: losses_heads = BK.loss_nll(full_arc_score, gold_heads_expr) losses_labels = BK.loss_nll(select_label_score, gold_labels_expr) elif self.norm_single: single_sample = self.conf.tconf.loss_single_sample losses_heads = self._losses_single(full_arc_score, gold_heads_expr, single_sample, is_hinge=False) losses_labels = self._losses_single(select_label_score, gold_labels_expr, single_sample, is_hinge=False) # simply adding final_losses = losses_heads + losses_labels elif self.loss_hinge: if self.norm_local: losses_heads = BK.loss_hinge(full_arc_score, gold_heads_expr) losses_labels = BK.loss_hinge(select_label_score, gold_labels_expr) elif self.norm_single: single_sample = self.conf.tconf.loss_single_sample losses_heads = self._losses_single(full_arc_score, gold_heads_expr, single_sample, is_hinge=True, margin=margin) losses_labels = self._losses_single(select_label_score, gold_labels_expr, single_sample, is_hinge=True, margin=margin) # simply adding final_losses = losses_heads + losses_labels elif self.loss_mr: # special treatment! probs_heads = BK.softmax(full_arc_score, dim=-1) # [bs, m, h] probs_labels = BK.softmax(select_label_score, dim=-1) # [bs, m, h] # select probs_head_gold = BK.gather_one_lastdim( probs_heads, gold_heads_expr).squeeze(-1) # [bs, m] probs_label_gold = BK.gather_one_lastdim( probs_labels, gold_labels_expr).squeeze(-1) # [bs, m] # root and pad will be excluded later # Reward = \sum_i 1.*marginal(GEdge_i); while for global models, need to gradient on marginal-functions # todo(warn): have problem since steps will be quite small, not used! final_losses = (mask_expr - probs_head_gold * probs_label_gold ) # let loss>=0 elif self.norm_global: full_label_score = self._score_label_full(scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) # for this one, use the merged full score full_score = full_arc_score.unsqueeze( -1) + full_label_score # [BS, m, h, L] # +=1 to include ROOT for mst decoding mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts], dtype=np.int32) # do inference if self.loss_prob: marginals_expr = self._marginal( full_score, mask_expr, mst_lengths_arr) # [BS, m, h, L] final_losses = self._losses_global_prob( full_score, gold_heads_expr, gold_labels_expr, marginals_expr, mask_expr) if self.alg_proj: # todo(+N): deal with search-error-like problem, discard unproj neg losses (score>weighted-avg), # but this might be too loose, although the unproj edges are few? gold_unproj_arr, _ = self.predict_padder.pad( [z.unprojs for z in annotated_insts]) gold_unproj_expr = BK.input_real( gold_unproj_arr) # [BS, Len] comparing_expr = Constants.REAL_PRAC_MIN * ( 1. - gold_unproj_expr) final_losses = BK.max_elem(final_losses, comparing_expr) elif self.loss_hinge: pred_heads_arr, pred_labels_arr, _ = self._decode( full_score, mask_expr, mst_lengths_arr) pred_heads_expr = BK.input_idx(pred_heads_arr) # [BS, Len] pred_labels_expr = BK.input_idx(pred_labels_arr) # [BS, Len] # final_losses = self._losses_global_hinge( full_score, gold_heads_expr, gold_labels_expr, pred_heads_expr, pred_labels_expr, mask_expr) elif self.loss_mr: # todo(+N): Loss = -Reward = \sum marginals, which requires gradients on marginal-one-edge, or marginal-two-edges raise NotImplementedError( "Not implemented for global-loss + mr.") elif self.norm_hlocal: # firstly label losses are the same select_label_score = self._score_label_selected( scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) losses_labels = BK.loss_nll(select_label_score, gold_labels_expr) # then specially for arc loss children_masks_arr, _ = self.hlocal_padder.pad( [z.get_children_mask_arr() for z in annotated_insts]) children_masks_expr = BK.input_real( children_masks_arr) # [bs, h, m] # [bs, h] # todo(warn): use prod rather than sum, but still only an approximation for the top-down # losses_arc = -BK.log(BK.sum(BK.softmax(full_arc_score, -2).transpose(-1, -2) * children_masks_expr, dim=-1) + (1-mask_expr)) losses_arc = -BK.sum(BK.log_softmax(full_arc_score, -2).transpose( -1, -2) * children_masks_expr, dim=-1) # including the root-head is important losses_arc[:, 1] += losses_arc[:, 0] final_losses = losses_arc + losses_labels # # jpos loss? (the same mask as parsing) jpos_losses_expr = jpos_pack[1] if jpos_losses_expr is not None: final_losses += jpos_losses_expr # collect loss with mask, also excluding the first symbol of ROOT final_losses_masked = (final_losses * mask_expr)[:, 1:] final_loss_sum = BK.sum(final_losses_masked) # divide loss by what? num_sent = len(annotated_insts) num_valid_tok = sum(len(z) for z in annotated_insts) if self.conf.tconf.loss_div_tok: final_loss = final_loss_sum / num_valid_tok else: final_loss = final_loss_sum / num_sent # final_loss_sum_val = float(BK.get_value(final_loss_sum)) info = { "sent": num_sent, "tok": num_valid_tok, "loss_sum": final_loss_sum_val } if training: BK.backward(final_loss, loss_factor) return info
class M3Encoder(MyIEBT): def __init__(self, pc: BK.ParamCollection, conf: M3EncConf, tconf, vpack: VocabPackage): super().__init__(pc, conf, tconf, vpack) # self.conf = conf # ----- bert # modify bert_conf for other input BERT_OTHER_VSIZE = 50 # todo(+N): this should be enough for small inputs! conf.bert_conf.bert2_other_input_names = conf.bert_other_inputs conf.bert_conf.bert2_other_input_vsizes = [BERT_OTHER_VSIZE] * len( conf.bert_other_inputs) self.berter = self.add_sub_node("bert", Berter2(pc, conf.bert_conf)) # ----- # index fake sent self.index_helper = IndexerHelper(vpack) # extra encoder over bert? self.bert_dim, self.bert_fold = self.berter.get_output_dims() conf.m3_enc_conf._input_dim = self.bert_dim self.m3_encs = [ self.add_sub_node("m3e", MyEncoder(pc, conf.m3_enc_conf)) for _ in range(self.bert_fold) ] self.m3_enc_out_dim = self.m3_encs[0].get_output_dims()[0] # skip m3_enc? self.m3_enc_is_empty = all(len(z.layers) == 0 for z in self.m3_encs) if self.m3_enc_is_empty: assert all(z.get_output_dims()[0] == self.bert_dim for z in self.m3_encs) zlog("For m3_enc, we will skip it since it is empty!!") # dep as basic? if conf.m2e_use_basic_dep: MAX_LABEL_NUM = 200 # this should be enough self.dep_label_emb = self.add_sub_node( "dlab", Embedding(self.pc, MAX_LABEL_NUM, conf.dep_label_dim, name="dlab")) self.dep_layer = self.add_sub_node( "dep", TaskSpecAdp(pc, [(self.m3_enc_out_dim, self.bert_fold), None], [conf.dep_label_dim], conf.dep_output_dim)) else: self.dep_label_emb = self.dep_layer = None self.dep_padder = DataPadder( 2, pad_vals=0) # 0 for both head-idx and label # multi-sentence encoding def run(self, insts: List[DocInstance], training: bool): conf = self.conf BERT_MAX_LEN = 510 # save 2 for CLS and SEP # ===== # encoder 1: the basic encoder # todo(note): only DocInstane input for this mode, otherwise will break if conf.m2e_use_basic: reidx_pad_len = conf.ms_extend_budget # enc the basic part + also get some indexes sentid2offset = {} # id(sent)->overall_seq_offset seq_offset = 0 # if look at the docs in one seq all_sents = [] # (inst, d_idx, s_idx) for d_idx, one_doc in enumerate(insts): assert isinstance(one_doc, DocInstance) for s_idx, one_sent in enumerate(one_doc.sents): # todo(note): here we encode all the sentences all_sents.append((one_sent, d_idx, s_idx)) sentid2offset[id(one_sent)] = seq_offset seq_offset += one_sent.length - 1 # exclude extra ROOT node sent_reprs = self.run_sents(all_sents, insts, training) # flatten and concatenate and re-index reidxes_arr = np.zeros( seq_offset + reidx_pad_len, dtype=np.long ) # todo(note): extra padding to avoid out of boundary all_flattened_reprs = [] all_flatten_offset = 0 # the local offset for batched basic encoding for one_pack in sent_reprs: one_sents, _, one_repr_ef, one_repr_evt, _ = one_pack assert one_repr_ef is one_repr_evt, "Currently does not support separate basic enc in m3 mode" one_repr_t = one_repr_evt _, one_slen, one_ldim = BK.get_shape(one_repr_t) all_flattened_reprs.append(one_repr_t.view([-1, one_ldim])) # fill in the indexes for one_sent in one_sents: cur_start_offset = sentid2offset[id(one_sent)] cur_real_slen = one_sent.length - 1 # again, +1 to get rid of extra ROOT reidxes_arr[cur_start_offset:cur_start_offset+cur_real_slen] = \ np.arange(cur_real_slen, dtype=np.long) + (all_flatten_offset+1) all_flatten_offset += one_slen # here add the slen in batched version # re-idxing seq_sent_repr0 = BK.concat(all_flattened_reprs, 0) seq_sent_repr = BK.select(seq_sent_repr0, reidxes_arr, 0) # [all_seq_len, D] else: sentid2offset = defaultdict(int) seq_sent_repr = None # ===== # repack and prepare for multiple sent enc # todo(note): here, the criterion is based on bert's tokenizer all_ms_info = [] if isinstance(insts[0], DocInstance): for d_idx, one_doc in enumerate(insts): for s_idx, x in enumerate(one_doc.sents): # the basic criterion is the same as the basic one include_flag = False if training: if x.length<self.train_skip_length and x.length>=self.train_min_length \ and (len(x.events)>0 or next(self.random_sample_stream)>self.train_skip_noevt_rate): include_flag = True else: if x.length >= self.test_min_length: include_flag = True if include_flag: all_ms_info.append( x.preps["ms"]) # use the pre-calculated one else: # multisent based all_ms_info = insts.copy() # shallow copy # ===== # encoder 2: the bert one (multi-sent encoding) ms_size_f = lambda x: x.subword_size all_ms_info.sort(key=ms_size_f) all_ms_buckets = self._bucket_sents_by_length( all_ms_info, conf.benc_bucket_range, ms_size_f, max_bsize=conf.benc_bucket_msize) berter = self.berter rets = [] bert_use_center_typeids = conf.bert_use_center_typeids bert_use_special_typeids = conf.bert_use_special_typeids bert_other_inputs = conf.bert_other_inputs for one_bucket in all_ms_buckets: # prepare batched_ids = [] batched_starts = [] batched_seq_offset = [] batched_typeids = [] batched_other_inputs_list: List = [ [] for _ in bert_other_inputs ] # List(comp) of List(batch) of List(idx) for one_item in one_bucket: one_sents = one_item.sents one_center_sid = one_item.center_idx one_ids, one_starts, one_typeids = [], [], [] one_other_inputs_list = [[] for _ in bert_other_inputs ] # List(comp) of List(idx) for one_sid, one_sent in enumerate(one_sents): # for bert one_bidxes = one_sent.preps["bidx"] one_ids.extend(one_bidxes.subword_ids) one_starts.extend(one_bidxes.subword_is_start) # prepare other inputs for this_field_name, this_tofill_list in zip( bert_other_inputs, one_other_inputs_list): this_tofill_list.extend( one_sent.preps["sub_" + this_field_name]) # todo(note): special procedure if bert_use_center_typeids: if one_sid != one_center_sid: one_typeids.extend([0] * len(one_bidxes.subword_ids)) else: this_typeids = [1] * len(one_bidxes.subword_ids) if bert_use_special_typeids: # todo(note): this is the special mode that we are given the events!! for this_event in one_sents[ one_center_sid].events: _, this_wid, this_wlen = this_event.mention.hard_span.position( headed=False) for a, b in one_item.center_word2sub[ this_wid - 1:this_wid - 1 + this_wlen]: this_typeids[a:b] = [0] * (b - a) one_typeids.extend(this_typeids) batched_ids.append(one_ids) batched_starts.append(one_starts) batched_typeids.append(one_typeids) for comp_one_oi, comp_batched_oi in zip( one_other_inputs_list, batched_other_inputs_list): comp_batched_oi.append(comp_one_oi) # for basic part batched_seq_offset.append(sentid2offset[id(one_sents[0])]) # bert forward: [bs, slen, fold, D] if not bert_use_center_typeids: batched_typeids = None bert_expr0, mask_expr = berter.forward_batch( batched_ids, batched_starts, batched_typeids, training=training, other_inputs=batched_other_inputs_list) if self.m3_enc_is_empty: bert_expr = bert_expr0 else: mask_arr = BK.get_value(mask_expr) # [bs, slen] m3e_exprs = [ cur_enc(bert_expr0[:, :, cur_i], mask_arr) for cur_i, cur_enc in enumerate(self.m3_encs) ] bert_expr = BK.stack(m3e_exprs, -2) # on the fold dim again # collect basic ones: [bs, slen, D'] or None if seq_sent_repr is not None: arange_idxes_t = BK.arange_idx(BK.get_shape( mask_expr, -1)).unsqueeze(0) # [1, slen] offset_idxes_t = BK.input_idx(batched_seq_offset).unsqueeze( -1) + arange_idxes_t # [bs, slen] basic_expr = seq_sent_repr[offset_idxes_t] # [bs, slen, D'] elif conf.m2e_use_basic_dep: # collect each token's head-bert and ud-label, then forward with adp fake_sents = [one_item.fake_sent for one_item in one_bucket] # head idx and labels, no artificial ROOT padded_head_arr, _ = self.dep_padder.pad( [s.ud_heads.vals[1:] for s in fake_sents]) padded_label_arr, _ = self.dep_padder.pad( [s.ud_labels.idxes[1:] for s in fake_sents]) # get tensor padded_head_t = (BK.input_idx(padded_head_arr) - 1 ) # here, the idx exclude root padded_head_t.clamp_(min=0) # [bs, slen] padded_label_t = BK.input_idx(padded_label_arr) # get inputs input_head_bert_t = bert_expr[ BK.arange_idx(len(fake_sents)).unsqueeze(-1), padded_head_t] # [bs, slen, fold, D] input_label_emb_t = self.dep_label_emb( padded_label_t) # [bs, slen, D'] basic_expr = self.dep_layer( input_head_bert_t, None, [input_label_emb_t]) # [bs, slen, ?] elif conf.m2e_use_basic_plus: sent_reprs = self.run_sents([(one_item.fake_sent, None, None) for one_item in one_bucket], insts, training, use_one_bucket=True) assert len( sent_reprs ) == 1, "Unsupported split reprs for basic encoder, please set enc_bucket_range<=benc_bucket_range" _, _, one_repr_ef, one_repr_evt, _ = sent_reprs[0] assert one_repr_ef is one_repr_evt, "Currently does not support separate basic enc in m3 mode" basic_expr = one_repr_evt[:, 1:] # exclude ROOT, [bs, slen, D] assert BK.get_shape(basic_expr)[:2] == BK.get_shape( bert_expr)[:2] else: basic_expr = None # pack: (List[ms_item], bert_expr, basic_expr) rets.append((one_bucket, bert_expr, basic_expr)) return rets # prepare instance def prepare_inst(self, inst: DocInstance): berter = self.berter conf = self.conf ms_extend_budget = conf.ms_extend_budget ms_extend_step = conf.ms_extend_step # ----- # prepare bert tokenization results # print(inst.doc_id) for sent in inst.sents: real_words = sent.words.vals[1:] # no special ROOT bidxes = berter.subword_tokenize(real_words, True) sent.preps["bidx"] = bidxes # prepare subword expanded fields subword2word = np.cumsum(bidxes.subword_is_start).tolist( ) # -1 and +1 happens to cancel out for field_name in conf.bert_other_inputs: field_idxes = getattr( sent, field_name ).idxes # use full one since idxes are set in this way sent.preps["sub_" + field_name] = [ field_idxes[z] for z in subword2word ] # ----- # prepare others (another loop since we need cross-sent bert tokens) for sent in inst.sents: # ----- # prepare multi-sent # include this multiple sent pack, extend to both sides until window limit or bert limit cur_center_sent = sent cur_sid, cur_doc = sent.sid, sent.doc cur_doc_sents = cur_doc.sents cur_doc_nsent = len(cur_doc.sents) cur_sid_left = cur_sid_right = cur_sid cur_subword_size = len(cur_center_sent.preps["bidx"].subword_ids) for step in range(ms_extend_step): # first left then right if cur_sid_left > 0: this_subword_size = len( cur_doc_sents[cur_sid_left - 1].preps["bidx"].subword_ids) if cur_subword_size + this_subword_size <= ms_extend_budget: cur_sid_left -= 1 cur_subword_size += this_subword_size if cur_sid_right < cur_doc_nsent - 1: this_subword_size = len( cur_doc_sents[cur_sid_right + 1].preps["bidx"].subword_ids) if cur_subword_size + this_subword_size <= ms_extend_budget: cur_sid_right += 1 cur_subword_size += this_subword_size # List[Sentence], List[int], center_local_idx, all_subword_size cur_sents = cur_doc_sents[cur_sid_left:cur_sid_right + 1] cur_offsets = [0] for s in cur_sents: cur_offsets.append( s.length - 1 + cur_offsets[-1]) # does not include ROOT here!! one_ms = MultiSentItem(cur_sents, cur_offsets, cur_sid - cur_sid_left, cur_subword_size, None, None, None) sent.preps["ms"] = one_ms # ----- # subword idx for center sent center_word2sub = [] prev_start = -1 center_subword_is_start = cur_center_sent.preps[ "bidx"].subword_is_start for cur_end, one_is_start in enumerate(center_subword_is_start): if one_is_start: if prev_start >= 0: center_word2sub.append((prev_start, cur_end)) prev_start = cur_end if prev_start >= 0: center_word2sub.append( (prev_start, len(center_subword_is_start))) one_ms.center_word2sub = center_word2sub # ----- # fake a concat sent for basic plus modeling if conf.m2e_use_basic_plus or conf.m2e_use_basic_dep: concat_words, concat_lemmas, concat_uposes, concat_ud_heads, concat_ud_labels = [], [], [], [], [] cur_fake_offset = 0 # overall offset in fake sent prev_root = None for one_fake_inner_sent in cur_sents: # exclude root concat_words.extend(one_fake_inner_sent.words.vals[1:]) concat_lemmas.extend(one_fake_inner_sent.lemmas.vals[1:]) concat_uposes.extend(one_fake_inner_sent.uposes.vals[1:]) # todo(note): make the heads look like a real sent; the actual heads already +=1; root points to prev root for local_i, local_h in enumerate( one_fake_inner_sent.ud_heads.vals[1:]): if local_h == 0: if prev_root is None: global_h = cur_fake_offset + local_i + 1 # +1 here for offset else: global_h = prev_root prev_root = cur_fake_offset + local_i + 1 # +1 here for offset else: global_h = cur_fake_offset + local_h # already +=1 concat_ud_heads.append(global_h) concat_ud_labels.extend( one_fake_inner_sent.ud_labels.vals[1:]) cur_fake_offset += len(one_fake_inner_sent.words.vals) - 1 one_fake_sent = Sentence(None, concat_words, concat_lemmas, concat_uposes, concat_ud_heads, concat_ud_labels, None, None) one_ms.fake_sent = one_fake_sent self.index_helper.index_sent(one_fake_sent) def special_refresh(self, embed_rop, other_rop): super().special_refresh(embed_rop, other_rop) self.berter.refresh(other_rop) for one in self.m3_encs + [self.dep_layer, self.dep_label_emb]: if one is not None: one.refresh(other_rop) # # # def get_output_dims(self, *input_dims): # raise RuntimeError("Complex output, thus not using this one") def speical_output_dims(self): conf = self.conf if conf.m2e_use_basic_dep: basic_dim = conf.dep_output_dim elif conf.m2e_use_basic or conf.m2e_use_basic_plus: basic_dim = self.enc_evt_output_dim else: basic_dim = None # bert_outputs, basic_output return (self.m3_enc_out_dim, self.bert_fold), basic_dim
class SL0Layer(BasicNode): def __init__(self, pc: BK.ParamCollection, rconf: SL0Conf): super().__init__(pc, None, None) self.dim = rconf._input_dim # both input/output dim # padders for child nodes self.chs_start_posi = -rconf.chs_num self.ch_idx_padder = DataPadder(2, pad_vals=0, mask_range=2) # [*, num-ch] self.ch_label_padder = DataPadder(2, pad_vals=0) # self.label_embeddings = self.add_sub_node( "label", Embedding(pc, rconf._num_label, rconf.dim_label, fix_row0=False)) self.dim_label = rconf.dim_label # todo(note): now adopting flatten groupings for basic, and then that is all, no more recurrent features # group 1: [cur, chs, par] -> head_pre_size self.use_chs = rconf.use_chs self.use_par = rconf.use_par self.use_label_feat = rconf.use_label_feat # components (add the parameters anyway) # todo(note): children features: children + (label of mod->children) self.chs_reprer = self.add_sub_node("chs", ChsReprer(pc, rconf)) self.chs_ff = self.add_sub_node( "chs_ff", Affine(pc, self.chs_reprer.get_output_dims()[0], self.dim, act="tanh")) # todo(note): parent features: parent + (label of parent->mod) # todo(warn): always add label related params par_ff_inputs = [self.dim, rconf.dim_label] self.par_ff = self.add_sub_node( "par_ff", Affine(pc, par_ff_inputs, self.dim, act="tanh")) # no other groups anymore! if rconf.zero_extra_output_params: self.par_ff.zero_params() self.chs_ff.zero_params() # calculating the structured representations, giving raw repr-tensor # 1) [*, D]; 2) [*, D], [*], [*]; 3) [*, chs-len, D], [*, chs-len], [*, chs-len], [*] def calculate_repr(self, cur_t, par_t, label_t, par_mask_t, chs_t, chs_label_t, chs_mask_t, chs_valid_mask_t): ret_t = cur_t # [*, D] # padding 0 if not using labels dim_label = self.dim_label # child features if self.use_chs and chs_t is not None: if self.use_label_feat: chs_label_rt = self.label_embeddings( chs_label_t) # [*, max-chs, dlab] else: labels_shape = BK.get_shape(chs_t) labels_shape[-1] = dim_label chs_label_rt = BK.zeros(labels_shape) chs_input_t = BK.concat([chs_t, chs_label_rt], -1) chs_feat0 = self.chs_reprer(cur_t, chs_input_t, chs_mask_t, chs_valid_mask_t) chs_feat = self.chs_ff(chs_feat0) ret_t += chs_feat # parent features if self.use_par and par_t is not None: if self.use_label_feat: cur_label_t = self.label_embeddings(label_t) # [*, dlab] else: labels_shape = BK.get_shape(par_t) labels_shape[-1] = dim_label cur_label_t = BK.zeros(labels_shape) par_feat = self.par_ff([par_t, cur_label_t]) if par_mask_t is not None: par_feat *= par_mask_t.unsqueeze(-1) ret_t += par_feat return ret_t # todo(note): if no other features, then no change for the repr! def forward_repr(self, cur_t): return cur_t # preparation: padding for chs/par def pad_chs(self, idxes_list: List[List], labels_list: List[List]): start_posi = self.chs_start_posi if start_posi < 0: # truncate idxes_list = [x[start_posi:] for x in idxes_list] # overall valid mask chs_valid = [(0. if len(z) == 0 else 1.) for z in idxes_list] # if any valid children in the batch if all(x > 0 for x in chs_valid): padded_chs_idxes, padded_chs_mask = self.ch_idx_padder.pad( idxes_list) # [*, max-ch], [*, max-ch] if self.use_label_feat: if start_posi < 0: # truncate labels_list = [x[start_posi:] for x in labels_list] padded_chs_labels, _ = self.ch_label_padder.pad( labels_list) # [*, max-ch] chs_label_t = BK.input_idx(padded_chs_labels) else: chs_label_t = None chs_idxes_t, chs_mask_t, chs_valid_mask_t = \ BK.input_idx(padded_chs_idxes), BK.input_real(padded_chs_mask), BK.input_real(chs_valid) return chs_idxes_t, chs_label_t, chs_mask_t, chs_valid_mask_t else: return None, None, None, None def pad_par(self, idxes: List, labels: List): par_idxes_t = BK.input_idx(idxes) labels_t = BK.input_idx(labels) # todo(note): specifically, <0 means non-exist # todo(note): an interesting bug, the bug is ">=" was wrongly written as "<", in this way, 0 will act as the parent of those who actually do not have parents and are to be attached, therefore maybe patterns of "parent=0" will get much positive scores # todo(note): ACTUALLY, mainly because of the difference in search and forward-backward!! par_mask_t = (par_idxes_t >= 0).float() par_idxes_t.clamp_(0) # since -1 will be illegal idx labels_t.clamp_(0) return par_idxes_t, labels_t, par_mask_t
class ParserBT(BasicNode): def __init__(self, pc: BK.ParamCollection, bconf: BTConf, vpack: VocabPackage): super().__init__(pc, None, None) self.bconf = bconf # ===== Vocab ===== self.word_vocab = vpack.get_voc("word") self.char_vocab = vpack.get_voc("char") self.pos_vocab = vpack.get_voc("pos") # ===== Model ===== # embedding self.emb = self.add_sub_node( "emb", MyEmbedder(self.pc, bconf.emb_conf, vpack)) emb_output_dim = self.emb.get_output_dims()[0] # encoder0 for jpos # todo(note): will do nothing if not use_jpos bconf.jpos_conf._input_dim = emb_output_dim self.jpos_enc = self.add_sub_node( "enc0", JPosModule(self.pc, bconf.jpos_conf, self.pos_vocab)) enc0_output_dim = self.jpos_enc.get_output_dims()[0] # encoder # todo(0): feed compute-on-the-fly hp bconf.enc_conf._input_dim = enc0_output_dim self.enc = self.add_sub_node("enc", MyEncoder(self.pc, bconf.enc_conf)) self.enc_output_dim = self.enc.get_output_dims()[0] # ===== Input Specification ===== # inputs (word, char, pos) and vocabulary self.need_word = self.emb.has_word self.need_char = self.emb.has_char # todo(warn): currently only allow extra fields for POS self.need_pos = False if len(self.emb.extra_names) > 0: assert len( self.emb.extra_names) == 1 and self.emb.extra_names[0] == "pos" self.need_pos = True # todo(warn): currently only allow one aux field self.need_aux = False if len(self.emb.dim_auxes) > 0: assert len(self.emb.dim_auxes) == 1 self.need_aux = True # self.word_padder = DataPadder(2, pad_vals=self.word_vocab.pad, mask_range=2) self.char_padder = DataPadder(3, pad_lens=(0, 0, bconf.char_max_length), pad_vals=self.char_vocab.pad) self.pos_padder = DataPadder(2, pad_vals=self.pos_vocab.pad) # self.random_sample_stream = Random.stream(Random.random_sample) def get_output_dims(self, *input_dims): return (self.enc_output_dim, ) # def refresh(self, rop=None): zfatal("Should call special_refresh instead!") # whether enabling joint-pos multitask def jpos_multitask_enabled(self): return self.jpos_enc.jpos_multitask # ==== # special routines def special_refresh(self, embed_rop, other_rop): self.emb.refresh(embed_rop) self.enc.refresh(other_rop) self.jpos_enc.refresh(other_rop) def prepare_training_rop(self): mconf = self.bconf embed_rop = RefreshOptions(hdrop=mconf.drop_embed, dropmd=mconf.dropmd_embed, fix_drop=mconf.fix_drop) other_rop = RefreshOptions(hdrop=mconf.drop_hidden, idrop=mconf.idrop_rnn, gdrop=mconf.gdrop_rnn, fix_drop=mconf.fix_drop) return embed_rop, other_rop # ===== # run def _prepare_input(self, insts, training): word_arr, char_arr, extra_arrs, aux_arrs = None, None, [], [] # ===== specially prepare for the words wv = self.word_vocab W_UNK = wv.unk UNK_REP_RATE = self.bconf.singleton_unk UNK_REP_THR = self.bconf.singleton_thr word_act_idxes = [] if training and UNK_REP_RATE > 0.: # replace unfreq/singleton words with UNK for one_inst in insts: one_act_idxes = [] for one_idx in one_inst.words.idxes: one_freq = wv.idx2val(one_idx) if one_freq is not None and one_freq >= 1 and one_freq <= UNK_REP_THR: if next(self.random_sample_stream) < (UNK_REP_RATE / one_freq): one_idx = W_UNK one_act_idxes.append(one_idx) word_act_idxes.append(one_act_idxes) else: word_act_idxes = [z.words.idxes for z in insts] # todo(warn): still need the masks word_arr, mask_arr = self.word_padder.pad(word_act_idxes) # ===== if not self.need_word: word_arr = None if self.need_char: chars = [z.chars.idxes for z in insts] char_arr, _ = self.char_padder.pad(chars) if self.need_pos or self.jpos_multitask_enabled(): poses = [z.poses.idxes for z in insts] pos_arr, _ = self.pos_padder.pad(poses) if self.need_pos: extra_arrs.append(pos_arr) else: pos_arr = None if self.need_aux: aux_arr_list = [z.extra_features["aux_repr"] for z in insts] # pad padded_seq_len = int(mask_arr.shape[1]) final_aux_arr_list = [] for cur_arr in aux_arr_list: cur_len = len(cur_arr) if cur_len > padded_seq_len: final_aux_arr_list.append(cur_arr[:padded_seq_len]) else: final_aux_arr_list.append( np.pad(cur_arr, ((0, padded_seq_len - cur_len), (0, 0)), 'constant')) aux_arrs.append(np.stack(final_aux_arr_list, 0)) # input_repr = self.emb(word_arr, char_arr, extra_arrs, aux_arrs) # [BS, Len, Dim], [BS, Len] return input_repr, mask_arr, pos_arr # todo(warn): for rnn, need to transpose masks, thus need np.array # return input_repr, enc_repr, mask_arr def run(self, insts, training): # ===== calculate # [BS, Len, Di], [BS, Len], [BS, len] input_repr, mask_arr, gold_pos_arr = self._prepare_input( insts, training) # enc0 for joint-pos multitask input_repr0, jpos_pack = self.jpos_enc(input_repr, mask_arr, require_loss=training, require_pred=(not training), gold_pos_arr=gold_pos_arr) # [BS, Len, De] enc_repr = self.enc(input_repr0, mask_arr) return input_repr, enc_repr, jpos_pack, mask_arr # special routine def aug_words_and_embs(self, aug_vocab, aug_wv): orig_vocab = self.word_vocab if self.emb.has_word: orig_arr = self.emb.word_embed.E.detach().cpu().numpy() # todo(+2): find same-spelling words in the original vocab if not-hit in the extra_embed? # todo(warn): here aug_vocab should be find in aug_wv aug_arr = aug_vocab.filter_embed(aug_wv, assert_all_hit=True) new_vocab, new_arr = MultiHelper.aug_vocab_and_arr( orig_vocab, orig_arr, aug_vocab, aug_arr, aug_override=True) # assign self.word_vocab = new_vocab self.emb.word_embed.replace_weights(new_arr) else: zwarn("No need to aug vocab since delexicalized model!!") new_vocab = orig_vocab return new_vocab
class MaskLMNode(BasicNode): def __init__(self, pc: BK.ParamCollection, conf: MaskLMNodeConf, vpack: VocabPackage): super().__init__(pc, None, None) self.conf = conf # vocab and padder self.word_vocab = vpack.get_voc("word") self.padder = DataPadder( 2, pad_vals=self.word_vocab.pad, mask_range=2) # todo(note): <pad>-id is very large # models self.hid_layer = self.add_sub_node( "hid", Affine(pc, conf._input_dim, conf.hid_dim, act=conf.hid_act)) self.pred_layer = self.add_sub_node( "pred", Affine(pc, conf.hid_dim, conf.max_pred_rank + 1, init_rop=NoDropRop())) if conf.init_pred_from_pretrain: npvec = vpack.get_emb("word") if npvec is None: zwarn( "Pretrained vector not provided, skip init pred embeddings!!" ) else: with BK.no_grad_env(): self.pred_layer.ws[0].copy_( BK.input_real(npvec[:conf.max_pred_rank + 1].T)) zlog( f"Init pred embeddings from pretrained vectors (size={conf.max_pred_rank+1})." ) # return (input_word_mask_repl, output_pred_mask_repl, ouput_pred_idx) def prepare(self, insts: List[ParseInstance], training): conf = self.conf word_idxes = [z.words.idxes for z in insts] word_arr, input_mask = self.padder.pad(word_idxes) # [bsize, slen] # prepare for the masks input_word_mask = (Random.random_sample(word_arr.shape) < conf.mask_rate) & (input_mask > 0.) input_word_mask &= (word_arr >= conf.min_mask_rank) input_word_mask[:, 0] = False # no masking for special ROOT output_pred_mask = (input_word_mask & (word_arr <= conf.max_pred_rank)) return input_word_mask.astype(np.float32), output_pred_mask.astype( np.float32), word_arr # [bsize, slen, *] def loss(self, repr_t, pred_mask_repl_arr, pred_idx_arr): mask_idxes, mask_valids = BK.mask2idx( BK.input_real(pred_mask_repl_arr)) # [bsize, ?] if BK.get_shape(mask_idxes, -1) == 0: # no loss zzz = BK.zeros([]) return [[zzz, zzz, zzz]] else: target_reprs = BK.gather_first_dims(repr_t, mask_idxes, 1) # [bsize, ?, *] target_hids = self.hid_layer(target_reprs) target_scores = self.pred_layer(target_hids) # [bsize, ?, V] pred_idx_t = BK.input_idx(pred_idx_arr) # [bsize, slen] target_idx_t = pred_idx_t.gather(-1, mask_idxes) # [bsize, ?] target_idx_t[(mask_valids < 1.)] = 0 # make sure invalid ones in range # get loss pred_losses = BK.loss_nll(target_scores, target_idx_t) # [bsize, ?] pred_loss_sum = (pred_losses * mask_valids).sum() pred_loss_count = mask_valids.sum() # argmax _, argmax_idxes = target_scores.max(-1) pred_corrs = (argmax_idxes == target_idx_t).float() * mask_valids pred_corr_count = pred_corrs.sum() return [[pred_loss_sum, pred_loss_count, pred_corr_count]]
class G2Parser(BaseParser): def __init__(self, conf: G2ParserConf, vpack: VocabPackage): super().__init__(conf, vpack) # todo(note): the neural parameters are exactly the same as the EF one # ===== basic G1 Parser's loading # todo(note): there can be parameter mismatch (but all of them in non-trained part, thus will be fine) self.g1parser = G1Parser.pre_g1_init(self, conf.pre_g1_conf) self.lambda_g1_arc_training = conf.pre_g1_conf.lambda_g1_arc_training self.lambda_g1_arc_testing = conf.pre_g1_conf.lambda_g1_arc_testing self.lambda_g1_lab_training = conf.pre_g1_conf.lambda_g1_lab_training self.lambda_g1_lab_testing = conf.pre_g1_conf.lambda_g1_lab_testing # self.add_slayer() self.dl = G2DL(self.scorer, self.slayer, conf) # self.predict_padder = DataPadder(2, pad_vals=0) self.num_label = self.label_vocab.trg_len( True) # todo(WARN): use the original idx def build_decoder(self): conf = self.conf # ===== Decoding Scorer ===== conf.sc_conf._input_dim = self.enc_output_dim conf.sc_conf._num_label = self.label_vocab.trg_len( True) # todo(WARN): use the original idx return Scorer(self.pc, conf.sc_conf) def build_slayer(self): conf = self.conf # ===== Structured Layer ===== conf.sl_conf._input_dim = self.enc_output_dim conf.sl_conf._num_label = self.label_vocab.trg_len( True) # todo(WARN): use the original idx return SL0Layer(self.pc, conf.sl_conf) # ===== # main procedures # get g1 prunings and extra-scores # -> if no g1parser provided, then use aux scores; otherwise, it depends on g1_use_aux_scores def _get_g1_pack(self, insts: List[ParseInstance], score_arc_lambda: float, score_lab_lambda: float): pconf = self.conf.iconf.pruning_conf if self.g1parser is None or self.g1parser.g1_use_aux_scores: valid_mask, arc_score, lab_score, _, _ = G1Parser.score_and_prune( insts, self.num_label, pconf) else: valid_mask, arc_score, lab_score, _, _ = self.g1parser.prune_on_batch( insts, pconf) if score_arc_lambda <= 0. and score_lab_lambda <= 0.: go1_pack = None else: arc_score *= score_arc_lambda lab_score *= score_lab_lambda go1_pack = (arc_score, lab_score) # [*, slen, slen], ([*, slen, slen], [*, slen, slen, Lab]) return valid_mask, go1_pack # make real valid masks (inplaced): valid(byte): [bs, len-m, len-h]; mask(float): [bs, len] def _make_final_valid(self, valid_expr, mask_expr): maxlen = BK.get_shape(mask_expr, -1) # first apply masks mask_expr_byte = mask_expr.byte() valid_expr &= mask_expr_byte.unsqueeze(-1) valid_expr &= mask_expr_byte.unsqueeze(-2) # then diag mask_diag = 1 - BK.eye(maxlen).byte() valid_expr &= mask_diag # root not as mod valid_expr[:, 0] = 0 # only allow root->root (for grandparent feature) valid_expr[:, 0, 0] = 1 return valid_expr # decoding def inference_on_batch(self, insts: List[ParseInstance], **kwargs): # iconf = self.conf.iconf with BK.no_grad_env(): self.refresh_batch(False) # pruning and scores from g1 valid_mask, go1_pack = self._get_g1_pack( insts, self.lambda_g1_arc_testing, self.lambda_g1_lab_testing) # encode input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( insts, False) mask_expr = BK.input_real(mask_arr) # decode final_valid_expr = self._make_final_valid(valid_mask, mask_expr) ret_heads, ret_labels, _, _ = self.dl.decode( insts, enc_repr, final_valid_expr, go1_pack, False, 0.) # collect the results together all_heads = Helper.join_list(ret_heads) if ret_labels is None: # todo(note): simply get labels from the go1-label classifier; must provide g1parser if go1_pack is None: _, go1_pack = self._get_g1_pack(insts, 1., 1.) _, go1_label_max_idxes = go1_pack[1].max( -1) # [bs, slen, slen] pred_heads_arr, _ = self.predict_padder.pad( all_heads) # [bs, slen] pred_heads_expr = BK.input_idx(pred_heads_arr) pred_labels_expr = BK.gather_one_lastdim( go1_label_max_idxes, pred_heads_expr).squeeze(-1) all_labels = BK.get_value(pred_labels_expr) # [bs, slen] else: all_labels = np.concatenate(ret_labels, 0) # ===== assign, todo(warn): here, the labels are directly original idx, no need to change for one_idx, one_inst in enumerate(insts): cur_length = len(one_inst) + 1 one_inst.pred_heads.set_vals( all_heads[one_idx] [:cur_length]) # directly int-val for heads one_inst.pred_labels.build_vals( all_labels[one_idx][:cur_length], self.label_vocab) # one_inst.pred_par_scores.set_vals(all_scores[one_idx][:cur_length]) # ===== # put jpos result (possibly) self.jpos_decode(insts, jpos_pack) # ----- info = {"sent": len(insts), "tok": sum(map(len, insts))} return info # training def fb_on_batch(self, annotated_insts: List[ParseInstance], training=True, loss_factor=1., **kwargs): self.refresh_batch(training) # pruning and scores from g1 valid_mask, go1_pack = self._get_g1_pack(annotated_insts, self.lambda_g1_arc_training, self.lambda_g1_lab_training) # encode input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( annotated_insts, training) mask_expr = BK.input_real(mask_arr) # the parsing loss final_valid_expr = self._make_final_valid(valid_mask, mask_expr) parsing_loss, parsing_scores, info = \ self.dl.loss(annotated_insts, enc_repr, final_valid_expr, go1_pack, True, self.margin.value) info["loss_parse"] = BK.get_value(parsing_loss).item() final_loss = parsing_loss # other loss? jpos_loss = self.jpos_loss(jpos_pack, mask_expr) if jpos_loss is not None: info["loss_jpos"] = BK.get_value(jpos_loss).item() final_loss = parsing_loss + jpos_loss if parsing_scores is not None: reg_loss = self.reg_scores_loss(*parsing_scores) if reg_loss is not None: final_loss = final_loss + reg_loss info["fb"] = 1 if training: BK.backward(final_loss, loss_factor) return info
class G1Parser(BaseParser): def __init__(self, conf: G1ParserConf, vpack: VocabPackage): super().__init__(conf, vpack) # todo(note): the neural parameters are exactly the same as the EF one self.scorer_helper = GScorerHelper(self.scorer) self.predict_padder = DataPadder(2, pad_vals=0) # self.g1_use_aux_scores = conf.debug_use_aux_scores # assining here is only for debugging usage, otherwise assigning outside self.num_label = self.label_vocab.trg_len( True) # todo(WARN): use the original idx # self.loss_hinge = (self.conf.tconf.loss_function == "hinge") if not self.loss_hinge: assert self.conf.tconf.loss_function == "prob", "This model only supports hinge or prob" def build_decoder(self): conf = self.conf # ===== Decoding Scorer ===== conf.sc_conf._input_dim = self.enc_output_dim conf.sc_conf._num_label = self.label_vocab.trg_len( True) # todo(WARN): use the original idx return Scorer(self.pc, conf.sc_conf) # ===== # main procedures # decoding def inference_on_batch(self, insts: List[ParseInstance], **kwargs): iconf = self.conf.iconf pconf = iconf.pruning_conf with BK.no_grad_env(): self.refresh_batch(False) if iconf.use_pruning: # todo(note): for the testing of pruning mode, use the scores instead if self.g1_use_aux_scores: valid_mask, arc_score, label_score, mask_expr, _ = G1Parser.score_and_prune( insts, self.num_label, pconf) else: valid_mask, arc_score, label_score, mask_expr, _ = self.prune_on_batch( insts, pconf) valid_mask_f = valid_mask.float() # [*, len, len] mask_value = Constants.REAL_PRAC_MIN full_score = arc_score.unsqueeze(-1) + label_score full_score += (mask_value * (1. - valid_mask_f)).unsqueeze(-1) info_pruning = G1Parser.collect_pruning_info( insts, valid_mask_f) jpos_pack = [None, None, None] else: input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( insts, False) mask_expr = BK.input_real(mask_arr) full_score = self.scorer_helper.score_full(enc_repr) info_pruning = None # ===== self._decode(insts, full_score, mask_expr, "g1") # put jpos result (possibly) self.jpos_decode(insts, jpos_pack) # ----- info = {"sent": len(insts), "tok": sum(map(len, insts))} if info_pruning is not None: info.update(info_pruning) return info # training def fb_on_batch(self, annotated_insts: List[ParseInstance], training=True, loss_factor=1., **kwargs): self.refresh_batch(training) # encode input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( annotated_insts, training) mask_expr = BK.input_real(mask_arr) # the parsing loss arc_score = self.scorer_helper.score_arc(enc_repr) lab_score = self.scorer_helper.score_label(enc_repr) full_score = arc_score + lab_score parsing_loss, info = self._loss(annotated_insts, full_score, mask_expr) # other loss? jpos_loss = self.jpos_loss(jpos_pack, mask_expr) reg_loss = self.reg_scores_loss(arc_score, lab_score) # info["loss_parse"] = BK.get_value(parsing_loss).item() final_loss = parsing_loss if jpos_loss is not None: info["loss_jpos"] = BK.get_value(jpos_loss).item() final_loss = parsing_loss + jpos_loss if reg_loss is not None: final_loss = final_loss + reg_loss info["fb"] = 1 if training: BK.backward(final_loss, loss_factor) return info # ===== def _decode(self, insts: List[ParseInstance], full_score, mask_expr, misc_prefix): # decode mst_lengths = [len(z) + 1 for z in insts] # +=1 to include ROOT for mst decoding mst_lengths_arr = np.asarray(mst_lengths, dtype=np.int32) mst_heads_arr, mst_labels_arr, mst_scores_arr = nmst_unproj( full_score, mask_expr, mst_lengths_arr, labeled=True, ret_arr=True) if self.conf.iconf.output_marginals: # todo(note): here, we care about marginals for arc # lab_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True) arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).sum(-1) bsize, max_len = BK.get_shape(mask_expr) idxes_bs_expr = BK.arange_idx(bsize).unsqueeze(-1) idxes_m_expr = BK.arange_idx(max_len).unsqueeze(0) output_marg = arc_marginals[idxes_bs_expr, idxes_m_expr, BK.input_idx(mst_heads_arr)] mst_marg_arr = BK.get_value(output_marg) else: mst_marg_arr = None # ===== assign, todo(warn): here, the labels are directly original idx, no need to change for one_idx, one_inst in enumerate(insts): cur_length = mst_lengths[one_idx] one_inst.pred_heads.set_vals( mst_heads_arr[one_idx] [:cur_length]) # directly int-val for heads one_inst.pred_labels.build_vals( mst_labels_arr[one_idx][:cur_length], self.label_vocab) one_scores = mst_scores_arr[one_idx][:cur_length] one_inst.pred_par_scores.set_vals(one_scores) # extra output one_inst.extra_pred_misc[misc_prefix + "_score"] = one_scores.tolist() if mst_marg_arr is not None: one_inst.extra_pred_misc[ misc_prefix + "_marg"] = mst_marg_arr[one_idx][:cur_length].tolist() # here, only adopt hinge(max-margin) loss; mostly adopted from previous graph parser def _loss(self, annotated_insts: List[ParseInstance], full_score_expr, mask_expr, valid_expr=None): bsize, max_len = BK.get_shape(mask_expr) # gold heads and labels gold_heads_arr, _ = self.predict_padder.pad( [z.heads.vals for z in annotated_insts]) # todo(note): here use the original idx of label, no shift! gold_labels_arr, _ = self.predict_padder.pad( [z.labels.idxes for z in annotated_insts]) gold_heads_expr = BK.input_idx(gold_heads_arr) # [BS, Len] gold_labels_expr = BK.input_idx(gold_labels_arr) # [BS, Len] # idxes_bs_expr = BK.arange_idx(bsize).unsqueeze(-1) idxes_m_expr = BK.arange_idx(max_len).unsqueeze(0) # scores for decoding or marginal margin = self.margin.value decoding_scores = full_score_expr.clone().detach() decoding_scores = self.scorer_helper.postprocess_scores( decoding_scores, mask_expr, margin, gold_heads_expr, gold_labels_expr) if self.loss_hinge: mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts], dtype=np.int32) pred_heads_expr, pred_labels_expr, _ = nmst_unproj(decoding_scores, mask_expr, mst_lengths_arr, labeled=True, ret_arr=False) # ===== add margin*cost, [bs, len] gold_final_scores = full_score_expr[idxes_bs_expr, idxes_m_expr, gold_heads_expr, gold_labels_expr] pred_final_scores = full_score_expr[ idxes_bs_expr, idxes_m_expr, pred_heads_expr, pred_labels_expr] + margin * ( gold_heads_expr != pred_heads_expr).float() + margin * ( gold_labels_expr != pred_labels_expr).float() # plus margin hinge_losses = pred_final_scores - gold_final_scores valid_losses = ((hinge_losses * mask_expr)[:, 1:].sum(-1) > 0.).float().unsqueeze(-1) # [*, 1] final_losses = hinge_losses * valid_losses else: lab_marginals = nmarginal_unproj(decoding_scores, mask_expr, None, labeled=True) lab_marginals[idxes_bs_expr, idxes_m_expr, gold_heads_expr, gold_labels_expr] -= 1. grads_masked = lab_marginals * mask_expr.unsqueeze(-1).unsqueeze( -1) * mask_expr.unsqueeze(-2).unsqueeze(-1) final_losses = (full_score_expr * grads_masked).sum(-1).sum( -1) # [bs, m] # divide loss by what? num_sent = len(annotated_insts) num_valid_tok = sum(len(z) for z in annotated_insts) # exclude non-valid ones: there can be pruning error if valid_expr is not None: final_valids = valid_expr[idxes_bs_expr, idxes_m_expr, gold_heads_expr] # [bs, m] of (0. or 1.) final_losses = final_losses * final_valids tok_valid = float(BK.get_value(final_valids[:, 1:].sum())) assert tok_valid <= num_valid_tok tok_prune_err = num_valid_tok - tok_valid else: tok_prune_err = 0 # collect loss with mask, also excluding the first symbol of ROOT final_losses_masked = (final_losses * mask_expr)[:, 1:] final_loss_sum = BK.sum(final_losses_masked) if self.conf.tconf.loss_div_tok: final_loss = final_loss_sum / num_valid_tok else: final_loss = final_loss_sum / num_sent final_loss_sum_val = float(BK.get_value(final_loss_sum)) info = { "sent": num_sent, "tok": num_valid_tok, "tok_prune_err": tok_prune_err, "loss_sum": final_loss_sum_val } return final_loss, info # ===== # special preloading @staticmethod def special_pretrain_load(m: BaseParser, path, strict): if FileHelper.isfile(path): try: zlog(f"Trying to load pretrained model from {path}") m.load(path, strict) zlog(f"Finished loading pretrained model from {path}") return True except: zlog(traceback.format_exc()) zlog("Failed loading, keep the original ones.") else: zlog(f"File does not exist for pretraining loading: {path}") return False # init for the pre-trained G1, return g1parser and also modify m's params @staticmethod def pre_g1_init(m: BaseParser, pg1_conf: PreG1Conf, strict=True): # ===== basic G1 Parser's loading # todo(WARN): construct the g1conf here instead of loading for simplicity, since the scorer architecture should be the same g1conf = G1ParserConf() g1conf.bt_conf = deepcopy(m.conf.bt_conf) g1conf.sc_conf = deepcopy(m.conf.sc_conf) g1conf.validate() # todo(note): specific setting g1conf.g1_use_aux_scores = pg1_conf.g1_use_aux_scores # g1parser = G1Parser(g1conf, m.vpack) if not G1Parser.special_pretrain_load( g1parser, pg1_conf.g1_pretrain_path, strict): g1parser = None # current init if pg1_conf.g1_pretrain_init: G1Parser.special_pretrain_load(m, pg1_conf.g1_pretrain_path, strict) return g1parser # collect and batch scores @staticmethod def collect_aux_scores(insts: List[ParseInstance], output_num_label): score_tuples = [z.extra_features["aux_score"] for z in insts] num_label = score_tuples[0][1].shape[-1] max_len = max(len(z) + 1 for z in insts) mask_value = Constants.REAL_PRAC_MIN bsize = len(insts) arc_score_arr = np.full([bsize, max_len, max_len], mask_value, dtype=np.float32) lab_score_arr = np.full([bsize, max_len, max_len, output_num_label], mask_value, dtype=np.float32) mask_arr = np.full([bsize, max_len], 0., dtype=np.float32) for bidx, one_tuple in enumerate(score_tuples): one_score_arc, one_score_lab = one_tuple one_len = one_score_arc.shape[1] arc_score_arr[bidx, :one_len, :one_len] = one_score_arc lab_score_arr[bidx, :one_len, :one_len, -num_label:] = one_score_lab mask_arr[bidx, :one_len] = 1. return BK.input_real(arc_score_arr).unsqueeze(-1), BK.input_real( lab_score_arr), BK.input_real(mask_arr) # pruner: [bs, slen, slen], [bs, slen, slen, Lab] @staticmethod def prune_with_scores(arc_score, label_score, mask_expr, pconf: PruneG1Conf, arc_marginals=None): prune_use_topk, prune_use_marginal, prune_labeled, prune_perc, prune_topk, prune_gap, prune_mthresh, prune_mthresh_rel = \ pconf.pruning_use_topk, pconf.pruning_use_marginal, pconf.pruning_labeled, pconf.pruning_perc, pconf.pruning_topk, \ pconf.pruning_gap, pconf.pruning_mthresh, pconf.pruning_mthresh_rel full_score = arc_score + label_score final_valid_mask = BK.constants(BK.get_shape(arc_score), 0, dtype=BK.uint8).squeeze(-1) # (put as argument) arc_marginals = None # [*, mlen, hlen] if prune_use_marginal: if arc_marginals is None: # does not provided, calculate from scores if prune_labeled: # arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).max(-1)[0] # use sum of label marginals instead of max arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).sum(-1) else: arc_marginals = nmarginal_unproj(arc_score, mask_expr, None, labeled=True).squeeze(-1) if prune_mthresh_rel: # relative value max_arc_marginals = arc_marginals.max(-1)[0].log().unsqueeze( -1) m_valid_mask = (arc_marginals.log() - max_arc_marginals) > float( np.log(prune_mthresh)) else: # absolute value m_valid_mask = (arc_marginals > prune_mthresh ) # [*, len-m, len-h] final_valid_mask |= m_valid_mask if prune_use_topk: # prune by "in topk" and "gap-to-top less than gap" for each mod if prune_labeled: # take argmax among label dim tmp_arc_score, _ = full_score.max(-1) else: # todo(note): may be modified inplaced, but does not matter since will finally be masked later tmp_arc_score = arc_score.squeeze(-1) # first apply mask mask_value = Constants.REAL_PRAC_MIN mask_mul = (mask_value * (1. - mask_expr)) # [*, len] tmp_arc_score += mask_mul.unsqueeze(-1) tmp_arc_score += mask_mul.unsqueeze(-2) maxlen = BK.get_shape(tmp_arc_score, -1) tmp_arc_score += mask_value * BK.eye(maxlen) prune_topk = min(prune_topk, int(maxlen * prune_perc + 1), maxlen) if prune_topk >= maxlen: topk_arc_score = tmp_arc_score else: topk_arc_score, _ = BK.topk(tmp_arc_score, prune_topk, dim=-1, sorted=False) # [*, len, k] min_topk_arc_score = topk_arc_score.min(-1)[0].unsqueeze( -1) # [*, len, 1] max_topk_arc_score = topk_arc_score.max(-1)[0].unsqueeze( -1) # [*, len, 1] arc_score_thresh = BK.max_elem(min_topk_arc_score, max_topk_arc_score - prune_gap) # [*, len, 1] t_valid_mask = (tmp_arc_score > arc_score_thresh ) # [*, len-m, len-h] final_valid_mask |= t_valid_mask return final_valid_mask, arc_marginals # combining the above two @staticmethod def score_and_prune(insts: List[ParseInstance], output_num_label, pconf: PruneG1Conf): arc_score, lab_score, mask_expr = G1Parser.collect_aux_scores( insts, output_num_label) valid_mask, arc_marginals = G1Parser.prune_with_scores( arc_score, lab_score, mask_expr, pconf) return valid_mask, arc_score.squeeze( -1), lab_score, mask_expr, arc_marginals # ===== # special mode: first-order model as scorer/pruner def score_on_batch(self, insts: List[ParseInstance]): with BK.no_grad_env(): self.refresh_batch(False) input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( insts, False) # mask_expr = BK.input_real(mask_arr) arc_score = self.scorer_helper.score_arc(enc_repr) label_score = self.scorer_helper.score_label(enc_repr) return arc_score.squeeze(-1), label_score # todo(note): the union of two types of pruner! def prune_on_batch(self, insts: List[ParseInstance], pconf: PruneG1Conf): with BK.no_grad_env(): self.refresh_batch(False) # encode input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( insts, False) mask_expr = BK.input_real(mask_arr) arc_score = self.scorer_helper.score_arc(enc_repr) label_score = self.scorer_helper.score_label(enc_repr) final_valid_mask, arc_marginals = G1Parser.prune_with_scores( arc_score, label_score, mask_expr, pconf) return final_valid_mask, arc_score.squeeze( -1), label_score, mask_expr, arc_marginals # ===== @staticmethod def collect_pruning_info(insts: List[ParseInstance], valid_mask_f): # two dimensions: coverage and pruning-effect maxlen = BK.get_shape(valid_mask_f, -1) # 1. coverage valid_mask_f_flattened = valid_mask_f.view([-1, maxlen]) # [bs*len, len] cur_mod_base = 0 all_mods, all_heads = [], [] for cur_idx, cur_inst in enumerate(insts): for m, h in enumerate(cur_inst.heads.vals[1:], 1): all_mods.append(m + cur_mod_base) all_heads.append(h) cur_mod_base += maxlen cov_count = len(all_mods) cov_valid = BK.get_value( valid_mask_f_flattened[all_mods, all_heads].sum()).item() # 2. pruning-rate # todo(warn): to speed up, these stats are approximate because of including paddings # edges pr_edges = int(np.prod(BK.get_shape(valid_mask_f))) pr_edges_valid = BK.get_value(valid_mask_f.sum()).item() # valid as structured heads pr_o2_sib = pr_o2_g = pr_edges pr_o3_gsib = maxlen * pr_edges valid_chs_counts, valid_par_counts = valid_mask_f.sum( -2), valid_mask_f.sum(-1) # [*, len] valid_gsibs = valid_chs_counts * valid_par_counts pr_o2_sib_valid = BK.get_value(valid_chs_counts.sum()).item() pr_o2_g_valid = BK.get_value(valid_par_counts.sum()).item() pr_o3_gsib_valid = BK.get_value(valid_gsibs.sum()).item() return { "cov_count": cov_count, "cov_valid": cov_valid, "pr_edges": pr_edges, "pr_edges_valid": pr_edges_valid, "pr_o2_sib": pr_o2_sib, "pr_o2_g": pr_o2_g, "pr_o3_gsib": pr_o3_gsib, "pr_o2_sib_valid": pr_o2_sib_valid, "pr_o2_g_valid": pr_o2_g_valid, "pr_o3_gsib_valid": pr_o3_gsib_valid }
class FpEncoder(BasicNode): def __init__(self, pc: BK.ParamCollection, conf: FpEncConf, vpack: VocabPackage): super().__init__(pc, None, None) self.conf = conf # ===== Vocab ===== self.word_vocab = vpack.get_voc("word") self.char_vocab = vpack.get_voc("char") self.pos_vocab = vpack.get_voc("pos") # avoid no params error self._tmp_v = self.add_param("nope", (1, )) # ===== Model ===== # embedding self.emb = self.add_sub_node("emb", MyEmbedder(self.pc, conf.emb_conf, vpack)) self.emb_output_dim = self.emb.get_output_dims()[0] # bert self.bert = self.add_sub_node("bert", Berter2(self.pc, conf.bert_conf)) self.bert_output_dim = self.bert.get_output_dims()[0] # make sure there are inputs assert self.emb_output_dim > 0 or self.bert_output_dim > 0 # middle? if conf.middle_dim > 0: self.middle_node = self.add_sub_node( "mid", Affine(self.pc, self.emb_output_dim + self.bert_output_dim, conf.middle_dim, act="elu")) self.enc_input_dim = conf.middle_dim else: self.middle_node = None self.enc_input_dim = self.emb_output_dim + self.bert_output_dim # concat the two parts (if needed) # encoder? # todo(note): feed compute-on-the-fly hp conf.enc_conf._input_dim = self.enc_input_dim self.enc = self.add_sub_node("enc", MyEncoder(self.pc, conf.enc_conf)) self.enc_output_dim = self.enc.get_output_dims()[0] # ===== Input Specification ===== # inputs (word, char, pos) and vocabulary self.need_word = self.emb.has_word self.need_char = self.emb.has_char # todo(warn): currently only allow extra fields for POS self.need_pos = False if len(self.emb.extra_names) > 0: assert len( self.emb.extra_names) == 1 and self.emb.extra_names[0] == "pos" self.need_pos = True # self.word_padder = DataPadder(2, pad_vals=self.word_vocab.pad, mask_range=2) self.char_padder = DataPadder(3, pad_lens=(0, 0, conf.char_max_length), pad_vals=self.char_vocab.pad) self.pos_padder = DataPadder(2, pad_vals=self.pos_vocab.pad) def get_output_dims(self, *input_dims): return (self.enc_output_dim, ) # def refresh(self, rop=None): zfatal("Should call special_refresh instead!") # ==== # special routines def special_refresh(self, embed_rop, other_rop): self.emb.refresh(embed_rop) self.enc.refresh(other_rop) self.bert.refresh(other_rop) if self.middle_node is not None: self.middle_node.refresh(other_rop) def prepare_training_rop(self): mconf = self.conf embed_rop = RefreshOptions(hdrop=mconf.drop_embed, dropmd=mconf.dropmd_embed, fix_drop=mconf.fix_drop) other_rop = RefreshOptions(hdrop=mconf.drop_hidden, idrop=mconf.idrop_rnn, gdrop=mconf.gdrop_rnn, fix_drop=mconf.fix_drop) return embed_rop, other_rop def aug_words_and_embs(self, aug_vocab, aug_wv): orig_vocab = self.word_vocab if self.emb.has_word: orig_arr = self.emb.word_embed.E.detach().cpu().numpy() # todo(+2): find same-spelling words in the original vocab if not-hit in the extra_embed? # todo(warn): here aug_vocab should be find in aug_wv aug_arr = aug_vocab.filter_embed(aug_wv, assert_all_hit=True) new_vocab, new_arr = MultiHelper.aug_vocab_and_arr( orig_vocab, orig_arr, aug_vocab, aug_arr, aug_override=True) # assign self.word_vocab = new_vocab self.emb.word_embed.replace_weights(new_arr) else: zwarn("No need to aug vocab since delexicalized model!!") new_vocab = orig_vocab return new_vocab # ===== # run def prepare_inputs(self, insts, training, input_word_mask_repl=None): word_arr, char_arr, extra_arrs, aux_arrs = None, None, [], [] # ===== specially prepare for the words wv = self.word_vocab W_UNK = wv.unk word_act_idxes = [z.words.idxes for z in insts] # todo(warn): still need the masks word_arr, mask_arr = self.word_padder.pad(word_act_idxes) if input_word_mask_repl is not None: input_word_mask_repl = input_word_mask_repl.astype(np.int) word_arr = word_arr * ( 1 - input_word_mask_repl ) + W_UNK * input_word_mask_repl # replace with UNK # ===== if not self.need_word: word_arr = None if self.need_char: chars = [z.chars.idxes for z in insts] char_arr, _ = self.char_padder.pad(chars) if self.need_pos: poses = [z.poses.idxes for z in insts] pos_arr, _ = self.pos_padder.pad(poses) extra_arrs.append(pos_arr) return word_arr, char_arr, extra_arrs, aux_arrs, mask_arr # todo(note): for rnn, need to transpose masks, thus need np.array # return input_repr, enc_repr, mask_arr def run(self, insts, training, input_word_mask_repl=None): self._cache_subword_tokens(insts) # prepare inputs word_arr, char_arr, extra_arrs, aux_arrs, mask_arr = \ self.prepare_inputs(insts, training, input_word_mask_repl=input_word_mask_repl) # layer0: emb + bert layer0_reprs = [] if self.emb_output_dim > 0: emb_repr = self.emb(word_arr, char_arr, extra_arrs, aux_arrs) # [BS, Len, Dim] layer0_reprs.append(emb_repr) if self.bert_output_dim > 0: # prepare bert inputs BERT_MASK_ID = self.bert.tokenizer.mask_token_id batch_subword_ids, batch_subword_is_starts = [], [] for bidx, one_inst in enumerate(insts): st = one_inst.extra_features["st"] if input_word_mask_repl is not None: cur_subword_ids, cur_subword_is_start, _ = \ st.mask_and_return(input_word_mask_repl[bidx][1:], BERT_MASK_ID) # todo(note): exclude ROOT for bert tokens else: cur_subword_ids, cur_subword_is_start = st.subword_ids, st.subword_is_start batch_subword_ids.append(cur_subword_ids) batch_subword_is_starts.append(cur_subword_is_start) bert_repr, _ = self.bert.forward_batch( batch_subword_ids, batch_subword_is_starts, batched_typeids=None, training=training) # [BS, Len, D'] layer0_reprs.append(bert_repr) # layer1: enc enc_input_repr = BK.concat(layer0_reprs, -1) # [BS, Len, D+D'] if self.middle_node is not None: enc_input_repr = self.middle_node(enc_input_repr) # [BS, Len, D??] enc_repr = self.enc(enc_input_repr, mask_arr) mask_repr = BK.input_real(mask_arr) return enc_repr, mask_repr # [bs, len, *], [bs, len] # ===== # caching def _cache_subword_tokens(self, insts): for one_inst in insts: if "st" not in one_inst.extra_features: one_inst.extra_features["st"] = self.bert.subword_tokenize2( one_inst.words.vals[1:], True)
def __init__(self, name, conf: HLabelConf, keys: Dict[str, int], nil_as_zero=True): assert nil_as_zero, "Currently assume nil as zero for all layers" self.conf = conf self.name = name # from original vocab self.orig_counts = {k: v for k, v in keys.items()} keys = sorted(set( keys.keys())) # for example, can be "vocab.trg_keys()" max_layer = 0 # ===== # part 1: layering v = {} keys = [None] + keys # ad None as NIL for k in keys: cur_idx = HLabelIdx.construct_hlidx(k, conf.layered) max_layer = max(max_layer, len(cur_idx)) v[k] = cur_idx # collect all the layered types and put idxes self.max_layer = max_layer self.layered_v = [{} for _ in range(max_layer) ] # key -> int-idx for each layer self.layered_k = [[] for _ in range(max_layer) ] # int-idx -> key for each layer self.layered_prei = [ [] for _ in range(max_layer) ] # int-idx -> int-idx: idx of prefix in previous layer self.layered_hlidx = [[] for _ in range(max_layer) ] # int-idx -> hlidx for each layer for k in keys: cur_hidx = v[k] cur_types = cur_hidx.pad_types(max_layer) assert len(cur_types) == max_layer cur_idxes = [] # int-idxes for each layer # assign from 0 to max-layer for cur_layer_i in range(max_layer): cur_layered_v, cur_layered_k, cur_layered_prei, cur_layered_hlidx = \ self.layered_v[cur_layer_i], self.layered_k[cur_layer_i], \ self.layered_prei[cur_layer_i], self.layered_hlidx[cur_layer_i] # also put empty classes here for valid_until_idx in range(cur_layer_i + 1): cur_layer_types = cur_types[:valid_until_idx + 1] + ( None, ) * (cur_layer_i - valid_until_idx) # for cur_layer_types in [cur_types[:cur_layer_i]+(None,), cur_types[:cur_layer_i+1]]: if cur_layer_types not in cur_layered_v: new_idx = len(cur_layered_k) cur_layered_v[cur_layer_types] = new_idx cur_layered_k.append(cur_layer_types) cur_layered_prei.append(0 if cur_layer_i == 0 else cur_idxes[-1]) # previous idx cur_layered_hlidx.append( HLabelIdx(cur_layer_types, None) ) # make a new hlidx, need to fill idxes later # put the actual idx cur_idxes.append(cur_layered_v[cur_types[:cur_layer_i + 1]]) cur_hidx.idxes = cur_idxes # put the idxes (actually not useful here) self.nil_as_zero = nil_as_zero if nil_as_zero: assert all( z[0].is_nil() for z in self.layered_hlidx) # make sure each layer's 0 is all-Nil # put the idxes for layered_hlidx self.v = {} for cur_layer_i in range(max_layer): cur_layered_hlidx = self.layered_hlidx[cur_layer_i] for one_hlidx in cur_layered_hlidx: one_types = one_hlidx.pad_types(max_layer) one_hlidx.idxes = [ self.layered_v[i][one_types[:i + 1]] for i in range(max_layer) ] self.v[str( one_hlidx )] = one_hlidx # todo(note): further layers will over-written previous ones self.nil_idx = self.v[""] # ===== # (main) part 2: representation # link each type representation to the overall pool self.pools_v = {None: 0} # NIL as 0 self.pools_k = [None] self.pools_hint_lexicon = [ [] ] # hit lexicon to look up in pre-trained embeddings: List[pool] of List[str] # links for each label-embeddings to the pool-embeddings: List(layer) of List(label) of List(idx-in-pool) self.layered_pool_links = [[] for _ in range(max_layer)] # masks indicating local-NIL(None) self.layered_pool_isnil = [[] for _ in range(max_layer)] for cur_layer_i in range(max_layer): cur_layered_pool_links = self.layered_pool_links[ cur_layer_i] # List(label) of List cur_layered_k = self.layered_k[cur_layer_i] # List(full-key-tuple) cur_layered_pool_isnil = self.layered_pool_isnil[ cur_layer_i] # List[int] for one_k in cur_layered_k: one_k_final = one_k[ cur_layer_i] # use the final token at least for lexicon hint if one_k_final is None: cur_layered_pool_links.append( [0]) # todo(note): None is always zero cur_layered_pool_isnil.append(1) continue cur_layered_pool_isnil.append(0) # either splitting into pools or splitting for lexicon hint one_k_final_elems: List[str] = split_camel(one_k_final) # adding prefix according to the strategy if conf.pool_sharing == "nope": one_prefix = f"L{cur_layer_i}-" + ".".join( one_k[:-1]) + "." elif conf.pool_sharing == "layered": one_prefix = f"L{cur_layer_i}-" elif conf.pool_sharing == "shared": one_prefix = "" else: raise NotImplementedError( f"UNK pool-sharing strategy {conf.pool_sharing}!") # put in the pools (also two types of strategies) if conf.pool_split_camel: # possibly multiple mappings to the pool, each pool-elem gets only one hint-lexicon cur_layered_pool_links.append([]) for this_pool_key in one_k_final_elems: this_pool_key1 = one_prefix + this_pool_key if this_pool_key1 not in self.pools_v: self.pools_v[this_pool_key1] = len(self.pools_k) self.pools_k.append(this_pool_key1) self.pools_hint_lexicon.append( [self.get_lexicon_hint(this_pool_key)]) cur_layered_pool_links[-1].append( self.pools_v[this_pool_key1]) else: # only one mapping to the pool, each pool-elem can get multiple hint-lexicons this_pool_key1 = one_prefix + one_k_final if this_pool_key1 not in self.pools_v: self.pools_v[this_pool_key1] = len(self.pools_k) self.pools_k.append(this_pool_key1) self.pools_hint_lexicon.append([ self.get_lexicon_hint(z) for z in one_k_final_elems ]) cur_layered_pool_links.append( [self.pools_v[this_pool_key1]]) assert self.pools_v[None] == 0, "Internal error!" # padding and masking for the links (in numpy) self.layered_pool_links_padded = [] # List[arr(#layered-label, #elem)] self.layered_pool_links_mask = [] # List[arr(...)] padder = DataPadder(2, mask_range=2) # separately for each layer for cur_layer_i in range(max_layer): cur_arr, cur_mask = padder.pad( self.layered_pool_links[cur_layer_i] ) # [each-sublabel, padded-max-elems] self.layered_pool_links_padded.append(cur_arr) self.layered_pool_links_mask.append(cur_mask) self.layered_prei = [np.asarray(z) for z in self.layered_prei] self.layered_pool_isnil = [ np.asarray(z) for z in self.layered_pool_isnil ] self.pool_init_vec = None # zlog( f"Build HLabelVocab {name} (max-layer={max_layer} from pools={len(self.pools_k)}): " + "; ".join( [f"L{i}={len(self.layered_k[i])}" for i in range(max_layer)]))