Ejemplo n.º 1
0
 def __init__(self, pc: BK.ParamCollection, conf: MaskLMNodeConf,
              vpack: VocabPackage):
     super().__init__(pc, None, None)
     self.conf = conf
     # vocab and padder
     self.word_vocab = vpack.get_voc("word")
     self.padder = DataPadder(
         2, pad_vals=self.word_vocab.pad,
         mask_range=2)  # todo(note): <pad>-id is very large
     # models
     self.hid_layer = self.add_sub_node(
         "hid", Affine(pc, conf._input_dim, conf.hid_dim, act=conf.hid_act))
     self.pred_layer = self.add_sub_node(
         "pred",
         Affine(pc,
                conf.hid_dim,
                conf.max_pred_rank + 1,
                init_rop=NoDropRop()))
     if conf.init_pred_from_pretrain:
         npvec = vpack.get_emb("word")
         if npvec is None:
             zwarn(
                 "Pretrained vector not provided, skip init pred embeddings!!"
             )
         else:
             with BK.no_grad_env():
                 self.pred_layer.ws[0].copy_(
                     BK.input_real(npvec[:conf.max_pred_rank + 1].T))
             zlog(
                 f"Init pred embeddings from pretrained vectors (size={conf.max_pred_rank+1})."
             )
Ejemplo n.º 2
0
 def __init__(self, conf: G1ParserConf, vpack: VocabPackage):
     super().__init__(conf, vpack)
     # todo(note): the neural parameters are exactly the same as the EF one
     self.scorer_helper = GScorerHelper(self.scorer)
     self.predict_padder = DataPadder(2, pad_vals=0)
     #
     self.g1_use_aux_scores = conf.debug_use_aux_scores  # assining here is only for debugging usage, otherwise assigning outside
     self.num_label = self.label_vocab.trg_len(
         True)  # todo(WARN): use the original idx
     #
     self.loss_hinge = (self.conf.tconf.loss_function == "hinge")
     if not self.loss_hinge:
         assert self.conf.tconf.loss_function == "prob", "This model only supports hinge or prob"
Ejemplo n.º 3
0
 def __init__(self, embedder: EmbedderNode, vpack: VocabPackage):
     self.vpack = vpack
     self.embedder = embedder
     # -----
     # prepare the inputter
     self.comp_names = embedder.comp_names
     self.comp_helpers = []
     for comp_name in self.comp_names:
         one_helper = InputHelper.get_input_helper(comp_name, comp_name, vpack)
         self.comp_helpers.append(one_helper)
         if comp_name == "bert":
             assert embedder.berter is not None
             one_helper.set_berter(berter=embedder.berter)
     # ====
     self.mask_padder = DataPadder(2, pad_vals=0., mask_range=2)
Ejemplo n.º 4
0
 def __init__(self, pc, conf: FpDecConf, label_vocab):
     super().__init__(pc, None, None)
     self.conf = conf
     self.label_vocab = label_vocab
     self.predict_padder = DataPadder(2, pad_vals=0)
     # the scorer
     self.use_ablpair = conf.use_ablpair
     conf.sconf._input_dim = conf._input_dim
     conf.sconf._num_label = conf._num_label
     if self.use_ablpair:
         self.scorer = self.add_sub_node("s",
                                         FpSingleScorer(pc, conf.sconf))
     else:
         self.scorer = self.add_sub_node("s",
                                         FpPairedScorer(pc, conf.sconf))
Ejemplo n.º 5
0
class Inputter:
    def __init__(self, embedder: EmbedderNode, vpack: VocabPackage):
        self.vpack = vpack
        self.embedder = embedder
        # -----
        # prepare the inputter
        self.comp_names = embedder.comp_names
        self.comp_helpers = []
        for comp_name in self.comp_names:
            one_helper = InputHelper.get_input_helper(comp_name, comp_name, vpack)
            self.comp_helpers.append(one_helper)
            if comp_name == "bert":
                assert embedder.berter is not None
                one_helper.set_berter(berter=embedder.berter)
        # ====
        self.mask_padder = DataPadder(2, pad_vals=0., mask_range=2)

    def __call__(self, insts: List[GeneralSentence]):
        # first pad words to get masks
        _, masks_arr = self.mask_padder.pad([z.word_seq.idxes for z in insts])
        # then get each one
        ret_map = {"mask": masks_arr}
        for comp_name, comp_helper in zip(self.comp_names, self.comp_helpers):
            ret_map[comp_name] = comp_helper.prepare(insts)
        return ret_map

    # todo(note): return new masks (input is read only!!)
    def mask_input(self, input_map: Dict, input_erase_mask, nomask_names_set: Set):
        ret_map = {"mask": input_map["mask"]}
        for comp_name, comp_helper in zip(self.comp_names, self.comp_helpers):
            if comp_name in nomask_names_set:  # direct borrow that one
                ret_map[comp_name] = input_map[comp_name]
            else:
                ret_map[comp_name] = comp_helper.mask(input_map[comp_name], input_erase_mask)
        return ret_map
Ejemplo n.º 6
0
 def __init__(self, conf: GraphParserConf, vpack: VocabPackage):
     super().__init__(conf, vpack)
     # ===== Input Specification =====
     # both head/label padding with 0 (does not matter what, since will be masked)
     self.predict_padder = DataPadder(2, pad_vals=0)
     self.hlocal_padder = DataPadder(3, pad_vals=0.)
     #
     # todo(warn): adding-styled hlocal has problems intuitively, maybe not suitable for graph-parser
     self.norm_single, self.norm_local, self.norm_global, self.norm_hlocal = \
         [conf.output_normalizing==z for z in ["single", "local", "global", "hlocal"]]
     self.loss_prob, self.loss_hinge, self.loss_mr = [
         conf.tconf.loss_function == z for z in ["prob", "hinge", "mr"]
     ]
     self.alg_proj, self.alg_unproj, self.alg_greedy = [
         conf.iconf.dec_algorithm == z
         for z in ["proj", "unproj", "greedy"]
     ]
Ejemplo n.º 7
0
 def __init__(self, conf: G2ParserConf, vpack: VocabPackage):
     super().__init__(conf, vpack)
     # todo(note): the neural parameters are exactly the same as the EF one
     # ===== basic G1 Parser's loading
     # todo(note): there can be parameter mismatch (but all of them in non-trained part, thus will be fine)
     self.g1parser = G1Parser.pre_g1_init(self, conf.pre_g1_conf)
     self.lambda_g1_arc_training = conf.pre_g1_conf.lambda_g1_arc_training
     self.lambda_g1_arc_testing = conf.pre_g1_conf.lambda_g1_arc_testing
     self.lambda_g1_lab_training = conf.pre_g1_conf.lambda_g1_lab_training
     self.lambda_g1_lab_testing = conf.pre_g1_conf.lambda_g1_lab_testing
     #
     self.add_slayer()
     self.dl = G2DL(self.scorer, self.slayer, conf)
     #
     self.predict_padder = DataPadder(2, pad_vals=0)
     self.num_label = self.label_vocab.trg_len(
         True)  # todo(WARN): use the original idx
Ejemplo n.º 8
0
 def __init__(self, pc: BK.ParamCollection, conf: M3EncConf, tconf,
              vpack: VocabPackage):
     super().__init__(pc, conf, tconf, vpack)
     #
     self.conf = conf
     # ----- bert
     # modify bert_conf for other input
     BERT_OTHER_VSIZE = 50  # todo(+N): this should be enough for small inputs!
     conf.bert_conf.bert2_other_input_names = conf.bert_other_inputs
     conf.bert_conf.bert2_other_input_vsizes = [BERT_OTHER_VSIZE] * len(
         conf.bert_other_inputs)
     self.berter = self.add_sub_node("bert", Berter2(pc, conf.bert_conf))
     # -----
     # index fake sent
     self.index_helper = IndexerHelper(vpack)
     # extra encoder over bert?
     self.bert_dim, self.bert_fold = self.berter.get_output_dims()
     conf.m3_enc_conf._input_dim = self.bert_dim
     self.m3_encs = [
         self.add_sub_node("m3e", MyEncoder(pc, conf.m3_enc_conf))
         for _ in range(self.bert_fold)
     ]
     self.m3_enc_out_dim = self.m3_encs[0].get_output_dims()[0]
     # skip m3_enc?
     self.m3_enc_is_empty = all(len(z.layers) == 0 for z in self.m3_encs)
     if self.m3_enc_is_empty:
         assert all(z.get_output_dims()[0] == self.bert_dim
                    for z in self.m3_encs)
         zlog("For m3_enc, we will skip it since it is empty!!")
     # dep as basic?
     if conf.m2e_use_basic_dep:
         MAX_LABEL_NUM = 200  # this should be enough
         self.dep_label_emb = self.add_sub_node(
             "dlab",
             Embedding(self.pc,
                       MAX_LABEL_NUM,
                       conf.dep_label_dim,
                       name="dlab"))
         self.dep_layer = self.add_sub_node(
             "dep",
             TaskSpecAdp(pc, [(self.m3_enc_out_dim, self.bert_fold), None],
                         [conf.dep_label_dim], conf.dep_output_dim))
     else:
         self.dep_label_emb = self.dep_layer = None
     self.dep_padder = DataPadder(
         2, pad_vals=0)  # 0 for both head-idx and label
Ejemplo n.º 9
0
 def __init__(self, pc, conf: NodeExtractorConfBase, vocab: HLabelVocab,
              extract_type: str):
     super().__init__(pc, None, None)
     self.conf = conf
     self.vocab = vocab
     self.hl: HLabelNode = self.add_sub_node(
         "hl", HLabelNode(pc, conf.lab_conf, vocab))
     self.hl_output_size = self.hl.prediction_sizes[
         self.hl.eff_max_layer - 1]  # num of output labels
     #
     self.extract_type = extract_type
     self.items_getter = {
         "evt": self.get_events,
         "ef": lambda sent: sent.entity_fillers
     }[extract_type]
     self.constrain_evt_types = None
     # 2d pad
     self.padder_mask = DataPadder(2, pad_vals=0.)
     self.padder_idxes = DataPadder(
         2, pad_vals=0)  # todo(warn): 0 for full-nil
     self.padder_items = DataPadder(2, pad_vals=None)
     # 3d pad
     self.padder_mask_3d = DataPadder(3, pad_vals=0.)
     self.padder_items_3d = DataPadder(3, pad_vals=None)
Ejemplo n.º 10
0
 def __init__(self, pc: BK.ParamCollection, rconf: SL0Conf):
     super().__init__(pc, None, None)
     self.dim = rconf._input_dim  # both input/output dim
     # padders for child nodes
     self.chs_start_posi = -rconf.chs_num
     self.ch_idx_padder = DataPadder(2, pad_vals=0,
                                     mask_range=2)  # [*, num-ch]
     self.ch_label_padder = DataPadder(2, pad_vals=0)
     #
     self.label_embeddings = self.add_sub_node(
         "label",
         Embedding(pc, rconf._num_label, rconf.dim_label, fix_row0=False))
     self.dim_label = rconf.dim_label
     # todo(note): now adopting flatten groupings for basic, and then that is all, no more recurrent features
     # group 1: [cur, chs, par] -> head_pre_size
     self.use_chs = rconf.use_chs
     self.use_par = rconf.use_par
     self.use_label_feat = rconf.use_label_feat
     # components (add the parameters anyway)
     # todo(note): children features: children + (label of mod->children)
     self.chs_reprer = self.add_sub_node("chs", ChsReprer(pc, rconf))
     self.chs_ff = self.add_sub_node(
         "chs_ff",
         Affine(pc,
                self.chs_reprer.get_output_dims()[0],
                self.dim,
                act="tanh"))
     # todo(note): parent features: parent + (label of parent->mod)
     # todo(warn): always add label related params
     par_ff_inputs = [self.dim, rconf.dim_label]
     self.par_ff = self.add_sub_node(
         "par_ff", Affine(pc, par_ff_inputs, self.dim, act="tanh"))
     # no other groups anymore!
     if rconf.zero_extra_output_params:
         self.par_ff.zero_params()
         self.chs_ff.zero_params()
Ejemplo n.º 11
0
 def __init__(self, pc: BK.ParamCollection, conf: FpEncConf,
              vpack: VocabPackage):
     super().__init__(pc, None, None)
     self.conf = conf
     # ===== Vocab =====
     self.word_vocab = vpack.get_voc("word")
     self.char_vocab = vpack.get_voc("char")
     self.pos_vocab = vpack.get_voc("pos")
     # avoid no params error
     self._tmp_v = self.add_param("nope", (1, ))
     # ===== Model =====
     # embedding
     self.emb = self.add_sub_node("emb",
                                  MyEmbedder(self.pc, conf.emb_conf, vpack))
     self.emb_output_dim = self.emb.get_output_dims()[0]
     # bert
     self.bert = self.add_sub_node("bert", Berter2(self.pc, conf.bert_conf))
     self.bert_output_dim = self.bert.get_output_dims()[0]
     # make sure there are inputs
     assert self.emb_output_dim > 0 or self.bert_output_dim > 0
     # middle?
     if conf.middle_dim > 0:
         self.middle_node = self.add_sub_node(
             "mid",
             Affine(self.pc,
                    self.emb_output_dim + self.bert_output_dim,
                    conf.middle_dim,
                    act="elu"))
         self.enc_input_dim = conf.middle_dim
     else:
         self.middle_node = None
         self.enc_input_dim = self.emb_output_dim + self.bert_output_dim  # concat the two parts (if needed)
     # encoder?
     # todo(note): feed compute-on-the-fly hp
     conf.enc_conf._input_dim = self.enc_input_dim
     self.enc = self.add_sub_node("enc", MyEncoder(self.pc, conf.enc_conf))
     self.enc_output_dim = self.enc.get_output_dims()[0]
     # ===== Input Specification =====
     # inputs (word, char, pos) and vocabulary
     self.need_word = self.emb.has_word
     self.need_char = self.emb.has_char
     # todo(warn): currently only allow extra fields for POS
     self.need_pos = False
     if len(self.emb.extra_names) > 0:
         assert len(
             self.emb.extra_names) == 1 and self.emb.extra_names[0] == "pos"
         self.need_pos = True
     #
     self.word_padder = DataPadder(2,
                                   pad_vals=self.word_vocab.pad,
                                   mask_range=2)
     self.char_padder = DataPadder(3,
                                   pad_lens=(0, 0, conf.char_max_length),
                                   pad_vals=self.char_vocab.pad)
     self.pos_padder = DataPadder(2, pad_vals=self.pos_vocab.pad)
Ejemplo n.º 12
0
 def __init__(self, pc: BK.ParamCollection, bconf: BTConf,
              vpack: VocabPackage):
     super().__init__(pc, None, None)
     self.bconf = bconf
     # ===== Vocab =====
     self.word_vocab = vpack.get_voc("word")
     self.char_vocab = vpack.get_voc("char")
     self.pos_vocab = vpack.get_voc("pos")
     # ===== Model =====
     # embedding
     self.emb = self.add_sub_node(
         "emb", MyEmbedder(self.pc, bconf.emb_conf, vpack))
     emb_output_dim = self.emb.get_output_dims()[0]
     # encoder0 for jpos
     # todo(note): will do nothing if not use_jpos
     bconf.jpos_conf._input_dim = emb_output_dim
     self.jpos_enc = self.add_sub_node(
         "enc0", JPosModule(self.pc, bconf.jpos_conf, self.pos_vocab))
     enc0_output_dim = self.jpos_enc.get_output_dims()[0]
     # encoder
     # todo(0): feed compute-on-the-fly hp
     bconf.enc_conf._input_dim = enc0_output_dim
     self.enc = self.add_sub_node("enc", MyEncoder(self.pc, bconf.enc_conf))
     self.enc_output_dim = self.enc.get_output_dims()[0]
     # ===== Input Specification =====
     # inputs (word, char, pos) and vocabulary
     self.need_word = self.emb.has_word
     self.need_char = self.emb.has_char
     # todo(warn): currently only allow extra fields for POS
     self.need_pos = False
     if len(self.emb.extra_names) > 0:
         assert len(
             self.emb.extra_names) == 1 and self.emb.extra_names[0] == "pos"
         self.need_pos = True
     # todo(warn): currently only allow one aux field
     self.need_aux = False
     if len(self.emb.dim_auxes) > 0:
         assert len(self.emb.dim_auxes) == 1
         self.need_aux = True
     #
     self.word_padder = DataPadder(2,
                                   pad_vals=self.word_vocab.pad,
                                   mask_range=2)
     self.char_padder = DataPadder(3,
                                   pad_lens=(0, 0, bconf.char_max_length),
                                   pad_vals=self.char_vocab.pad)
     self.pos_padder = DataPadder(2, pad_vals=self.pos_vocab.pad)
     #
     self.random_sample_stream = Random.stream(Random.random_sample)
Ejemplo n.º 13
0
class InputHelper:
    def __init__(self, comp_name, vpack: VocabPackage):
        self.comp_name = comp_name
        self.comp_seq_name = f"{comp_name}_seq"
        self.voc = vpack.get_voc(comp_name)
        self.padder = DataPadder(2, pad_vals=0)  # pad 0

    # return batched arr
    def prepare(self, insts: List[GeneralSentence]):
        cur_input_list = [getattr(z, self.comp_seq_name).idxes for z in insts]
        cur_input_arr, _ = self.padder.pad(cur_input_list)
        return cur_input_arr

    def mask(self, v, erase_mask):
        IDX_MASK = self.voc.err  # todo(note): this one is unused, thus just take it!
        ret_arr = v * (1-erase_mask) + IDX_MASK * erase_mask
        return ret_arr

    @staticmethod
    def get_input_helper(name, *args, **kwargs):
        helper_type = {"char": CharCnnHelper, "posi": PosiHelper, "bert": BertInputHelper}.get(name, PlainSeqHelper)
        return helper_type(*args, **kwargs)
Ejemplo n.º 14
0
class NodeExtractorBase(BasicNode):
    def __init__(self, pc, conf: NodeExtractorConfBase, vocab: HLabelVocab,
                 extract_type: str):
        super().__init__(pc, None, None)
        self.conf = conf
        self.vocab = vocab
        self.hl: HLabelNode = self.add_sub_node(
            "hl", HLabelNode(pc, conf.lab_conf, vocab))
        self.hl_output_size = self.hl.prediction_sizes[
            self.hl.eff_max_layer - 1]  # num of output labels
        #
        self.extract_type = extract_type
        self.items_getter = {
            "evt": self.get_events,
            "ef": lambda sent: sent.entity_fillers
        }[extract_type]
        self.constrain_evt_types = None
        # 2d pad
        self.padder_mask = DataPadder(2, pad_vals=0.)
        self.padder_idxes = DataPadder(
            2, pad_vals=0)  # todo(warn): 0 for full-nil
        self.padder_items = DataPadder(2, pad_vals=None)
        # 3d pad
        self.padder_mask_3d = DataPadder(3, pad_vals=0.)
        self.padder_items_3d = DataPadder(3, pad_vals=None)

    # =====
    # idx tranforms
    def hlidx2idx(self, hlidx: HLabelIdx) -> int:
        return hlidx.get_idx(self.hl.eff_max_layer - 1)

    def idx2hlidx(self, idx: int) -> HLabelIdx:
        return self.vocab.get_hlidx(idx, self.hl.eff_max_layer)

    # events possibly filtered by constrain_evt_types
    def get_events(self, sent):
        ret = sent.events
        constrain_evt_types = self.constrain_evt_types
        if constrain_evt_types is None:
            return ret
        else:
            return [z for z in ret if z.type in constrain_evt_types]

    def set_constrain_evt_types(self, constrain_evt_types):
        self.constrain_evt_types = constrain_evt_types

    # =====
    # main procedure

    def loss(self, insts: List, input_lexi, input_expr, input_mask, margin=0.):
        raise NotImplementedError()

    def predict(self, insts: List, input_lexi, input_expr, input_mask):
        raise NotImplementedError()

    def lookup(self, insts: List, input_lexi, input_expr, input_mask):
        raise NotImplementedError()

    # =====
    # basic input specifying for this module
    # todo(+N): currently entangling things with the least elegant way?

    # batch inputs for head mode
    def batch_inputs_h(self, insts: List[Sentence]):
        key, items_getter = self.extract_type, self.items_getter
        nil_idx = 0
        # get gold/input data and batch
        all_masks, all_idxes, all_items, all_valid = [], [], [], []
        all_idxes2, all_items2 = [], []  # secondary types
        for sent in insts:
            preps = sent.preps.get(key)
            # not cached, rebuild them
            if preps is None:
                length = sent.length
                items = items_getter(sent)
                # token-idx -> ...
                prep_masks, prep_idxes, prep_items = [0.] * length, [
                    nil_idx
                ] * length, [None] * length
                prep_idxes2, prep_items2 = [nil_idx] * length, [None] * length
                if items is None:
                    # todo(note): there are samples that do not have entity annotations (KBP15)
                    #  final 0/1 indicates valid or not
                    prep_valid = 0.
                else:
                    prep_valid = 1.
                    for one_item in items:
                        this_hwidx = one_item.mention.hard_span.head_wid
                        this_hlidx = one_item.type_idx
                        # todo(+N): ignore except the first two types (already ranked by type-freq)
                        if prep_idxes[this_hwidx] == 0:
                            prep_masks[this_hwidx] = 1.
                            prep_idxes[this_hwidx] = self.hlidx2idx(
                                this_hlidx)  # change to int here!
                            prep_items[this_hwidx] = one_item
                        elif prep_idxes2[this_hwidx] == 0:
                            prep_idxes2[this_hwidx] = self.hlidx2idx(
                                this_hlidx)  # change to int here!
                            prep_items2[this_hwidx] = one_item
                sent.preps[key] = (prep_masks, prep_idxes, prep_items,
                                   prep_valid, prep_idxes2, prep_items2)
            else:
                prep_masks, prep_idxes, prep_items, prep_valid, prep_idxes2, prep_items2 = preps
            # =====
            all_masks.append(prep_masks)
            all_idxes.append(prep_idxes)
            all_items.append(prep_items)
            all_valid.append(prep_valid)
            all_idxes2.append(prep_idxes2)
            all_items2.append(prep_items2)
        # pad and batch
        mention_masks = BK.input_real(
            self.padder_mask.pad(all_masks)[0])  # [*, slen]
        mention_idxes = BK.input_idx(
            self.padder_idxes.pad(all_idxes)[0])  # [*, slen]
        mention_items_arr, _ = self.padder_items.pad(all_items)  # [*, slen]
        mention_valid = BK.input_real(all_valid)  # [*]
        mention_idxes2 = BK.input_idx(
            self.padder_idxes.pad(all_idxes2)[0])  # [*, slen]
        mention_items2_arr, _ = self.padder_items.pad(all_items2)  # [*, slen]
        return mention_masks, mention_idxes, mention_items_arr, mention_valid, mention_idxes2, mention_items2_arr

    # batch inputs for gene0 mode (separate for each label)
    def batch_inputs_g0(self, insts: List[Sentence]):
        # similar to "batch_inputs_h", but further extend for each label
        key, items_getter = self.extract_type, self.items_getter
        # nil_idx = 0
        # get gold/input data and batch
        output_size = self.hl_output_size
        all_masks, all_items, all_valid = [], [], []
        for sent in insts:
            preps = sent.preps.get(key)
            # not cached, rebuild them
            if preps is None:
                length = sent.length
                items = items_getter(sent)
                # token-idx -> [slen, out-size]
                prep_masks = [[0. for _i1 in range(output_size)]
                              for _i0 in range(length)]
                prep_items = [[None for _i1 in range(output_size)]
                              for _i0 in range(length)]
                if items is None:
                    # todo(note): there are samples that do not have entity annotations (KBP15)
                    #  final 0/1 indicates valid or not
                    prep_valid = 0.
                else:
                    prep_valid = 1.
                    for one_item in items:
                        this_hwidx = one_item.mention.hard_span.head_wid
                        this_hlidx = one_item.type_idx
                        this_tidx = self.hlidx2idx(
                            this_hlidx)  # change to int here!
                        # todo(+N): simply ignore repeated ones with same type and trigger
                        if prep_masks[this_hwidx][this_tidx] == 0.:
                            prep_masks[this_hwidx][this_tidx] = 1.
                            prep_items[this_hwidx][this_tidx] = one_item
                sent.preps[key] = (prep_masks, prep_items, prep_valid)
            else:
                prep_masks, prep_items, prep_valid = preps
            # =====
            all_masks.append(prep_masks)
            all_items.append(prep_items)
            all_valid.append(prep_valid)
        # pad and batch
        mention_masks = BK.input_real(
            self.padder_mask_3d.pad(all_masks)[0])  # [*, slen, L]
        mention_idxes = None
        mention_items_arr, _ = self.padder_items_3d.pad(
            all_items)  # [*, slen, L]
        mention_valid = BK.input_real(all_valid)  # [*]
        return mention_masks, mention_idxes, mention_items_arr, mention_valid

    # batch inputs for gene1 mode (seq-gene mode)
    # todo(note): the return is different than previous, here directly idx-based
    def batch_inputs_g1(self, insts: List[Sentence]):
        train_reverse_evetns = self.conf.train_reverse_evetns  # todo(note): this option is from derived class
        _tmp_f = lambda x: list(reversed(x)
                                ) if train_reverse_evetns else lambda x: x
        key, items_getter = self.extract_type, self.items_getter
        # nil_idx = 0  # nil means eos
        # get gold/input data and batch
        all_widxes, all_lidxes, all_vmasks, all_items, all_valid = [], [], [], [], []
        for sent in insts:
            preps = sent.preps.get(key)
            # not cached, rebuild them
            if preps is None:
                items = items_getter(sent)
                # todo(note): directly add, assume they are already sorted in a good way (widx+lidx); 0(nil) as eos
                if items is None:
                    prep_valid = 0.
                    # prep_widxes, prep_lidxes, prep_vmasks, prep_items = [0], [0], [1.], [None]
                    prep_widxes, prep_lidxes, prep_vmasks, prep_items = [], [], [], []
                else:
                    prep_valid = 1.
                    prep_widxes = _tmp_f(
                        [z.mention.hard_span.head_wid for z in items]) + [0]
                    prep_lidxes = _tmp_f(
                        [self.hlidx2idx(z.type_idx) for z in items]) + [0]
                    prep_vmasks = [1.] * (len(items) + 1)
                    prep_items = _tmp_f(items.copy()) + [None]
                sent.preps[key] = (prep_widxes, prep_lidxes, prep_vmasks,
                                   prep_items, prep_valid)
            else:
                prep_widxes, prep_lidxes, prep_vmasks, prep_items, prep_valid = preps
            # =====
            all_widxes.append(prep_widxes)
            all_lidxes.append(prep_lidxes)
            all_vmasks.append(prep_vmasks)
            all_items.append(prep_items)
            all_valid.append(prep_valid)
        # pad and batch
        mention_widxes = BK.input_idx(
            self.padder_idxes.pad(all_widxes)[0])  # [*, ?]
        mention_lidxes = BK.input_idx(
            self.padder_idxes.pad(all_lidxes)[0])  # [*, ?]
        mention_vmasks = BK.input_real(
            self.padder_mask.pad(all_vmasks)[0])  # [*, ?]
        mention_items_arr, _ = self.padder_items.pad(all_items)  # [*, ?]
        mention_valid = BK.input_real(all_valid)  # [*]
        return mention_widxes, mention_lidxes, mention_vmasks, mention_items_arr, mention_valid
Ejemplo n.º 15
0
class FpDecoder(BasicNode):
    def __init__(self, pc, conf: FpDecConf, label_vocab):
        super().__init__(pc, None, None)
        self.conf = conf
        self.label_vocab = label_vocab
        self.predict_padder = DataPadder(2, pad_vals=0)
        # the scorer
        self.use_ablpair = conf.use_ablpair
        conf.sconf._input_dim = conf._input_dim
        conf.sconf._num_label = conf._num_label
        if self.use_ablpair:
            self.scorer = self.add_sub_node("s",
                                            FpSingleScorer(pc, conf.sconf))
        else:
            self.scorer = self.add_sub_node("s",
                                            FpPairedScorer(pc, conf.sconf))

    # -----
    # scoring
    def _score(self, enc_expr, mask_expr):
        # -----
        def _special_score(
                one_score):  # specially change ablpair scores into [bs,m,h,*]
            root_score = one_score[:, :, 0].unsqueeze(2)  # [bs, rlen, 1, *]
            tmp_shape = BK.get_shape(root_score)
            tmp_shape[1] = 1  # [bs, 1, 1, *]
            padded_root_score = BK.concat([BK.zeros(tmp_shape), root_score],
                                          dim=1)  # [bs, rlen+1, 1, *]
            final_score = BK.concat(
                [padded_root_score,
                 one_score.transpose(1, 2)],
                dim=2)  # [bs, rlen+1[m], rlen+1[h], *]
            return final_score

        # -----
        if self.use_ablpair:
            input_mask_expr = (
                mask_expr.unsqueeze(-1) *
                mask_expr.unsqueeze(-2))[:, 1:]  # [bs, rlen, rlen+1]
            arc_score = self.scorer.transform_and_arc_score(
                enc_expr, input_mask_expr)  # [bs, rlen, rlen+1, 1]
            lab_score = self.scorer.transform_and_lab_score(
                enc_expr, input_mask_expr)  # [bs, rlen, rlen+1, Lab]
            # put root-scores for both directions
            arc_score = _special_score(arc_score)
            lab_score = _special_score(lab_score)
        else:
            # todo(+2): for training, we can simply select and lab-score
            arc_score = self.scorer.transform_and_arc_score(
                enc_expr, mask_expr)  # [bs, m, h, 1]
            lab_score = self.scorer.transform_and_lab_score(
                enc_expr, mask_expr)  # [bs, m, h, Lab]
        # mask out diag scores
        diag_mask = BK.eye(BK.get_shape(arc_score, 1))
        diag_mask[0, 0] = 0.
        diag_add = Constants.REAL_PRAC_MIN * (
            diag_mask.unsqueeze(-1).unsqueeze(0))  # [1, m, h, 1]
        arc_score += diag_add
        lab_score += diag_add
        return arc_score, lab_score

    # loss
    # todo(note): no margins here, simply using target-selection cross-entropy
    def loss(self, insts: List[ParseInstance], enc_expr, mask_expr, **kwargs):
        conf = self.conf
        # scoring
        arc_score, lab_score = self._score(enc_expr,
                                           mask_expr)  # [bs, m, h, *]
        # loss
        bsize, max_len = BK.get_shape(mask_expr)
        # gold heads and labels
        gold_heads_arr, _ = self.predict_padder.pad(
            [z.heads.vals for z in insts])
        # todo(note): here use the original idx of label, no shift!
        gold_labels_arr, _ = self.predict_padder.pad(
            [z.labels.idxes for z in insts])
        gold_heads_expr = BK.input_idx(gold_heads_arr)  # [bs, Len]
        gold_labels_expr = BK.input_idx(gold_labels_arr)  # [bs, Len]
        # collect the losses
        arange_bs_expr = BK.arange_idx(bsize).unsqueeze(-1)  # [bs, 1]
        arange_m_expr = BK.arange_idx(max_len).unsqueeze(0)  # [1, Len]
        # logsoftmax and losses
        arc_logsoftmaxs = BK.log_softmax(arc_score.squeeze(-1),
                                         -1)  # [bs, m, h]
        lab_logsoftmaxs = BK.log_softmax(lab_score, -1)  # [bs, m, h, Lab]
        arc_sel_ls = arc_logsoftmaxs[arange_bs_expr, arange_m_expr,
                                     gold_heads_expr]  # [bs, Len]
        lab_sel_ls = lab_logsoftmaxs[arange_bs_expr, arange_m_expr,
                                     gold_heads_expr,
                                     gold_labels_expr]  # [bs, Len]
        # head selection (no root)
        arc_loss_sum = (-arc_sel_ls * mask_expr)[:, 1:].sum()
        lab_loss_sum = (-lab_sel_ls * mask_expr)[:, 1:].sum()
        final_loss = conf.lambda_arc * arc_loss_sum + conf.lambda_lab * lab_loss_sum
        final_loss_count = mask_expr[:, 1:].sum()
        return [[final_loss, final_loss_count]]

    # decode
    def predict(self, insts: List[ParseInstance], enc_expr, mask_expr,
                **kwargs):
        conf = self.conf
        # scoring
        arc_score, lab_score = self._score(enc_expr,
                                           mask_expr)  # [bs, m, h, *]
        full_score = BK.log_softmax(arc_score, -2) + BK.log_softmax(
            lab_score, -1)  # [bs, m, h, Lab]
        # decode
        mst_lengths = [len(z) + 1
                       for z in insts]  # +1 to include ROOT for mst decoding
        mst_lengths_arr = np.asarray(mst_lengths, dtype=np.int32)
        mst_heads_arr, mst_labels_arr, mst_scores_arr = \
            nmst_unproj(full_score, mask_expr, mst_lengths_arr, labeled=True, ret_arr=True)
        # ===== assign, todo(warn): here, the labels are directly original idx, no need to change
        misc_prefix = "g"
        for one_idx, one_inst in enumerate(insts):
            cur_length = mst_lengths[one_idx]
            one_inst.pred_heads.set_vals(
                mst_heads_arr[one_idx]
                [:cur_length])  # directly int-val for heads
            one_inst.pred_labels.build_vals(
                mst_labels_arr[one_idx][:cur_length], self.label_vocab)
            one_scores = mst_scores_arr[one_idx][:cur_length]
            one_inst.pred_par_scores.set_vals(one_scores)
            # extra output
            one_inst.extra_pred_misc[misc_prefix +
                                     "_score"] = one_scores.tolist()
Ejemplo n.º 16
0
 def __init__(self, pc: BK.ParamCollection, bconf: BTConf, tconf: 'BaseTrainingConf', vpack: VocabPackage):
     super().__init__(pc, None, None)
     self.bconf = bconf
     # ===== Vocab =====
     self.word_vocab = vpack.get_voc("word")
     self.char_vocab = vpack.get_voc("char")
     self.lemma_vocab = vpack.get_voc("lemma")
     self.upos_vocab = vpack.get_voc("upos")
     self.ulabel_vocab = vpack.get_voc("ulabel")
     # ===== Model =====
     # embedding
     self.emb = self.add_sub_node("emb", MyEmbedder(self.pc, bconf.emb_conf, vpack))
     emb_output_dim = self.emb.get_output_dims()[0]
     self.emb_output_dim = emb_output_dim
     # doc hint
     self.use_doc_hint = bconf.use_doc_hint
     self.dh_combine_method = bconf.dh_combine_method
     if self.use_doc_hint:
         assert len(bconf.emb_conf.dim_auxes)>0
         # todo(note): currently use the concat of them if input multiple layers
         bconf.dh_conf._input_dim = bconf.emb_conf.dim_auxes[0]  # same as input bert dim
         bconf.dh_conf._output_dim = emb_output_dim  # same as emb_output_dim
         self.dh_node = self.add_sub_node("dh", DocHintModule(pc, bconf.dh_conf))
     else:
         self.dh_node = None
     # encoders
     # shared
     # todo(note): feed compute-on-the-fly hp
     bconf.enc_conf._input_dim = emb_output_dim
     self.enc = self.add_sub_node("enc", MyEncoder(self.pc, bconf.enc_conf))
     tmp_enc_output_dim = self.enc.get_output_dims()[0]
     # privates
     bconf.enc_ef_conf._input_dim = tmp_enc_output_dim
     self.enc_ef = self.add_sub_node("enc_ef", MyEncoder(self.pc, bconf.enc_ef_conf))
     self.enc_ef_output_dim = self.enc_ef.get_output_dims()[0]
     bconf.enc_evt_conf._input_dim = tmp_enc_output_dim
     self.enc_evt = self.add_sub_node("enc_evt", MyEncoder(self.pc, bconf.enc_evt_conf))
     self.enc_evt_output_dim = self.enc_evt.get_output_dims()[0]
     # ===== Input Specification =====
     # inputs (word, lemma, char, upos, ulabel) and vocabulary
     self.need_word = self.emb.has_word
     self.need_char = self.emb.has_char
     # extra fields
     # todo(warn): need to
     self.need_lemma = False
     self.need_upos = False
     self.need_ulabel = False
     for one_extra_name in self.emb.extra_names:
         if one_extra_name == "lemma":
             self.need_lemma = True
         elif one_extra_name == "upos":
             self.need_upos = True
         elif one_extra_name == "ulabel":
             self.need_ulabel = True
         else:
             raise NotImplementedError("UNK extra input name: " + one_extra_name)
     # todo(warn): currently only allow one aux field
     self.need_aux = False
     if len(self.emb.dim_auxes) > 0:
         assert len(self.emb.dim_auxes) == 1
         self.need_aux = True
     # padders
     self.word_padder = DataPadder(2, pad_vals=self.word_vocab.pad, mask_range=2)
     self.char_padder = DataPadder(3, pad_lens=(0, 0, bconf.char_max_length), pad_vals=self.char_vocab.pad)
     self.lemma_padder = DataPadder(2, pad_vals=self.lemma_vocab.pad)
     self.upos_padder = DataPadder(2, pad_vals=self.upos_vocab.pad)
     self.ulabel_padder = DataPadder(2, pad_vals=self.ulabel_vocab.pad)
     #
     self.random_sample_stream = Random.stream(Random.random_sample)
     self.train_skip_noevt_rate = tconf.train_skip_noevt_rate
     self.train_skip_length = tconf.train_skip_length
     self.train_min_length = tconf.train_min_length
     self.test_min_length = tconf.test_min_length
     self.test_skip_noevt_rate = tconf.test_skip_noevt_rate
     self.train_sent_based = tconf.train_sent_based
     #
     assert not self.train_sent_based, "The basic model should not use this sent-level mode!"
Ejemplo n.º 17
0
class MyIEBT(BasicNode):
    def __init__(self, pc: BK.ParamCollection, bconf: BTConf, tconf: 'BaseTrainingConf', vpack: VocabPackage):
        super().__init__(pc, None, None)
        self.bconf = bconf
        # ===== Vocab =====
        self.word_vocab = vpack.get_voc("word")
        self.char_vocab = vpack.get_voc("char")
        self.lemma_vocab = vpack.get_voc("lemma")
        self.upos_vocab = vpack.get_voc("upos")
        self.ulabel_vocab = vpack.get_voc("ulabel")
        # ===== Model =====
        # embedding
        self.emb = self.add_sub_node("emb", MyEmbedder(self.pc, bconf.emb_conf, vpack))
        emb_output_dim = self.emb.get_output_dims()[0]
        self.emb_output_dim = emb_output_dim
        # doc hint
        self.use_doc_hint = bconf.use_doc_hint
        self.dh_combine_method = bconf.dh_combine_method
        if self.use_doc_hint:
            assert len(bconf.emb_conf.dim_auxes)>0
            # todo(note): currently use the concat of them if input multiple layers
            bconf.dh_conf._input_dim = bconf.emb_conf.dim_auxes[0]  # same as input bert dim
            bconf.dh_conf._output_dim = emb_output_dim  # same as emb_output_dim
            self.dh_node = self.add_sub_node("dh", DocHintModule(pc, bconf.dh_conf))
        else:
            self.dh_node = None
        # encoders
        # shared
        # todo(note): feed compute-on-the-fly hp
        bconf.enc_conf._input_dim = emb_output_dim
        self.enc = self.add_sub_node("enc", MyEncoder(self.pc, bconf.enc_conf))
        tmp_enc_output_dim = self.enc.get_output_dims()[0]
        # privates
        bconf.enc_ef_conf._input_dim = tmp_enc_output_dim
        self.enc_ef = self.add_sub_node("enc_ef", MyEncoder(self.pc, bconf.enc_ef_conf))
        self.enc_ef_output_dim = self.enc_ef.get_output_dims()[0]
        bconf.enc_evt_conf._input_dim = tmp_enc_output_dim
        self.enc_evt = self.add_sub_node("enc_evt", MyEncoder(self.pc, bconf.enc_evt_conf))
        self.enc_evt_output_dim = self.enc_evt.get_output_dims()[0]
        # ===== Input Specification =====
        # inputs (word, lemma, char, upos, ulabel) and vocabulary
        self.need_word = self.emb.has_word
        self.need_char = self.emb.has_char
        # extra fields
        # todo(warn): need to
        self.need_lemma = False
        self.need_upos = False
        self.need_ulabel = False
        for one_extra_name in self.emb.extra_names:
            if one_extra_name == "lemma":
                self.need_lemma = True
            elif one_extra_name == "upos":
                self.need_upos = True
            elif one_extra_name == "ulabel":
                self.need_ulabel = True
            else:
                raise NotImplementedError("UNK extra input name: " + one_extra_name)
        # todo(warn): currently only allow one aux field
        self.need_aux = False
        if len(self.emb.dim_auxes) > 0:
            assert len(self.emb.dim_auxes) == 1
            self.need_aux = True
        # padders
        self.word_padder = DataPadder(2, pad_vals=self.word_vocab.pad, mask_range=2)
        self.char_padder = DataPadder(3, pad_lens=(0, 0, bconf.char_max_length), pad_vals=self.char_vocab.pad)
        self.lemma_padder = DataPadder(2, pad_vals=self.lemma_vocab.pad)
        self.upos_padder = DataPadder(2, pad_vals=self.upos_vocab.pad)
        self.ulabel_padder = DataPadder(2, pad_vals=self.ulabel_vocab.pad)
        #
        self.random_sample_stream = Random.stream(Random.random_sample)
        self.train_skip_noevt_rate = tconf.train_skip_noevt_rate
        self.train_skip_length = tconf.train_skip_length
        self.train_min_length = tconf.train_min_length
        self.test_min_length = tconf.test_min_length
        self.test_skip_noevt_rate = tconf.test_skip_noevt_rate
        self.train_sent_based = tconf.train_sent_based
        #
        assert not self.train_sent_based, "The basic model should not use this sent-level mode!"

    def get_output_dims(self, *input_dims):
        return ([self.enc_ef_output_dim, self.enc_evt_output_dim], )

    #
    def refresh(self, rop=None):
        zfatal("Should call special_refresh instead!")

    # ====
    # special routines

    def special_refresh(self, embed_rop, other_rop):
        self.emb.refresh(embed_rop)
        self.enc.refresh(other_rop)
        self.enc_ef.refresh(other_rop)
        self.enc_evt.refresh(other_rop)
        if self.dh_node is not None:
            return self.dh_node.refresh(other_rop)

    def prepare_training_rop(self):
        mconf = self.bconf
        embed_rop = RefreshOptions(hdrop=mconf.drop_embed, dropmd=mconf.dropmd_embed, fix_drop=mconf.fix_drop)
        other_rop = RefreshOptions(hdrop=mconf.drop_hidden, idrop=mconf.idrop_rnn, gdrop=mconf.gdrop_rnn,
                                   fix_drop=mconf.fix_drop)
        return embed_rop, other_rop

    # =====
    # run
    def _prepare_input(self, sents: List[Sentence], training: bool):
        word_arr, char_arr, extra_arrs, aux_arrs = None, None, [], []
        # ===== specially prepare for the words
        wv = self.word_vocab
        W_UNK = wv.unk
        UNK_REP_RATE = self.bconf.singleton_unk
        UNK_REP_THR = self.bconf.singleton_thr
        word_act_idxes = []
        if training and UNK_REP_RATE>0.:    # replace unfreq/singleton words with UNK
            for one_inst in sents:
                one_act_idxes = []
                for one_idx in one_inst.words.idxes:
                    one_freq = wv.idx2val(one_idx)
                    if one_freq is not None and one_freq >= 1 and one_freq <= UNK_REP_THR:
                        if next(self.random_sample_stream) < (UNK_REP_RATE/one_freq):
                            one_idx = W_UNK
                    one_act_idxes.append(one_idx)
                word_act_idxes.append(one_act_idxes)
        else:
            word_act_idxes = [z.words.idxes for z in sents]
        # todo(warn): still need the masks
        word_arr, mask_arr = self.word_padder.pad(word_act_idxes)
        # =====
        if not self.need_word:
            word_arr = None
        if self.need_char:
            chars = [z.chars.idxes for z in sents]
            char_arr, _ = self.char_padder.pad(chars)
        # extra ones: lemma, upos, ulabel
        if self.need_lemma:
            lemmas = [z.lemmas.idxes for z in sents]
            lemmas_arr, _ = self.lemma_padder.pad(lemmas)
            extra_arrs.append(lemmas_arr)
        if self.need_upos:
            uposes = [z.uposes.idxes for z in sents]
            upos_arr, _ = self.upos_padder.pad(uposes)
            extra_arrs.append(upos_arr)
        if self.need_ulabel:
            ulabels = [z.ud_labels.idxes for z in sents]
            ulabels_arr, _ = self.ulabel_padder.pad(ulabels)
            extra_arrs.append(ulabels_arr)
        # aux ones
        if self.need_aux:
            aux_arr_list = [z.extra_features["aux_repr"] for z in sents]
            # pad
            padded_seq_len = int(mask_arr.shape[1])
            final_aux_arr_list = []
            for cur_arr in aux_arr_list:
                cur_len = len(cur_arr)
                if cur_len > padded_seq_len:
                    final_aux_arr_list.append(cur_arr[:padded_seq_len])
                else:
                    final_aux_arr_list.append(np.pad(cur_arr, ((0,padded_seq_len-cur_len),(0,0)), 'constant'))
            aux_arrs.append(np.stack(final_aux_arr_list, 0))
        #
        input_repr = self.emb(word_arr, char_arr, extra_arrs, aux_arrs)
        # [BS, Len, Dim], [BS, Len]
        return input_repr, mask_arr

    # input: Tuple(Sentence, ...)
    def _bucket_sents_by_length(self, all_sents, enc_bucket_range: int, getlen_f=lambda x: x[0].length, max_bsize=None):
        # split into buckets
        all_buckets = []
        cur_local_sidx = 0
        use_max_bsize = (max_bsize is not None)
        while cur_local_sidx < len(all_sents):
            cur_bucket = []
            starting_slen = getlen_f(all_sents[cur_local_sidx])
            ending_slen = starting_slen + enc_bucket_range
            # searching forward
            tmp_sidx = cur_local_sidx
            while tmp_sidx < len(all_sents):
                one_sent = all_sents[tmp_sidx]
                one_slen = getlen_f(one_sent)
                if one_slen>ending_slen or (use_max_bsize and len(cur_bucket)>=max_bsize):
                    break
                else:
                    cur_bucket.append(one_sent)
                tmp_sidx += 1
            # put bucket and next
            all_buckets.append(cur_bucket)
            cur_local_sidx = tmp_sidx
        return all_buckets

    # todo(warn): for rnn, need to transpose masks, thus need np.array
    # TODO(+N): here we encode only at sentence level, encoding doc level might be helpful, but much harder to batch
    #  therefore, still take DOC as input, since may be extended to doc-level encoding
    # return input_repr, enc_repr, mask_arr
    def run(self, insts: List[DocInstance], training: bool):
        # make it as sentence level processing (group the sentences by length, and ignore doc level for now)
        # skip no content sentences in training?
        # assert not self.train_sent_based, "The basic model should not use this sent-level mode!"
        all_sents = []  # (inst, d_idx, s_idx)
        for d_idx, one_doc in enumerate(insts):
            for s_idx, x in enumerate(one_doc.sents):
                if training:
                    if x.length<self.train_skip_length and x.length>=self.train_min_length \
                            and (len(x.events)>0 or next(self.random_sample_stream)>self.train_skip_noevt_rate):
                        all_sents.append((x, d_idx, s_idx))
                else:
                    if x.length >= self.test_min_length:
                        all_sents.append((x, d_idx, s_idx))
        return self.run_sents(all_sents, insts, training)

    # input interested sents: Tuple[Sentence, DocId, SentId]
    def run_sents(self, all_sents: List, all_docs: List[DocInstance], training: bool, use_one_bucket=False):
        if use_one_bucket:
            all_buckets = [all_sents]  # when we do not want to split if we know the input lengths do not vary too much
        else:
            all_sents.sort(key=lambda x: x[0].length)
            all_buckets = self._bucket_sents_by_length(all_sents, self.bconf.enc_bucket_range)
        # doc hint
        use_doc_hint = self.use_doc_hint
        if use_doc_hint:
            dh_sent_repr = self.dh_node.run(all_docs)  # [NumDoc, MaxSent, D]
        else:
            dh_sent_repr = None
        # encoding for each of the bucket
        rets = []
        dh_add, dh_both, dh_cls = [self.dh_combine_method==z for z in ["add", "both", "cls"]]
        for one_bucket in all_buckets:
            one_sents = [z[0] for z in one_bucket]
            # [BS, Len, Di], [BS, Len]
            input_repr0, mask_arr0 = self._prepare_input(one_sents, training)
            if use_doc_hint:
                one_d_idxes = BK.input_idx([z[1] for z in one_bucket])
                one_s_idxes = BK.input_idx([z[2] for z in one_bucket])
                one_s_reprs = dh_sent_repr[one_d_idxes, one_s_idxes].unsqueeze(-2)  # [BS, 1, D]
                if dh_add:
                    input_repr = input_repr0 + one_s_reprs  # [BS, slen, D]
                    mask_arr = mask_arr0
                elif dh_both:
                    input_repr = BK.concat([one_s_reprs, input_repr0, one_s_reprs], -2)  # [BS, 2+slen, D]
                    mask_arr = np.pad(mask_arr0, ((0,0),(1,1)), 'constant', constant_values=1.)  # [BS, 2+slen]
                elif dh_cls:
                    input_repr = BK.concat([one_s_reprs, input_repr0[:, 1:]], -2)  # [BS, slen, D]
                    mask_arr = mask_arr0
                else:
                    raise NotImplementedError()
            else:
                input_repr, mask_arr = input_repr0, mask_arr0
            # [BS, Len, De]
            enc_repr = self.enc(input_repr, mask_arr)
            # separate ones (possibly using detach to avoid gradients for some of them)
            enc_repr_ef = self.enc_ef(enc_repr.detach() if self.bconf.enc_ef_input_detach else enc_repr, mask_arr)
            enc_repr_evt = self.enc_evt(enc_repr.detach() if self.bconf.enc_evt_input_detach else enc_repr, mask_arr)
            if use_doc_hint and dh_both:
                one_ret = (one_sents, input_repr0, enc_repr_ef[:, 1:-1].contiguous(), enc_repr_evt[:, 1:-1].contiguous(), mask_arr0)
            else:
                one_ret = (one_sents, input_repr0, enc_repr_ef, enc_repr_evt, mask_arr0)
            rets.append(one_ret)
        # todo(note): returning tuple is (List[Sentence], Tensor, Tensor, Tensor)
        return rets

    # special routine
    def aug_words_and_embs(self, aug_vocab, aug_wv):
        orig_vocab = self.word_vocab
        orig_arr = self.emb.word_embed.E.detach().cpu().numpy()
        # todo(+2): find same-spelling words in the original vocab if not-hit in the extra_embed?
        # todo(warn): here aug_vocab should be find in aug_wv
        aug_arr = aug_vocab.filter_embed(aug_wv, assert_all_hit=True)
        new_vocab, new_arr = MultiHelper.aug_vocab_and_arr(orig_vocab, orig_arr, aug_vocab, aug_arr, aug_override=True)
        # assign
        self.word_vocab = new_vocab
        self.emb.word_embed.replace_weights(new_arr)
        return new_vocab
Ejemplo n.º 18
0
def main():
    np.random.seed(1234)
    NUM_POS = 10
    # build vocabs
    reader = TextReader("./test_utils.py")
    vb_word = VocabBuilder("w")
    vb_char = VocabBuilder("c")
    for one in reader:
        vb_word.feed_stream(one.tokens)
        vb_char.feed_stream((c for w in one.tokens for c in w))
    voc_word = vb_word.finish()
    voc_char = vb_char.finish()
    voc_pos = VocabBuilder.build_from_stream(range(NUM_POS), name="pos")
    vpack = VocabPackage({
        "word": voc_word,
        "char": voc_char,
        "pos": voc_pos
    }, {"word": None})
    # build model
    pc = BK.ParamCollection()
    conf_emb = EmbedConf().init_from_kwargs(init_words_from_pretrain=False,
                                            dim_char=10,
                                            dim_posi=10,
                                            emb_proj_dim=400,
                                            dim_extras="50",
                                            extra_names="pos")
    conf_emb.do_validate()
    mod_emb = MyEmbedder(pc, conf_emb, vpack)
    conf_enc = EncConf().init_from_kwargs(enc_rnn_type="lstm2",
                                          enc_cnn_layer=1,
                                          enc_att_layer=1)
    conf_enc._input_dim = mod_emb.get_output_dims()[0]
    mod_enc = MyEncoder(pc, conf_enc)
    enc_output_dim = mod_enc.get_output_dims()[0]
    mod_scorer = BiAffineScorer(pc, enc_output_dim, enc_output_dim, 10)
    # build data
    word_padder = DataPadder(2, pad_lens=(0, 50), mask_range=2)
    char_padder = DataPadder(3, pad_lens=(0, 50, 20))
    word_idxes = []
    char_idxes = []
    pos_idxes = []
    for toks in reader:
        one_words = []
        one_chars = []
        for w in toks.tokens:
            one_words.append(voc_word.get_else_unk(w))
            one_chars.append([voc_char.get_else_unk(c) for c in w])
        word_idxes.append(one_words)
        char_idxes.append(one_chars)
        pos_idxes.append(
            np.random.randint(voc_pos.trg_len(), size=len(one_words)) +
            1)  # pred->trg
    word_arr, word_mask_arr = word_padder.pad(word_idxes)
    pos_arr, _ = word_padder.pad(pos_idxes)
    char_arr, _ = char_padder.pad(char_idxes)
    #
    # run
    rop = layers.RefreshOptions(hdrop=0.2, gdrop=0.2, fix_drop=True)
    for _ in range(5):
        mod_emb.refresh(rop)
        mod_enc.refresh(rop)
        mod_scorer.refresh(rop)
        #
        expr_emb = mod_emb(word_arr, char_arr, [pos_arr])
        zlog(BK.get_shape(expr_emb))
        expr_enc = mod_enc(expr_emb, word_mask_arr)
        zlog(BK.get_shape(expr_enc))
        #
        mask_expr = BK.input_real(word_mask_arr)
        score0 = mod_scorer.paired_score(expr_enc, expr_enc, mask_expr,
                                         mask_expr)
        score1 = mod_scorer.plain_score(expr_enc.unsqueeze(-2),
                                        expr_enc.unsqueeze(-3),
                                        mask_expr.unsqueeze(-1),
                                        mask_expr.unsqueeze(-2))
        #
        zmiss = float(BK.avg(score0 - score1))
        assert zmiss < 0.0001
    zlog("OK")
    pass
Ejemplo n.º 19
0
 def __init__(self, comp_name, vpack: VocabPackage):
     super().__init__(comp_name, vpack)
     self.padder = DataPadder(3, pad_vals=0)  # replace the padder
Ejemplo n.º 20
0
 def __init__(self, comp_name, vpack: VocabPackage):
     self.comp_name = comp_name
     self.comp_seq_name = f"{comp_name}_seq"
     self.voc = vpack.get_voc(comp_name)
     self.padder = DataPadder(2, pad_vals=0)  # pad 0
Ejemplo n.º 21
0
class GraphParser(BaseParser):
    def __init__(self, conf: GraphParserConf, vpack: VocabPackage):
        super().__init__(conf, vpack)
        # ===== Input Specification =====
        # both head/label padding with 0 (does not matter what, since will be masked)
        self.predict_padder = DataPadder(2, pad_vals=0)
        self.hlocal_padder = DataPadder(3, pad_vals=0.)
        #
        # todo(warn): adding-styled hlocal has problems intuitively, maybe not suitable for graph-parser
        self.norm_single, self.norm_local, self.norm_global, self.norm_hlocal = \
            [conf.output_normalizing==z for z in ["single", "local", "global", "hlocal"]]
        self.loss_prob, self.loss_hinge, self.loss_mr = [
            conf.tconf.loss_function == z for z in ["prob", "hinge", "mr"]
        ]
        self.alg_proj, self.alg_unproj, self.alg_greedy = [
            conf.iconf.dec_algorithm == z
            for z in ["proj", "unproj", "greedy"]
        ]

    #

    def build_decoder(self):
        conf = self.conf
        conf.sc_conf._input_dim = self.enc_output_dim
        conf.sc_conf._num_label = self.label_vocab.trg_len(False)
        return GraphScorer(self.pc, conf.sc_conf)

    # shared calculations for final scoring
    # -> the scores are masked with PRAC_MIN (by the scorer) for the paddings, but not handling diag here!

    def _prepare_score(self, insts, training):
        input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
            insts, training)
        mask_expr = BK.input_real(mask_arr)
        # am_expr, ah_expr, lm_expr, lh_expr = self.scorer.transform_space(enc_repr)
        scoring_expr_pack = self.scorer.transform_space(enc_repr)
        return scoring_expr_pack, mask_expr, jpos_pack

    # scoring procedures
    def _score_arc_full(self,
                        scoring_expr_pack,
                        mask_expr,
                        training,
                        margin,
                        gold_heads_expr=None):
        am_expr, ah_expr, _, _ = scoring_expr_pack
        # [BS, len-m, len-h]
        full_arc_score = self.scorer.score_arc_all(am_expr, ah_expr, mask_expr,
                                                   mask_expr)
        # # set diag to small values # todo(warn): handled specifically in algorithms
        # maxlen = BK.get_shape(full_arc_score, 1)
        # full_arc_score += BK.diagflat(BK.constants([maxlen], Constants.REAL_PRAC_MIN))
        # margin?
        if training and margin > 0.:
            full_arc_score = BK.minus_margin(full_arc_score, gold_heads_expr,
                                             margin)
        return full_arc_score

    def _score_label_full(self,
                          scoring_expr_pack,
                          mask_expr,
                          training,
                          margin,
                          gold_heads_expr=None,
                          gold_labels_expr=None):
        _, _, lm_expr, lh_expr = scoring_expr_pack
        # [BS, len-m, len-h, L]
        full_label_score = self.scorer.score_label_all(lm_expr, lh_expr,
                                                       mask_expr, mask_expr)
        # # set diag to small values # todo(warn): handled specifically in algorithms
        # maxlen = BK.get_shape(full_label_score, 1)
        # full_label_score += BK.diagflat(BK.constants([maxlen], Constants.REAL_PRAC_MIN)).unsqueeze(-1)
        # margin? -- specially reshaping
        if training and margin > 0.:
            full_shape = BK.get_shape(full_label_score)
            # combine last two dim
            combiend_score_expr = full_label_score.view(full_shape[:-2] + [-1])
            combined_idx_expr = gold_heads_expr * full_shape[
                -1] + gold_labels_expr
            combined_changed_score = BK.minus_margin(combiend_score_expr,
                                                     combined_idx_expr, margin)
            full_label_score = combined_changed_score.view(full_shape)
        return full_label_score

    def _score_label_selected(self,
                              scoring_expr_pack,
                              mask_expr,
                              training,
                              margin,
                              gold_heads_expr,
                              gold_labels_expr=None):
        _, _, lm_expr, lh_expr = scoring_expr_pack
        # [BS, len-m, D]
        lh_expr_shape = BK.get_shape(lh_expr)
        selected_lh_expr = BK.gather(
            lh_expr,
            gold_heads_expr.unsqueeze(-1).expand(*lh_expr_shape),
            dim=len(lh_expr_shape) - 2)
        # [BS, len-m, L]
        select_label_score = self.scorer.score_label_select(
            lm_expr, selected_lh_expr, mask_expr)
        # margin?
        if training and margin > 0.:
            select_label_score = BK.minus_margin(select_label_score,
                                                 gold_labels_expr, margin)
        return select_label_score

    # for global-norm + hinge(perceptron-like)-loss
    # [*, m, h, L], [*, m]
    def _losses_global_hinge(self, full_score_expr, gold_heads_expr,
                             gold_labels_expr, pred_heads_expr,
                             pred_labels_expr, mask_expr):
        return GraphParser.get_losses_global_hinge(full_score_expr,
                                                   gold_heads_expr,
                                                   gold_labels_expr,
                                                   pred_heads_expr,
                                                   pred_labels_expr, mask_expr)

    # ===== export
    @staticmethod
    def get_losses_global_hinge(full_score_expr,
                                gold_heads_expr,
                                gold_labels_expr,
                                pred_heads_expr,
                                pred_labels_expr,
                                mask_expr,
                                clamping=True):
        # combine the last two dimension
        full_shape = BK.get_shape(full_score_expr)
        # [*, m, h*L]
        last_size = full_shape[-1]
        combiend_score_expr = full_score_expr.view(full_shape[:-2] + [-1])
        # [*, m]
        gold_combined_idx_expr = gold_heads_expr * last_size + gold_labels_expr
        pred_combined_idx_expr = pred_heads_expr * last_size + pred_labels_expr
        # [*, m]
        gold_scores = BK.gather_one_lastdim(combiend_score_expr,
                                            gold_combined_idx_expr).squeeze(-1)
        pred_scores = BK.gather_one_lastdim(combiend_score_expr,
                                            pred_combined_idx_expr).squeeze(-1)
        # todo(warn): be aware of search error!
        # hinge_losses = BK.clamp(pred_scores - gold_scores, min=0.)  # this is previous version
        hinge_losses = pred_scores - gold_scores  # [*, len]
        if clamping:
            valid_losses = ((hinge_losses * mask_expr)[:, 1:].sum(-1) >
                            0.).float().unsqueeze(-1)  # [*, 1]
            return hinge_losses * valid_losses
        else:
            # for this mode, will there be problems of search error? Maybe rare.
            return hinge_losses

    # =====

    # for global-norm + prob-loss
    def _losses_global_prob(self, full_score_expr, gold_heads_expr,
                            gold_labels_expr, marginals_expr, mask_expr):
        # combine the last two dimension
        full_shape = BK.get_shape(full_score_expr)
        last_size = full_shape[-1]
        # [*, m, h*L]
        combined_marginals_expr = marginals_expr.view(full_shape[:-2] + [-1])
        # # todo(warn): make sure sum to 1., handled in algorithm instead
        # combined_marginals_expr = combined_marginals_expr / combined_marginals_expr.sum(dim=-1, keepdim=True)
        # [*, m]
        gold_combined_idx_expr = gold_heads_expr * last_size + gold_labels_expr
        # [*, m, h, L]
        gradients = BK.minus_margin(combined_marginals_expr,
                                    gold_combined_idx_expr,
                                    1.).view(full_shape)
        # the gradients on h are already 0. from the marginal algorithm
        gradients_masked = gradients * mask_expr.unsqueeze(-1).unsqueeze(
            -1) * mask_expr.unsqueeze(-2).unsqueeze(-1)
        # for the h-dimension, need to divide by the real length.
        # todo(warn): this values should be directly summed rather than averaged, since directly from loss
        fake_losses = (full_score_expr * gradients_masked).sum(-1).sum(
            -1)  # [BS, m]
        # todo(warn): be aware of search-error-like output constrains;
        #  but this clamp for all is not good for loss-prob, dealt at outside with unproj-mask.
        # <bad> fake_losses = BK.clamp(fake_losses, min=0.)
        return fake_losses

    # for single-norm: 0-1 loss
    # [*, L], [*], float
    def _losses_single(self,
                       score_expr,
                       gold_idxes_expr,
                       single_sample,
                       is_hinge=False,
                       margin=0.):
        # expand the idxes to 0/1
        score_shape = BK.get_shape(score_expr)
        expanded_idxes_expr = BK.constants(score_shape, 0.)
        expanded_idxes_expr = BK.minus_margin(expanded_idxes_expr,
                                              gold_idxes_expr,
                                              -1.)  # minus -1 means +1
        # todo(+N): first adjust margin, since previously only minus margin for golds?
        if margin > 0.:
            adjusted_scores = margin + BK.minus_margin(score_expr,
                                                       gold_idxes_expr, margin)
        else:
            adjusted_scores = score_expr
        # [*, L]
        if is_hinge:
            # multiply pos instances with -1
            flipped_scores = adjusted_scores * (1. - 2 * expanded_idxes_expr)
            losses_all = BK.clamp(flipped_scores, min=0.)
        else:
            losses_all = BK.binary_cross_entropy_with_logits(
                adjusted_scores, expanded_idxes_expr, reduction='none')
        # special interpretation (todo(+2): there can be better implementation)
        if single_sample < 1.:
            # todo(warn): lower bound of sample_rate, ensure 2 samples
            real_sample_rate = max(single_sample, 2. / score_shape[-1])
        elif single_sample >= 2.:
            # including the positive one
            real_sample_rate = max(single_sample, 2.) / score_shape[-1]
        else:  # [1., 2.)
            real_sample_rate = single_sample
        #
        if real_sample_rate < 1.:
            sample_weight = BK.random_bernoulli(score_shape, real_sample_rate,
                                                1.)
            # make sure positive is valid
            sample_weight = (sample_weight +
                             expanded_idxes_expr.float()).clamp_(0., 1.)
            #
            final_losses = (losses_all *
                            sample_weight).sum(-1) / sample_weight.sum(-1)
        else:
            final_losses = losses_all.mean(-1)
        return final_losses

    # =====
    # expr[BS, m, h, L], arr[BS] -> arr[BS, m]
    def _decode(self, full_score_expr, maske_expr, lengths_arr):
        if self.alg_unproj:
            return nmst_unproj(full_score_expr,
                               maske_expr,
                               lengths_arr,
                               labeled=True,
                               ret_arr=True)
        elif self.alg_proj:
            return nmst_proj(full_score_expr,
                             maske_expr,
                             lengths_arr,
                             labeled=True,
                             ret_arr=True)
        elif self.alg_greedy:
            return nmst_greedy(full_score_expr,
                               maske_expr,
                               lengths_arr,
                               labeled=True,
                               ret_arr=True)
        else:
            zfatal("Unknown decoding algorithm " +
                   self.conf.iconf.dec_algorithm)
            return None

    # expr[BS, m, h, L], arr[BS] -> expr[BS, m, h, L]
    def _marginal(self, full_score_expr, maske_expr, lengths_arr):
        if self.alg_unproj:
            marginals_expr = nmarginal_unproj(full_score_expr,
                                              maske_expr,
                                              lengths_arr,
                                              labeled=True)
        elif self.alg_proj:
            marginals_expr = nmarginal_proj(full_score_expr,
                                            maske_expr,
                                            lengths_arr,
                                            labeled=True)
        else:
            zfatal(
                "Unsupported marginal-calculation for the decoding algorithm of "
                + self.conf.iconf.dec_algorithm)
            marginals_expr = None
        return marginals_expr

    # ===== main methods: training and decoding
    # only getting scores
    def score_on_batch(self, insts: List[ParseInstance]):
        with BK.no_grad_env():
            self.refresh_batch(False)
            scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score(
                insts, False)
            full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr,
                                                  False, 0.)
            full_label_score = self._score_label_full(scoring_expr_pack,
                                                      mask_expr, False, 0.)
            return full_arc_score, full_label_score

    # full score and inference
    def inference_on_batch(self, insts: List[ParseInstance], **kwargs):
        with BK.no_grad_env():
            self.refresh_batch(False)
            # ===== calculate
            scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score(
                insts, False)
            full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr,
                                                  False, 0.)
            full_label_score = self._score_label_full(scoring_expr_pack,
                                                      mask_expr, False, 0.)
            # normalizing scores
            full_score = None
            final_exp_score = False  # whether to provide PROB by exp
            if self.norm_local and self.loss_prob:
                full_score = BK.log_softmax(full_arc_score,
                                            -1).unsqueeze(-1) + BK.log_softmax(
                                                full_label_score, -1)
                final_exp_score = True
            elif self.norm_hlocal and self.loss_prob:
                # normalize at m dimension, ignore each nodes's self-finish step.
                full_score = BK.log_softmax(full_arc_score,
                                            -2).unsqueeze(-1) + BK.log_softmax(
                                                full_label_score, -1)
            elif self.norm_single and self.loss_prob:
                if self.conf.iconf.dec_single_neg:
                    # todo(+2): add all-neg for prob explanation
                    full_arc_probs = BK.sigmoid(full_arc_score)
                    full_label_probs = BK.sigmoid(full_label_score)
                    fake_arc_scores = BK.log(full_arc_probs) - BK.log(
                        1. - full_arc_probs)
                    fake_label_scores = BK.log(full_label_probs) - BK.log(
                        1. - full_label_probs)
                    full_score = fake_arc_scores.unsqueeze(
                        -1) + fake_label_scores
                else:
                    full_score = BK.logsigmoid(full_arc_score).unsqueeze(
                        -1) + BK.logsigmoid(full_label_score)
                    final_exp_score = True
            else:
                full_score = full_arc_score.unsqueeze(-1) + full_label_score
            # decode
            mst_lengths = [len(z) + 1 for z in insts
                           ]  # +=1 to include ROOT for mst decoding
            mst_heads_arr, mst_labels_arr, mst_scores_arr = self._decode(
                full_score, mask_expr, np.asarray(mst_lengths, dtype=np.int32))
            if final_exp_score:
                mst_scores_arr = np.exp(mst_scores_arr)
            # jpos prediction (directly index, no converting as in parsing)
            jpos_preds_expr = jpos_pack[2]
            has_jpos_pred = jpos_preds_expr is not None
            jpos_preds_arr = BK.get_value(
                jpos_preds_expr) if has_jpos_pred else None
            # ===== assign
            info = {"sent": len(insts), "tok": sum(mst_lengths) - len(insts)}
            mst_real_labels = self.pred2real_labels(mst_labels_arr)
            for one_idx, one_inst in enumerate(insts):
                cur_length = mst_lengths[one_idx]
                one_inst.pred_heads.set_vals(
                    mst_heads_arr[one_idx]
                    [:cur_length])  # directly int-val for heads
                one_inst.pred_labels.build_vals(
                    mst_real_labels[one_idx][:cur_length], self.label_vocab)
                one_inst.pred_par_scores.set_vals(
                    mst_scores_arr[one_idx][:cur_length])
                if has_jpos_pred:
                    one_inst.pred_poses.build_vals(
                        jpos_preds_arr[one_idx][:cur_length],
                        self.bter.pos_vocab)
            return info

    # list(mini-batch) of annotated instances
    # optional results are written in-place? return info.
    def fb_on_batch(self,
                    annotated_insts,
                    training=True,
                    loss_factor=1,
                    **kwargs):
        self.refresh_batch(training)
        margin = self.margin.value
        # gold heads and labels
        gold_heads_arr, _ = self.predict_padder.pad(
            [z.heads.vals for z in annotated_insts])
        gold_labels_arr, _ = self.predict_padder.pad(
            [self.real2pred_labels(z.labels.idxes) for z in annotated_insts])
        gold_heads_expr = BK.input_idx(gold_heads_arr)  # [BS, Len]
        gold_labels_expr = BK.input_idx(gold_labels_arr)  # [BS, Len]
        # ===== calculate
        scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score(
            annotated_insts, training)
        full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr,
                                              training, margin,
                                              gold_heads_expr)
        #
        final_losses = None
        if self.norm_local or self.norm_single:
            select_label_score = self._score_label_selected(
                scoring_expr_pack, mask_expr, training, margin,
                gold_heads_expr, gold_labels_expr)
            # already added margin previously
            losses_heads = losses_labels = None
            if self.loss_prob:
                if self.norm_local:
                    losses_heads = BK.loss_nll(full_arc_score, gold_heads_expr)
                    losses_labels = BK.loss_nll(select_label_score,
                                                gold_labels_expr)
                elif self.norm_single:
                    single_sample = self.conf.tconf.loss_single_sample
                    losses_heads = self._losses_single(full_arc_score,
                                                       gold_heads_expr,
                                                       single_sample,
                                                       is_hinge=False)
                    losses_labels = self._losses_single(select_label_score,
                                                        gold_labels_expr,
                                                        single_sample,
                                                        is_hinge=False)
                # simply adding
                final_losses = losses_heads + losses_labels
            elif self.loss_hinge:
                if self.norm_local:
                    losses_heads = BK.loss_hinge(full_arc_score,
                                                 gold_heads_expr)
                    losses_labels = BK.loss_hinge(select_label_score,
                                                  gold_labels_expr)
                elif self.norm_single:
                    single_sample = self.conf.tconf.loss_single_sample
                    losses_heads = self._losses_single(full_arc_score,
                                                       gold_heads_expr,
                                                       single_sample,
                                                       is_hinge=True,
                                                       margin=margin)
                    losses_labels = self._losses_single(select_label_score,
                                                        gold_labels_expr,
                                                        single_sample,
                                                        is_hinge=True,
                                                        margin=margin)
                # simply adding
                final_losses = losses_heads + losses_labels
            elif self.loss_mr:
                # special treatment!
                probs_heads = BK.softmax(full_arc_score, dim=-1)  # [bs, m, h]
                probs_labels = BK.softmax(select_label_score,
                                          dim=-1)  # [bs, m, h]
                # select
                probs_head_gold = BK.gather_one_lastdim(
                    probs_heads, gold_heads_expr).squeeze(-1)  # [bs, m]
                probs_label_gold = BK.gather_one_lastdim(
                    probs_labels, gold_labels_expr).squeeze(-1)  # [bs, m]
                # root and pad will be excluded later
                # Reward = \sum_i 1.*marginal(GEdge_i); while for global models, need to gradient on marginal-functions
                # todo(warn): have problem since steps will be quite small, not used!
                final_losses = (mask_expr - probs_head_gold * probs_label_gold
                                )  # let loss>=0
        elif self.norm_global:
            full_label_score = self._score_label_full(scoring_expr_pack,
                                                      mask_expr, training,
                                                      margin, gold_heads_expr,
                                                      gold_labels_expr)
            # for this one, use the merged full score
            full_score = full_arc_score.unsqueeze(
                -1) + full_label_score  # [BS, m, h, L]
            # +=1 to include ROOT for mst decoding
            mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts],
                                         dtype=np.int32)
            # do inference
            if self.loss_prob:
                marginals_expr = self._marginal(
                    full_score, mask_expr, mst_lengths_arr)  # [BS, m, h, L]
                final_losses = self._losses_global_prob(
                    full_score, gold_heads_expr, gold_labels_expr,
                    marginals_expr, mask_expr)
                if self.alg_proj:
                    # todo(+N): deal with search-error-like problem, discard unproj neg losses (score>weighted-avg),
                    #  but this might be too loose, although the unproj edges are few?
                    gold_unproj_arr, _ = self.predict_padder.pad(
                        [z.unprojs for z in annotated_insts])
                    gold_unproj_expr = BK.input_real(
                        gold_unproj_arr)  # [BS, Len]
                    comparing_expr = Constants.REAL_PRAC_MIN * (
                        1. - gold_unproj_expr)
                    final_losses = BK.max_elem(final_losses, comparing_expr)
            elif self.loss_hinge:
                pred_heads_arr, pred_labels_arr, _ = self._decode(
                    full_score, mask_expr, mst_lengths_arr)
                pred_heads_expr = BK.input_idx(pred_heads_arr)  # [BS, Len]
                pred_labels_expr = BK.input_idx(pred_labels_arr)  # [BS, Len]
                #
                final_losses = self._losses_global_hinge(
                    full_score, gold_heads_expr, gold_labels_expr,
                    pred_heads_expr, pred_labels_expr, mask_expr)
            elif self.loss_mr:
                # todo(+N): Loss = -Reward = \sum marginals, which requires gradients on marginal-one-edge, or marginal-two-edges
                raise NotImplementedError(
                    "Not implemented for global-loss + mr.")
        elif self.norm_hlocal:
            # firstly label losses are the same
            select_label_score = self._score_label_selected(
                scoring_expr_pack, mask_expr, training, margin,
                gold_heads_expr, gold_labels_expr)
            losses_labels = BK.loss_nll(select_label_score, gold_labels_expr)
            # then specially for arc loss
            children_masks_arr, _ = self.hlocal_padder.pad(
                [z.get_children_mask_arr() for z in annotated_insts])
            children_masks_expr = BK.input_real(
                children_masks_arr)  # [bs, h, m]
            # [bs, h]
            # todo(warn): use prod rather than sum, but still only an approximation for the top-down
            # losses_arc = -BK.log(BK.sum(BK.softmax(full_arc_score, -2).transpose(-1, -2) * children_masks_expr, dim=-1) + (1-mask_expr))
            losses_arc = -BK.sum(BK.log_softmax(full_arc_score, -2).transpose(
                -1, -2) * children_masks_expr,
                                 dim=-1)
            # including the root-head is important
            losses_arc[:, 1] += losses_arc[:, 0]
            final_losses = losses_arc + losses_labels
        #
        # jpos loss? (the same mask as parsing)
        jpos_losses_expr = jpos_pack[1]
        if jpos_losses_expr is not None:
            final_losses += jpos_losses_expr
        # collect loss with mask, also excluding the first symbol of ROOT
        final_losses_masked = (final_losses * mask_expr)[:, 1:]
        final_loss_sum = BK.sum(final_losses_masked)
        # divide loss by what?
        num_sent = len(annotated_insts)
        num_valid_tok = sum(len(z) for z in annotated_insts)
        if self.conf.tconf.loss_div_tok:
            final_loss = final_loss_sum / num_valid_tok
        else:
            final_loss = final_loss_sum / num_sent
        #
        final_loss_sum_val = float(BK.get_value(final_loss_sum))
        info = {
            "sent": num_sent,
            "tok": num_valid_tok,
            "loss_sum": final_loss_sum_val
        }
        if training:
            BK.backward(final_loss, loss_factor)
        return info
Ejemplo n.º 22
0
class M3Encoder(MyIEBT):
    def __init__(self, pc: BK.ParamCollection, conf: M3EncConf, tconf,
                 vpack: VocabPackage):
        super().__init__(pc, conf, tconf, vpack)
        #
        self.conf = conf
        # ----- bert
        # modify bert_conf for other input
        BERT_OTHER_VSIZE = 50  # todo(+N): this should be enough for small inputs!
        conf.bert_conf.bert2_other_input_names = conf.bert_other_inputs
        conf.bert_conf.bert2_other_input_vsizes = [BERT_OTHER_VSIZE] * len(
            conf.bert_other_inputs)
        self.berter = self.add_sub_node("bert", Berter2(pc, conf.bert_conf))
        # -----
        # index fake sent
        self.index_helper = IndexerHelper(vpack)
        # extra encoder over bert?
        self.bert_dim, self.bert_fold = self.berter.get_output_dims()
        conf.m3_enc_conf._input_dim = self.bert_dim
        self.m3_encs = [
            self.add_sub_node("m3e", MyEncoder(pc, conf.m3_enc_conf))
            for _ in range(self.bert_fold)
        ]
        self.m3_enc_out_dim = self.m3_encs[0].get_output_dims()[0]
        # skip m3_enc?
        self.m3_enc_is_empty = all(len(z.layers) == 0 for z in self.m3_encs)
        if self.m3_enc_is_empty:
            assert all(z.get_output_dims()[0] == self.bert_dim
                       for z in self.m3_encs)
            zlog("For m3_enc, we will skip it since it is empty!!")
        # dep as basic?
        if conf.m2e_use_basic_dep:
            MAX_LABEL_NUM = 200  # this should be enough
            self.dep_label_emb = self.add_sub_node(
                "dlab",
                Embedding(self.pc,
                          MAX_LABEL_NUM,
                          conf.dep_label_dim,
                          name="dlab"))
            self.dep_layer = self.add_sub_node(
                "dep",
                TaskSpecAdp(pc, [(self.m3_enc_out_dim, self.bert_fold), None],
                            [conf.dep_label_dim], conf.dep_output_dim))
        else:
            self.dep_label_emb = self.dep_layer = None
        self.dep_padder = DataPadder(
            2, pad_vals=0)  # 0 for both head-idx and label

    # multi-sentence encoding
    def run(self, insts: List[DocInstance], training: bool):
        conf = self.conf
        BERT_MAX_LEN = 510  # save 2 for CLS and SEP
        # =====
        # encoder 1: the basic encoder
        # todo(note): only DocInstane input for this mode, otherwise will break
        if conf.m2e_use_basic:
            reidx_pad_len = conf.ms_extend_budget
            # enc the basic part + also get some indexes
            sentid2offset = {}  # id(sent)->overall_seq_offset
            seq_offset = 0  # if look at the docs in one seq
            all_sents = []  # (inst, d_idx, s_idx)
            for d_idx, one_doc in enumerate(insts):
                assert isinstance(one_doc, DocInstance)
                for s_idx, one_sent in enumerate(one_doc.sents):
                    # todo(note): here we encode all the sentences
                    all_sents.append((one_sent, d_idx, s_idx))
                    sentid2offset[id(one_sent)] = seq_offset
                    seq_offset += one_sent.length - 1  # exclude extra ROOT node
            sent_reprs = self.run_sents(all_sents, insts, training)
            # flatten and concatenate and re-index
            reidxes_arr = np.zeros(
                seq_offset + reidx_pad_len, dtype=np.long
            )  # todo(note): extra padding to avoid out of boundary
            all_flattened_reprs = []
            all_flatten_offset = 0  # the local offset for batched basic encoding
            for one_pack in sent_reprs:
                one_sents, _, one_repr_ef, one_repr_evt, _ = one_pack
                assert one_repr_ef is one_repr_evt, "Currently does not support separate basic enc in m3 mode"
                one_repr_t = one_repr_evt
                _, one_slen, one_ldim = BK.get_shape(one_repr_t)
                all_flattened_reprs.append(one_repr_t.view([-1, one_ldim]))
                # fill in the indexes
                for one_sent in one_sents:
                    cur_start_offset = sentid2offset[id(one_sent)]
                    cur_real_slen = one_sent.length - 1
                    # again, +1 to get rid of extra ROOT
                    reidxes_arr[cur_start_offset:cur_start_offset+cur_real_slen] = \
                        np.arange(cur_real_slen, dtype=np.long) + (all_flatten_offset+1)
                    all_flatten_offset += one_slen  # here add the slen in batched version
            # re-idxing
            seq_sent_repr0 = BK.concat(all_flattened_reprs, 0)
            seq_sent_repr = BK.select(seq_sent_repr0, reidxes_arr,
                                      0)  # [all_seq_len, D]
        else:
            sentid2offset = defaultdict(int)
            seq_sent_repr = None
        # =====
        # repack and prepare for multiple sent enc
        # todo(note): here, the criterion is based on bert's tokenizer
        all_ms_info = []
        if isinstance(insts[0], DocInstance):
            for d_idx, one_doc in enumerate(insts):
                for s_idx, x in enumerate(one_doc.sents):
                    # the basic criterion is the same as the basic one
                    include_flag = False
                    if training:
                        if x.length<self.train_skip_length and x.length>=self.train_min_length \
                                and (len(x.events)>0 or next(self.random_sample_stream)>self.train_skip_noevt_rate):
                            include_flag = True
                    else:
                        if x.length >= self.test_min_length:
                            include_flag = True
                    if include_flag:
                        all_ms_info.append(
                            x.preps["ms"])  # use the pre-calculated one
        else:
            # multisent based
            all_ms_info = insts.copy()  # shallow copy
        # =====
        # encoder 2: the bert one (multi-sent encoding)
        ms_size_f = lambda x: x.subword_size
        all_ms_info.sort(key=ms_size_f)
        all_ms_buckets = self._bucket_sents_by_length(
            all_ms_info,
            conf.benc_bucket_range,
            ms_size_f,
            max_bsize=conf.benc_bucket_msize)
        berter = self.berter
        rets = []
        bert_use_center_typeids = conf.bert_use_center_typeids
        bert_use_special_typeids = conf.bert_use_special_typeids
        bert_other_inputs = conf.bert_other_inputs
        for one_bucket in all_ms_buckets:
            # prepare
            batched_ids = []
            batched_starts = []
            batched_seq_offset = []
            batched_typeids = []
            batched_other_inputs_list: List = [
                [] for _ in bert_other_inputs
            ]  # List(comp) of List(batch) of List(idx)
            for one_item in one_bucket:
                one_sents = one_item.sents
                one_center_sid = one_item.center_idx
                one_ids, one_starts, one_typeids = [], [], []
                one_other_inputs_list = [[] for _ in bert_other_inputs
                                         ]  # List(comp) of List(idx)
                for one_sid, one_sent in enumerate(one_sents):  # for bert
                    one_bidxes = one_sent.preps["bidx"]
                    one_ids.extend(one_bidxes.subword_ids)
                    one_starts.extend(one_bidxes.subword_is_start)
                    # prepare other inputs
                    for this_field_name, this_tofill_list in zip(
                            bert_other_inputs, one_other_inputs_list):
                        this_tofill_list.extend(
                            one_sent.preps["sub_" + this_field_name])
                    # todo(note): special procedure
                    if bert_use_center_typeids:
                        if one_sid != one_center_sid:
                            one_typeids.extend([0] *
                                               len(one_bidxes.subword_ids))
                        else:
                            this_typeids = [1] * len(one_bidxes.subword_ids)
                            if bert_use_special_typeids:
                                # todo(note): this is the special mode that we are given the events!!
                                for this_event in one_sents[
                                        one_center_sid].events:
                                    _, this_wid, this_wlen = this_event.mention.hard_span.position(
                                        headed=False)
                                    for a, b in one_item.center_word2sub[
                                            this_wid - 1:this_wid - 1 +
                                            this_wlen]:
                                        this_typeids[a:b] = [0] * (b - a)
                            one_typeids.extend(this_typeids)
                batched_ids.append(one_ids)
                batched_starts.append(one_starts)
                batched_typeids.append(one_typeids)
                for comp_one_oi, comp_batched_oi in zip(
                        one_other_inputs_list, batched_other_inputs_list):
                    comp_batched_oi.append(comp_one_oi)
                # for basic part
                batched_seq_offset.append(sentid2offset[id(one_sents[0])])
            # bert forward: [bs, slen, fold, D]
            if not bert_use_center_typeids:
                batched_typeids = None
            bert_expr0, mask_expr = berter.forward_batch(
                batched_ids,
                batched_starts,
                batched_typeids,
                training=training,
                other_inputs=batched_other_inputs_list)
            if self.m3_enc_is_empty:
                bert_expr = bert_expr0
            else:
                mask_arr = BK.get_value(mask_expr)  # [bs, slen]
                m3e_exprs = [
                    cur_enc(bert_expr0[:, :, cur_i], mask_arr)
                    for cur_i, cur_enc in enumerate(self.m3_encs)
                ]
                bert_expr = BK.stack(m3e_exprs, -2)  # on the fold dim again
            # collect basic ones: [bs, slen, D'] or None
            if seq_sent_repr is not None:
                arange_idxes_t = BK.arange_idx(BK.get_shape(
                    mask_expr, -1)).unsqueeze(0)  # [1, slen]
                offset_idxes_t = BK.input_idx(batched_seq_offset).unsqueeze(
                    -1) + arange_idxes_t  # [bs, slen]
                basic_expr = seq_sent_repr[offset_idxes_t]  # [bs, slen, D']
            elif conf.m2e_use_basic_dep:
                # collect each token's head-bert and ud-label, then forward with adp
                fake_sents = [one_item.fake_sent for one_item in one_bucket]
                # head idx and labels, no artificial ROOT
                padded_head_arr, _ = self.dep_padder.pad(
                    [s.ud_heads.vals[1:] for s in fake_sents])
                padded_label_arr, _ = self.dep_padder.pad(
                    [s.ud_labels.idxes[1:] for s in fake_sents])
                # get tensor
                padded_head_t = (BK.input_idx(padded_head_arr) - 1
                                 )  # here, the idx exclude root
                padded_head_t.clamp_(min=0)  # [bs, slen]
                padded_label_t = BK.input_idx(padded_label_arr)
                # get inputs
                input_head_bert_t = bert_expr[
                    BK.arange_idx(len(fake_sents)).unsqueeze(-1),
                    padded_head_t]  # [bs, slen, fold, D]
                input_label_emb_t = self.dep_label_emb(
                    padded_label_t)  # [bs, slen, D']
                basic_expr = self.dep_layer(
                    input_head_bert_t, None,
                    [input_label_emb_t])  # [bs, slen, ?]
            elif conf.m2e_use_basic_plus:
                sent_reprs = self.run_sents([(one_item.fake_sent, None, None)
                                             for one_item in one_bucket],
                                            insts,
                                            training,
                                            use_one_bucket=True)
                assert len(
                    sent_reprs
                ) == 1, "Unsupported split reprs for basic encoder, please set enc_bucket_range<=benc_bucket_range"
                _, _, one_repr_ef, one_repr_evt, _ = sent_reprs[0]
                assert one_repr_ef is one_repr_evt, "Currently does not support separate basic enc in m3 mode"
                basic_expr = one_repr_evt[:, 1:]  # exclude ROOT, [bs, slen, D]
                assert BK.get_shape(basic_expr)[:2] == BK.get_shape(
                    bert_expr)[:2]
            else:
                basic_expr = None
            # pack: (List[ms_item], bert_expr, basic_expr)
            rets.append((one_bucket, bert_expr, basic_expr))
        return rets

    # prepare instance
    def prepare_inst(self, inst: DocInstance):
        berter = self.berter
        conf = self.conf
        ms_extend_budget = conf.ms_extend_budget
        ms_extend_step = conf.ms_extend_step
        # -----
        # prepare bert tokenization results
        # print(inst.doc_id)
        for sent in inst.sents:
            real_words = sent.words.vals[1:]  # no special ROOT
            bidxes = berter.subword_tokenize(real_words, True)
            sent.preps["bidx"] = bidxes
            # prepare subword expanded fields
            subword2word = np.cumsum(bidxes.subword_is_start).tolist(
            )  # -1 and +1 happens to cancel out
            for field_name in conf.bert_other_inputs:
                field_idxes = getattr(
                    sent, field_name
                ).idxes  # use full one since idxes are set in this way
                sent.preps["sub_" + field_name] = [
                    field_idxes[z] for z in subword2word
                ]
        # -----
        # prepare others (another loop since we need cross-sent bert tokens)
        for sent in inst.sents:
            # -----
            # prepare multi-sent
            # include this multiple sent pack, extend to both sides until window limit or bert limit
            cur_center_sent = sent
            cur_sid, cur_doc = sent.sid, sent.doc
            cur_doc_sents = cur_doc.sents
            cur_doc_nsent = len(cur_doc.sents)
            cur_sid_left = cur_sid_right = cur_sid
            cur_subword_size = len(cur_center_sent.preps["bidx"].subword_ids)
            for step in range(ms_extend_step):
                # first left then right
                if cur_sid_left > 0:
                    this_subword_size = len(
                        cur_doc_sents[cur_sid_left -
                                      1].preps["bidx"].subword_ids)
                    if cur_subword_size + this_subword_size <= ms_extend_budget:
                        cur_sid_left -= 1
                        cur_subword_size += this_subword_size
                if cur_sid_right < cur_doc_nsent - 1:
                    this_subword_size = len(
                        cur_doc_sents[cur_sid_right +
                                      1].preps["bidx"].subword_ids)
                    if cur_subword_size + this_subword_size <= ms_extend_budget:
                        cur_sid_right += 1
                        cur_subword_size += this_subword_size
            # List[Sentence], List[int], center_local_idx, all_subword_size
            cur_sents = cur_doc_sents[cur_sid_left:cur_sid_right + 1]
            cur_offsets = [0]
            for s in cur_sents:
                cur_offsets.append(
                    s.length - 1 +
                    cur_offsets[-1])  # does not include ROOT here!!
            one_ms = MultiSentItem(cur_sents, cur_offsets,
                                   cur_sid - cur_sid_left, cur_subword_size,
                                   None, None, None)
            sent.preps["ms"] = one_ms
            # -----
            # subword idx for center sent
            center_word2sub = []
            prev_start = -1
            center_subword_is_start = cur_center_sent.preps[
                "bidx"].subword_is_start
            for cur_end, one_is_start in enumerate(center_subword_is_start):
                if one_is_start:
                    if prev_start >= 0:
                        center_word2sub.append((prev_start, cur_end))
                    prev_start = cur_end
            if prev_start >= 0:
                center_word2sub.append(
                    (prev_start, len(center_subword_is_start)))
            one_ms.center_word2sub = center_word2sub
            # -----
            # fake a concat sent for basic plus modeling
            if conf.m2e_use_basic_plus or conf.m2e_use_basic_dep:
                concat_words, concat_lemmas, concat_uposes, concat_ud_heads, concat_ud_labels = [], [], [], [], []
                cur_fake_offset = 0  # overall offset in fake sent
                prev_root = None
                for one_fake_inner_sent in cur_sents:  # exclude root
                    concat_words.extend(one_fake_inner_sent.words.vals[1:])
                    concat_lemmas.extend(one_fake_inner_sent.lemmas.vals[1:])
                    concat_uposes.extend(one_fake_inner_sent.uposes.vals[1:])
                    # todo(note): make the heads look like a real sent; the actual heads already +=1; root points to prev root
                    for local_i, local_h in enumerate(
                            one_fake_inner_sent.ud_heads.vals[1:]):
                        if local_h == 0:
                            if prev_root is None:
                                global_h = cur_fake_offset + local_i + 1  # +1 here for offset
                            else:
                                global_h = prev_root
                            prev_root = cur_fake_offset + local_i + 1  # +1 here for offset
                        else:
                            global_h = cur_fake_offset + local_h  # already +=1
                        concat_ud_heads.append(global_h)
                    concat_ud_labels.extend(
                        one_fake_inner_sent.ud_labels.vals[1:])
                    cur_fake_offset += len(one_fake_inner_sent.words.vals) - 1
                one_fake_sent = Sentence(None, concat_words, concat_lemmas,
                                         concat_uposes, concat_ud_heads,
                                         concat_ud_labels, None, None)
                one_ms.fake_sent = one_fake_sent
                self.index_helper.index_sent(one_fake_sent)

    def special_refresh(self, embed_rop, other_rop):
        super().special_refresh(embed_rop, other_rop)
        self.berter.refresh(other_rop)
        for one in self.m3_encs + [self.dep_layer, self.dep_label_emb]:
            if one is not None:
                one.refresh(other_rop)

    # #
    # def get_output_dims(self, *input_dims):
    #     raise RuntimeError("Complex output, thus not using this one")

    def speical_output_dims(self):
        conf = self.conf
        if conf.m2e_use_basic_dep:
            basic_dim = conf.dep_output_dim
        elif conf.m2e_use_basic or conf.m2e_use_basic_plus:
            basic_dim = self.enc_evt_output_dim
        else:
            basic_dim = None
        # bert_outputs, basic_output
        return (self.m3_enc_out_dim, self.bert_fold), basic_dim
Ejemplo n.º 23
0
class SL0Layer(BasicNode):
    def __init__(self, pc: BK.ParamCollection, rconf: SL0Conf):
        super().__init__(pc, None, None)
        self.dim = rconf._input_dim  # both input/output dim
        # padders for child nodes
        self.chs_start_posi = -rconf.chs_num
        self.ch_idx_padder = DataPadder(2, pad_vals=0,
                                        mask_range=2)  # [*, num-ch]
        self.ch_label_padder = DataPadder(2, pad_vals=0)
        #
        self.label_embeddings = self.add_sub_node(
            "label",
            Embedding(pc, rconf._num_label, rconf.dim_label, fix_row0=False))
        self.dim_label = rconf.dim_label
        # todo(note): now adopting flatten groupings for basic, and then that is all, no more recurrent features
        # group 1: [cur, chs, par] -> head_pre_size
        self.use_chs = rconf.use_chs
        self.use_par = rconf.use_par
        self.use_label_feat = rconf.use_label_feat
        # components (add the parameters anyway)
        # todo(note): children features: children + (label of mod->children)
        self.chs_reprer = self.add_sub_node("chs", ChsReprer(pc, rconf))
        self.chs_ff = self.add_sub_node(
            "chs_ff",
            Affine(pc,
                   self.chs_reprer.get_output_dims()[0],
                   self.dim,
                   act="tanh"))
        # todo(note): parent features: parent + (label of parent->mod)
        # todo(warn): always add label related params
        par_ff_inputs = [self.dim, rconf.dim_label]
        self.par_ff = self.add_sub_node(
            "par_ff", Affine(pc, par_ff_inputs, self.dim, act="tanh"))
        # no other groups anymore!
        if rconf.zero_extra_output_params:
            self.par_ff.zero_params()
            self.chs_ff.zero_params()

    # calculating the structured representations, giving raw repr-tensor
    # 1) [*, D]; 2) [*, D], [*], [*]; 3) [*, chs-len, D], [*, chs-len], [*, chs-len], [*]
    def calculate_repr(self, cur_t, par_t, label_t, par_mask_t, chs_t,
                       chs_label_t, chs_mask_t, chs_valid_mask_t):
        ret_t = cur_t  # [*, D]
        # padding 0 if not using labels
        dim_label = self.dim_label
        # child features
        if self.use_chs and chs_t is not None:
            if self.use_label_feat:
                chs_label_rt = self.label_embeddings(
                    chs_label_t)  # [*, max-chs, dlab]
            else:
                labels_shape = BK.get_shape(chs_t)
                labels_shape[-1] = dim_label
                chs_label_rt = BK.zeros(labels_shape)
            chs_input_t = BK.concat([chs_t, chs_label_rt], -1)
            chs_feat0 = self.chs_reprer(cur_t, chs_input_t, chs_mask_t,
                                        chs_valid_mask_t)
            chs_feat = self.chs_ff(chs_feat0)
            ret_t += chs_feat
        # parent features
        if self.use_par and par_t is not None:
            if self.use_label_feat:
                cur_label_t = self.label_embeddings(label_t)  # [*, dlab]
            else:
                labels_shape = BK.get_shape(par_t)
                labels_shape[-1] = dim_label
                cur_label_t = BK.zeros(labels_shape)
            par_feat = self.par_ff([par_t, cur_label_t])
            if par_mask_t is not None:
                par_feat *= par_mask_t.unsqueeze(-1)
            ret_t += par_feat
        return ret_t

    # todo(note): if no other features, then no change for the repr!
    def forward_repr(self, cur_t):
        return cur_t

    # preparation: padding for chs/par
    def pad_chs(self, idxes_list: List[List], labels_list: List[List]):
        start_posi = self.chs_start_posi
        if start_posi < 0:  # truncate
            idxes_list = [x[start_posi:] for x in idxes_list]
        # overall valid mask
        chs_valid = [(0. if len(z) == 0 else 1.) for z in idxes_list]
        # if any valid children in the batch
        if all(x > 0 for x in chs_valid):
            padded_chs_idxes, padded_chs_mask = self.ch_idx_padder.pad(
                idxes_list)  # [*, max-ch], [*, max-ch]
            if self.use_label_feat:
                if start_posi < 0:  # truncate
                    labels_list = [x[start_posi:] for x in labels_list]
                padded_chs_labels, _ = self.ch_label_padder.pad(
                    labels_list)  # [*, max-ch]
                chs_label_t = BK.input_idx(padded_chs_labels)
            else:
                chs_label_t = None
            chs_idxes_t, chs_mask_t, chs_valid_mask_t = \
                BK.input_idx(padded_chs_idxes), BK.input_real(padded_chs_mask), BK.input_real(chs_valid)
            return chs_idxes_t, chs_label_t, chs_mask_t, chs_valid_mask_t
        else:
            return None, None, None, None

    def pad_par(self, idxes: List, labels: List):
        par_idxes_t = BK.input_idx(idxes)
        labels_t = BK.input_idx(labels)
        # todo(note): specifically, <0 means non-exist
        # todo(note): an interesting bug, the bug is ">=" was wrongly written as "<", in this way, 0 will act as the parent of those who actually do not have parents and are to be attached, therefore maybe patterns of "parent=0" will get much positive scores
        # todo(note): ACTUALLY, mainly because of the difference in search and forward-backward!!
        par_mask_t = (par_idxes_t >= 0).float()
        par_idxes_t.clamp_(0)  # since -1 will be illegal idx
        labels_t.clamp_(0)
        return par_idxes_t, labels_t, par_mask_t
Ejemplo n.º 24
0
class ParserBT(BasicNode):
    def __init__(self, pc: BK.ParamCollection, bconf: BTConf,
                 vpack: VocabPackage):
        super().__init__(pc, None, None)
        self.bconf = bconf
        # ===== Vocab =====
        self.word_vocab = vpack.get_voc("word")
        self.char_vocab = vpack.get_voc("char")
        self.pos_vocab = vpack.get_voc("pos")
        # ===== Model =====
        # embedding
        self.emb = self.add_sub_node(
            "emb", MyEmbedder(self.pc, bconf.emb_conf, vpack))
        emb_output_dim = self.emb.get_output_dims()[0]
        # encoder0 for jpos
        # todo(note): will do nothing if not use_jpos
        bconf.jpos_conf._input_dim = emb_output_dim
        self.jpos_enc = self.add_sub_node(
            "enc0", JPosModule(self.pc, bconf.jpos_conf, self.pos_vocab))
        enc0_output_dim = self.jpos_enc.get_output_dims()[0]
        # encoder
        # todo(0): feed compute-on-the-fly hp
        bconf.enc_conf._input_dim = enc0_output_dim
        self.enc = self.add_sub_node("enc", MyEncoder(self.pc, bconf.enc_conf))
        self.enc_output_dim = self.enc.get_output_dims()[0]
        # ===== Input Specification =====
        # inputs (word, char, pos) and vocabulary
        self.need_word = self.emb.has_word
        self.need_char = self.emb.has_char
        # todo(warn): currently only allow extra fields for POS
        self.need_pos = False
        if len(self.emb.extra_names) > 0:
            assert len(
                self.emb.extra_names) == 1 and self.emb.extra_names[0] == "pos"
            self.need_pos = True
        # todo(warn): currently only allow one aux field
        self.need_aux = False
        if len(self.emb.dim_auxes) > 0:
            assert len(self.emb.dim_auxes) == 1
            self.need_aux = True
        #
        self.word_padder = DataPadder(2,
                                      pad_vals=self.word_vocab.pad,
                                      mask_range=2)
        self.char_padder = DataPadder(3,
                                      pad_lens=(0, 0, bconf.char_max_length),
                                      pad_vals=self.char_vocab.pad)
        self.pos_padder = DataPadder(2, pad_vals=self.pos_vocab.pad)
        #
        self.random_sample_stream = Random.stream(Random.random_sample)

    def get_output_dims(self, *input_dims):
        return (self.enc_output_dim, )

    #
    def refresh(self, rop=None):
        zfatal("Should call special_refresh instead!")

    # whether enabling joint-pos multitask
    def jpos_multitask_enabled(self):
        return self.jpos_enc.jpos_multitask

    # ====
    # special routines

    def special_refresh(self, embed_rop, other_rop):
        self.emb.refresh(embed_rop)
        self.enc.refresh(other_rop)
        self.jpos_enc.refresh(other_rop)

    def prepare_training_rop(self):
        mconf = self.bconf
        embed_rop = RefreshOptions(hdrop=mconf.drop_embed,
                                   dropmd=mconf.dropmd_embed,
                                   fix_drop=mconf.fix_drop)
        other_rop = RefreshOptions(hdrop=mconf.drop_hidden,
                                   idrop=mconf.idrop_rnn,
                                   gdrop=mconf.gdrop_rnn,
                                   fix_drop=mconf.fix_drop)
        return embed_rop, other_rop

    # =====
    # run
    def _prepare_input(self, insts, training):
        word_arr, char_arr, extra_arrs, aux_arrs = None, None, [], []
        # ===== specially prepare for the words
        wv = self.word_vocab
        W_UNK = wv.unk
        UNK_REP_RATE = self.bconf.singleton_unk
        UNK_REP_THR = self.bconf.singleton_thr
        word_act_idxes = []
        if training and UNK_REP_RATE > 0.:  # replace unfreq/singleton words with UNK
            for one_inst in insts:
                one_act_idxes = []
                for one_idx in one_inst.words.idxes:
                    one_freq = wv.idx2val(one_idx)
                    if one_freq is not None and one_freq >= 1 and one_freq <= UNK_REP_THR:
                        if next(self.random_sample_stream) < (UNK_REP_RATE /
                                                              one_freq):
                            one_idx = W_UNK
                    one_act_idxes.append(one_idx)
                word_act_idxes.append(one_act_idxes)
        else:
            word_act_idxes = [z.words.idxes for z in insts]
        # todo(warn): still need the masks
        word_arr, mask_arr = self.word_padder.pad(word_act_idxes)
        # =====
        if not self.need_word:
            word_arr = None
        if self.need_char:
            chars = [z.chars.idxes for z in insts]
            char_arr, _ = self.char_padder.pad(chars)
        if self.need_pos or self.jpos_multitask_enabled():
            poses = [z.poses.idxes for z in insts]
            pos_arr, _ = self.pos_padder.pad(poses)
            if self.need_pos:
                extra_arrs.append(pos_arr)
        else:
            pos_arr = None
        if self.need_aux:
            aux_arr_list = [z.extra_features["aux_repr"] for z in insts]
            # pad
            padded_seq_len = int(mask_arr.shape[1])
            final_aux_arr_list = []
            for cur_arr in aux_arr_list:
                cur_len = len(cur_arr)
                if cur_len > padded_seq_len:
                    final_aux_arr_list.append(cur_arr[:padded_seq_len])
                else:
                    final_aux_arr_list.append(
                        np.pad(cur_arr,
                               ((0, padded_seq_len - cur_len), (0, 0)),
                               'constant'))
            aux_arrs.append(np.stack(final_aux_arr_list, 0))
        #
        input_repr = self.emb(word_arr, char_arr, extra_arrs, aux_arrs)
        # [BS, Len, Dim], [BS, Len]
        return input_repr, mask_arr, pos_arr

    # todo(warn): for rnn, need to transpose masks, thus need np.array
    # return input_repr, enc_repr, mask_arr
    def run(self, insts, training):
        # ===== calculate
        # [BS, Len, Di], [BS, Len], [BS, len]
        input_repr, mask_arr, gold_pos_arr = self._prepare_input(
            insts, training)
        # enc0 for joint-pos multitask
        input_repr0, jpos_pack = self.jpos_enc(input_repr,
                                               mask_arr,
                                               require_loss=training,
                                               require_pred=(not training),
                                               gold_pos_arr=gold_pos_arr)
        # [BS, Len, De]
        enc_repr = self.enc(input_repr0, mask_arr)
        return input_repr, enc_repr, jpos_pack, mask_arr

    # special routine
    def aug_words_and_embs(self, aug_vocab, aug_wv):
        orig_vocab = self.word_vocab
        if self.emb.has_word:
            orig_arr = self.emb.word_embed.E.detach().cpu().numpy()
            # todo(+2): find same-spelling words in the original vocab if not-hit in the extra_embed?
            # todo(warn): here aug_vocab should be find in aug_wv
            aug_arr = aug_vocab.filter_embed(aug_wv, assert_all_hit=True)
            new_vocab, new_arr = MultiHelper.aug_vocab_and_arr(
                orig_vocab, orig_arr, aug_vocab, aug_arr, aug_override=True)
            # assign
            self.word_vocab = new_vocab
            self.emb.word_embed.replace_weights(new_arr)
        else:
            zwarn("No need to aug vocab since delexicalized model!!")
            new_vocab = orig_vocab
        return new_vocab
Ejemplo n.º 25
0
class MaskLMNode(BasicNode):
    def __init__(self, pc: BK.ParamCollection, conf: MaskLMNodeConf,
                 vpack: VocabPackage):
        super().__init__(pc, None, None)
        self.conf = conf
        # vocab and padder
        self.word_vocab = vpack.get_voc("word")
        self.padder = DataPadder(
            2, pad_vals=self.word_vocab.pad,
            mask_range=2)  # todo(note): <pad>-id is very large
        # models
        self.hid_layer = self.add_sub_node(
            "hid", Affine(pc, conf._input_dim, conf.hid_dim, act=conf.hid_act))
        self.pred_layer = self.add_sub_node(
            "pred",
            Affine(pc,
                   conf.hid_dim,
                   conf.max_pred_rank + 1,
                   init_rop=NoDropRop()))
        if conf.init_pred_from_pretrain:
            npvec = vpack.get_emb("word")
            if npvec is None:
                zwarn(
                    "Pretrained vector not provided, skip init pred embeddings!!"
                )
            else:
                with BK.no_grad_env():
                    self.pred_layer.ws[0].copy_(
                        BK.input_real(npvec[:conf.max_pred_rank + 1].T))
                zlog(
                    f"Init pred embeddings from pretrained vectors (size={conf.max_pred_rank+1})."
                )

    # return (input_word_mask_repl, output_pred_mask_repl, ouput_pred_idx)
    def prepare(self, insts: List[ParseInstance], training):
        conf = self.conf
        word_idxes = [z.words.idxes for z in insts]
        word_arr, input_mask = self.padder.pad(word_idxes)  # [bsize, slen]
        # prepare for the masks
        input_word_mask = (Random.random_sample(word_arr.shape) <
                           conf.mask_rate) & (input_mask > 0.)
        input_word_mask &= (word_arr >= conf.min_mask_rank)
        input_word_mask[:, 0] = False  # no masking for special ROOT
        output_pred_mask = (input_word_mask & (word_arr <= conf.max_pred_rank))
        return input_word_mask.astype(np.float32), output_pred_mask.astype(
            np.float32), word_arr

    # [bsize, slen, *]
    def loss(self, repr_t, pred_mask_repl_arr, pred_idx_arr):
        mask_idxes, mask_valids = BK.mask2idx(
            BK.input_real(pred_mask_repl_arr))  # [bsize, ?]
        if BK.get_shape(mask_idxes, -1) == 0:  # no loss
            zzz = BK.zeros([])
            return [[zzz, zzz, zzz]]
        else:
            target_reprs = BK.gather_first_dims(repr_t, mask_idxes,
                                                1)  # [bsize, ?, *]
            target_hids = self.hid_layer(target_reprs)
            target_scores = self.pred_layer(target_hids)  # [bsize, ?, V]
            pred_idx_t = BK.input_idx(pred_idx_arr)  # [bsize, slen]
            target_idx_t = pred_idx_t.gather(-1, mask_idxes)  # [bsize, ?]
            target_idx_t[(mask_valids <
                          1.)] = 0  # make sure invalid ones in range
            # get loss
            pred_losses = BK.loss_nll(target_scores,
                                      target_idx_t)  # [bsize, ?]
            pred_loss_sum = (pred_losses * mask_valids).sum()
            pred_loss_count = mask_valids.sum()
            # argmax
            _, argmax_idxes = target_scores.max(-1)
            pred_corrs = (argmax_idxes == target_idx_t).float() * mask_valids
            pred_corr_count = pred_corrs.sum()
            return [[pred_loss_sum, pred_loss_count, pred_corr_count]]
Ejemplo n.º 26
0
class G2Parser(BaseParser):
    def __init__(self, conf: G2ParserConf, vpack: VocabPackage):
        super().__init__(conf, vpack)
        # todo(note): the neural parameters are exactly the same as the EF one
        # ===== basic G1 Parser's loading
        # todo(note): there can be parameter mismatch (but all of them in non-trained part, thus will be fine)
        self.g1parser = G1Parser.pre_g1_init(self, conf.pre_g1_conf)
        self.lambda_g1_arc_training = conf.pre_g1_conf.lambda_g1_arc_training
        self.lambda_g1_arc_testing = conf.pre_g1_conf.lambda_g1_arc_testing
        self.lambda_g1_lab_training = conf.pre_g1_conf.lambda_g1_lab_training
        self.lambda_g1_lab_testing = conf.pre_g1_conf.lambda_g1_lab_testing
        #
        self.add_slayer()
        self.dl = G2DL(self.scorer, self.slayer, conf)
        #
        self.predict_padder = DataPadder(2, pad_vals=0)
        self.num_label = self.label_vocab.trg_len(
            True)  # todo(WARN): use the original idx

    def build_decoder(self):
        conf = self.conf
        # ===== Decoding Scorer =====
        conf.sc_conf._input_dim = self.enc_output_dim
        conf.sc_conf._num_label = self.label_vocab.trg_len(
            True)  # todo(WARN): use the original idx
        return Scorer(self.pc, conf.sc_conf)

    def build_slayer(self):
        conf = self.conf
        # ===== Structured Layer =====
        conf.sl_conf._input_dim = self.enc_output_dim
        conf.sl_conf._num_label = self.label_vocab.trg_len(
            True)  # todo(WARN): use the original idx
        return SL0Layer(self.pc, conf.sl_conf)

    # =====
    # main procedures

    # get g1 prunings and extra-scores
    # -> if no g1parser provided, then use aux scores; otherwise, it depends on g1_use_aux_scores
    def _get_g1_pack(self, insts: List[ParseInstance], score_arc_lambda: float,
                     score_lab_lambda: float):
        pconf = self.conf.iconf.pruning_conf
        if self.g1parser is None or self.g1parser.g1_use_aux_scores:
            valid_mask, arc_score, lab_score, _, _ = G1Parser.score_and_prune(
                insts, self.num_label, pconf)
        else:
            valid_mask, arc_score, lab_score, _, _ = self.g1parser.prune_on_batch(
                insts, pconf)
        if score_arc_lambda <= 0. and score_lab_lambda <= 0.:
            go1_pack = None
        else:
            arc_score *= score_arc_lambda
            lab_score *= score_lab_lambda
            go1_pack = (arc_score, lab_score)
        # [*, slen, slen], ([*, slen, slen], [*, slen, slen, Lab])
        return valid_mask, go1_pack

    # make real valid masks (inplaced): valid(byte): [bs, len-m, len-h]; mask(float): [bs, len]
    def _make_final_valid(self, valid_expr, mask_expr):
        maxlen = BK.get_shape(mask_expr, -1)
        # first apply masks
        mask_expr_byte = mask_expr.byte()
        valid_expr &= mask_expr_byte.unsqueeze(-1)
        valid_expr &= mask_expr_byte.unsqueeze(-2)
        # then diag
        mask_diag = 1 - BK.eye(maxlen).byte()
        valid_expr &= mask_diag
        # root not as mod
        valid_expr[:, 0] = 0
        # only allow root->root (for grandparent feature)
        valid_expr[:, 0, 0] = 1
        return valid_expr

    # decoding
    def inference_on_batch(self, insts: List[ParseInstance], **kwargs):
        # iconf = self.conf.iconf
        with BK.no_grad_env():
            self.refresh_batch(False)
            # pruning and scores from g1
            valid_mask, go1_pack = self._get_g1_pack(
                insts, self.lambda_g1_arc_testing, self.lambda_g1_lab_testing)
            # encode
            input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
                insts, False)
            mask_expr = BK.input_real(mask_arr)
            # decode
            final_valid_expr = self._make_final_valid(valid_mask, mask_expr)
            ret_heads, ret_labels, _, _ = self.dl.decode(
                insts, enc_repr, final_valid_expr, go1_pack, False, 0.)
            # collect the results together
            all_heads = Helper.join_list(ret_heads)
            if ret_labels is None:
                # todo(note): simply get labels from the go1-label classifier; must provide g1parser
                if go1_pack is None:
                    _, go1_pack = self._get_g1_pack(insts, 1., 1.)
                _, go1_label_max_idxes = go1_pack[1].max(
                    -1)  # [bs, slen, slen]
                pred_heads_arr, _ = self.predict_padder.pad(
                    all_heads)  # [bs, slen]
                pred_heads_expr = BK.input_idx(pred_heads_arr)
                pred_labels_expr = BK.gather_one_lastdim(
                    go1_label_max_idxes, pred_heads_expr).squeeze(-1)
                all_labels = BK.get_value(pred_labels_expr)  # [bs, slen]
            else:
                all_labels = np.concatenate(ret_labels, 0)
            # ===== assign, todo(warn): here, the labels are directly original idx, no need to change
            for one_idx, one_inst in enumerate(insts):
                cur_length = len(one_inst) + 1
                one_inst.pred_heads.set_vals(
                    all_heads[one_idx]
                    [:cur_length])  # directly int-val for heads
                one_inst.pred_labels.build_vals(
                    all_labels[one_idx][:cur_length], self.label_vocab)
                # one_inst.pred_par_scores.set_vals(all_scores[one_idx][:cur_length])
            # =====
            # put jpos result (possibly)
            self.jpos_decode(insts, jpos_pack)
            # -----
            info = {"sent": len(insts), "tok": sum(map(len, insts))}
            return info

    # training
    def fb_on_batch(self,
                    annotated_insts: List[ParseInstance],
                    training=True,
                    loss_factor=1.,
                    **kwargs):
        self.refresh_batch(training)
        # pruning and scores from g1
        valid_mask, go1_pack = self._get_g1_pack(annotated_insts,
                                                 self.lambda_g1_arc_training,
                                                 self.lambda_g1_lab_training)
        # encode
        input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
            annotated_insts, training)
        mask_expr = BK.input_real(mask_arr)
        # the parsing loss
        final_valid_expr = self._make_final_valid(valid_mask, mask_expr)
        parsing_loss, parsing_scores, info = \
            self.dl.loss(annotated_insts, enc_repr, final_valid_expr, go1_pack, True, self.margin.value)
        info["loss_parse"] = BK.get_value(parsing_loss).item()
        final_loss = parsing_loss
        # other loss?
        jpos_loss = self.jpos_loss(jpos_pack, mask_expr)
        if jpos_loss is not None:
            info["loss_jpos"] = BK.get_value(jpos_loss).item()
            final_loss = parsing_loss + jpos_loss
        if parsing_scores is not None:
            reg_loss = self.reg_scores_loss(*parsing_scores)
            if reg_loss is not None:
                final_loss = final_loss + reg_loss
        info["fb"] = 1
        if training:
            BK.backward(final_loss, loss_factor)
        return info
Ejemplo n.º 27
0
class G1Parser(BaseParser):
    def __init__(self, conf: G1ParserConf, vpack: VocabPackage):
        super().__init__(conf, vpack)
        # todo(note): the neural parameters are exactly the same as the EF one
        self.scorer_helper = GScorerHelper(self.scorer)
        self.predict_padder = DataPadder(2, pad_vals=0)
        #
        self.g1_use_aux_scores = conf.debug_use_aux_scores  # assining here is only for debugging usage, otherwise assigning outside
        self.num_label = self.label_vocab.trg_len(
            True)  # todo(WARN): use the original idx
        #
        self.loss_hinge = (self.conf.tconf.loss_function == "hinge")
        if not self.loss_hinge:
            assert self.conf.tconf.loss_function == "prob", "This model only supports hinge or prob"

    def build_decoder(self):
        conf = self.conf
        # ===== Decoding Scorer =====
        conf.sc_conf._input_dim = self.enc_output_dim
        conf.sc_conf._num_label = self.label_vocab.trg_len(
            True)  # todo(WARN): use the original idx
        return Scorer(self.pc, conf.sc_conf)

    # =====
    # main procedures

    # decoding
    def inference_on_batch(self, insts: List[ParseInstance], **kwargs):
        iconf = self.conf.iconf
        pconf = iconf.pruning_conf
        with BK.no_grad_env():
            self.refresh_batch(False)
            if iconf.use_pruning:
                # todo(note): for the testing of pruning mode, use the scores instead
                if self.g1_use_aux_scores:
                    valid_mask, arc_score, label_score, mask_expr, _ = G1Parser.score_and_prune(
                        insts, self.num_label, pconf)
                else:
                    valid_mask, arc_score, label_score, mask_expr, _ = self.prune_on_batch(
                        insts, pconf)
                valid_mask_f = valid_mask.float()  # [*, len, len]
                mask_value = Constants.REAL_PRAC_MIN
                full_score = arc_score.unsqueeze(-1) + label_score
                full_score += (mask_value * (1. - valid_mask_f)).unsqueeze(-1)
                info_pruning = G1Parser.collect_pruning_info(
                    insts, valid_mask_f)
                jpos_pack = [None, None, None]
            else:
                input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
                    insts, False)
                mask_expr = BK.input_real(mask_arr)
                full_score = self.scorer_helper.score_full(enc_repr)
                info_pruning = None
            # =====
            self._decode(insts, full_score, mask_expr, "g1")
            # put jpos result (possibly)
            self.jpos_decode(insts, jpos_pack)
            # -----
            info = {"sent": len(insts), "tok": sum(map(len, insts))}
            if info_pruning is not None:
                info.update(info_pruning)
            return info

    # training
    def fb_on_batch(self,
                    annotated_insts: List[ParseInstance],
                    training=True,
                    loss_factor=1.,
                    **kwargs):
        self.refresh_batch(training)
        # encode
        input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
            annotated_insts, training)
        mask_expr = BK.input_real(mask_arr)
        # the parsing loss
        arc_score = self.scorer_helper.score_arc(enc_repr)
        lab_score = self.scorer_helper.score_label(enc_repr)
        full_score = arc_score + lab_score
        parsing_loss, info = self._loss(annotated_insts, full_score, mask_expr)
        # other loss?
        jpos_loss = self.jpos_loss(jpos_pack, mask_expr)
        reg_loss = self.reg_scores_loss(arc_score, lab_score)
        #
        info["loss_parse"] = BK.get_value(parsing_loss).item()
        final_loss = parsing_loss
        if jpos_loss is not None:
            info["loss_jpos"] = BK.get_value(jpos_loss).item()
            final_loss = parsing_loss + jpos_loss
        if reg_loss is not None:
            final_loss = final_loss + reg_loss
        info["fb"] = 1
        if training:
            BK.backward(final_loss, loss_factor)
        return info

    # =====
    def _decode(self, insts: List[ParseInstance], full_score, mask_expr,
                misc_prefix):
        # decode
        mst_lengths = [len(z) + 1
                       for z in insts]  # +=1 to include ROOT for mst decoding
        mst_lengths_arr = np.asarray(mst_lengths, dtype=np.int32)
        mst_heads_arr, mst_labels_arr, mst_scores_arr = nmst_unproj(
            full_score, mask_expr, mst_lengths_arr, labeled=True, ret_arr=True)
        if self.conf.iconf.output_marginals:
            # todo(note): here, we care about marginals for arc
            # lab_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True)
            arc_marginals = nmarginal_unproj(full_score,
                                             mask_expr,
                                             None,
                                             labeled=True).sum(-1)
            bsize, max_len = BK.get_shape(mask_expr)
            idxes_bs_expr = BK.arange_idx(bsize).unsqueeze(-1)
            idxes_m_expr = BK.arange_idx(max_len).unsqueeze(0)
            output_marg = arc_marginals[idxes_bs_expr, idxes_m_expr,
                                        BK.input_idx(mst_heads_arr)]
            mst_marg_arr = BK.get_value(output_marg)
        else:
            mst_marg_arr = None
        # ===== assign, todo(warn): here, the labels are directly original idx, no need to change
        for one_idx, one_inst in enumerate(insts):
            cur_length = mst_lengths[one_idx]
            one_inst.pred_heads.set_vals(
                mst_heads_arr[one_idx]
                [:cur_length])  # directly int-val for heads
            one_inst.pred_labels.build_vals(
                mst_labels_arr[one_idx][:cur_length], self.label_vocab)
            one_scores = mst_scores_arr[one_idx][:cur_length]
            one_inst.pred_par_scores.set_vals(one_scores)
            # extra output
            one_inst.extra_pred_misc[misc_prefix +
                                     "_score"] = one_scores.tolist()
            if mst_marg_arr is not None:
                one_inst.extra_pred_misc[
                    misc_prefix +
                    "_marg"] = mst_marg_arr[one_idx][:cur_length].tolist()

    # here, only adopt hinge(max-margin) loss; mostly adopted from previous graph parser
    def _loss(self,
              annotated_insts: List[ParseInstance],
              full_score_expr,
              mask_expr,
              valid_expr=None):
        bsize, max_len = BK.get_shape(mask_expr)
        # gold heads and labels
        gold_heads_arr, _ = self.predict_padder.pad(
            [z.heads.vals for z in annotated_insts])
        # todo(note): here use the original idx of label, no shift!
        gold_labels_arr, _ = self.predict_padder.pad(
            [z.labels.idxes for z in annotated_insts])
        gold_heads_expr = BK.input_idx(gold_heads_arr)  # [BS, Len]
        gold_labels_expr = BK.input_idx(gold_labels_arr)  # [BS, Len]
        #
        idxes_bs_expr = BK.arange_idx(bsize).unsqueeze(-1)
        idxes_m_expr = BK.arange_idx(max_len).unsqueeze(0)
        # scores for decoding or marginal
        margin = self.margin.value
        decoding_scores = full_score_expr.clone().detach()
        decoding_scores = self.scorer_helper.postprocess_scores(
            decoding_scores, mask_expr, margin, gold_heads_expr,
            gold_labels_expr)
        if self.loss_hinge:
            mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts],
                                         dtype=np.int32)
            pred_heads_expr, pred_labels_expr, _ = nmst_unproj(decoding_scores,
                                                               mask_expr,
                                                               mst_lengths_arr,
                                                               labeled=True,
                                                               ret_arr=False)
            # ===== add margin*cost, [bs, len]
            gold_final_scores = full_score_expr[idxes_bs_expr, idxes_m_expr,
                                                gold_heads_expr,
                                                gold_labels_expr]
            pred_final_scores = full_score_expr[
                idxes_bs_expr, idxes_m_expr, pred_heads_expr,
                pred_labels_expr] + margin * (
                    gold_heads_expr != pred_heads_expr).float() + margin * (
                        gold_labels_expr !=
                        pred_labels_expr).float()  # plus margin
            hinge_losses = pred_final_scores - gold_final_scores
            valid_losses = ((hinge_losses * mask_expr)[:, 1:].sum(-1) >
                            0.).float().unsqueeze(-1)  # [*, 1]
            final_losses = hinge_losses * valid_losses
        else:
            lab_marginals = nmarginal_unproj(decoding_scores,
                                             mask_expr,
                                             None,
                                             labeled=True)
            lab_marginals[idxes_bs_expr, idxes_m_expr, gold_heads_expr,
                          gold_labels_expr] -= 1.
            grads_masked = lab_marginals * mask_expr.unsqueeze(-1).unsqueeze(
                -1) * mask_expr.unsqueeze(-2).unsqueeze(-1)
            final_losses = (full_score_expr * grads_masked).sum(-1).sum(
                -1)  # [bs, m]
        # divide loss by what?
        num_sent = len(annotated_insts)
        num_valid_tok = sum(len(z) for z in annotated_insts)
        # exclude non-valid ones: there can be pruning error
        if valid_expr is not None:
            final_valids = valid_expr[idxes_bs_expr, idxes_m_expr,
                                      gold_heads_expr]  # [bs, m] of (0. or 1.)
            final_losses = final_losses * final_valids
            tok_valid = float(BK.get_value(final_valids[:, 1:].sum()))
            assert tok_valid <= num_valid_tok
            tok_prune_err = num_valid_tok - tok_valid
        else:
            tok_prune_err = 0
        # collect loss with mask, also excluding the first symbol of ROOT
        final_losses_masked = (final_losses * mask_expr)[:, 1:]
        final_loss_sum = BK.sum(final_losses_masked)
        if self.conf.tconf.loss_div_tok:
            final_loss = final_loss_sum / num_valid_tok
        else:
            final_loss = final_loss_sum / num_sent
        final_loss_sum_val = float(BK.get_value(final_loss_sum))
        info = {
            "sent": num_sent,
            "tok": num_valid_tok,
            "tok_prune_err": tok_prune_err,
            "loss_sum": final_loss_sum_val
        }
        return final_loss, info

    # =====
    # special preloading
    @staticmethod
    def special_pretrain_load(m: BaseParser, path, strict):
        if FileHelper.isfile(path):
            try:
                zlog(f"Trying to load pretrained model from {path}")
                m.load(path, strict)
                zlog(f"Finished loading pretrained model from {path}")
                return True
            except:
                zlog(traceback.format_exc())
                zlog("Failed loading, keep the original ones.")
        else:
            zlog(f"File does not exist for pretraining loading: {path}")
        return False

    # init for the pre-trained G1, return g1parser and also modify m's params
    @staticmethod
    def pre_g1_init(m: BaseParser, pg1_conf: PreG1Conf, strict=True):
        # ===== basic G1 Parser's loading
        # todo(WARN): construct the g1conf here instead of loading for simplicity, since the scorer architecture should be the same
        g1conf = G1ParserConf()
        g1conf.bt_conf = deepcopy(m.conf.bt_conf)
        g1conf.sc_conf = deepcopy(m.conf.sc_conf)
        g1conf.validate()
        # todo(note): specific setting
        g1conf.g1_use_aux_scores = pg1_conf.g1_use_aux_scores
        #
        g1parser = G1Parser(g1conf, m.vpack)
        if not G1Parser.special_pretrain_load(
                g1parser, pg1_conf.g1_pretrain_path, strict):
            g1parser = None
        # current init
        if pg1_conf.g1_pretrain_init:
            G1Parser.special_pretrain_load(m, pg1_conf.g1_pretrain_path,
                                           strict)
        return g1parser

    # collect and batch scores
    @staticmethod
    def collect_aux_scores(insts: List[ParseInstance], output_num_label):
        score_tuples = [z.extra_features["aux_score"] for z in insts]
        num_label = score_tuples[0][1].shape[-1]
        max_len = max(len(z) + 1 for z in insts)
        mask_value = Constants.REAL_PRAC_MIN
        bsize = len(insts)
        arc_score_arr = np.full([bsize, max_len, max_len],
                                mask_value,
                                dtype=np.float32)
        lab_score_arr = np.full([bsize, max_len, max_len, output_num_label],
                                mask_value,
                                dtype=np.float32)
        mask_arr = np.full([bsize, max_len], 0., dtype=np.float32)
        for bidx, one_tuple in enumerate(score_tuples):
            one_score_arc, one_score_lab = one_tuple
            one_len = one_score_arc.shape[1]
            arc_score_arr[bidx, :one_len, :one_len] = one_score_arc
            lab_score_arr[bidx, :one_len, :one_len,
                          -num_label:] = one_score_lab
            mask_arr[bidx, :one_len] = 1.
        return BK.input_real(arc_score_arr).unsqueeze(-1), BK.input_real(
            lab_score_arr), BK.input_real(mask_arr)

    # pruner: [bs, slen, slen], [bs, slen, slen, Lab]
    @staticmethod
    def prune_with_scores(arc_score,
                          label_score,
                          mask_expr,
                          pconf: PruneG1Conf,
                          arc_marginals=None):
        prune_use_topk, prune_use_marginal, prune_labeled, prune_perc, prune_topk, prune_gap, prune_mthresh, prune_mthresh_rel = \
            pconf.pruning_use_topk, pconf.pruning_use_marginal, pconf.pruning_labeled, pconf.pruning_perc, pconf.pruning_topk, \
            pconf.pruning_gap, pconf.pruning_mthresh, pconf.pruning_mthresh_rel
        full_score = arc_score + label_score
        final_valid_mask = BK.constants(BK.get_shape(arc_score),
                                        0,
                                        dtype=BK.uint8).squeeze(-1)
        # (put as argument) arc_marginals = None  # [*, mlen, hlen]
        if prune_use_marginal:
            if arc_marginals is None:  # does not provided, calculate from scores
                if prune_labeled:
                    # arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).max(-1)[0]
                    # use sum of label marginals instead of max
                    arc_marginals = nmarginal_unproj(full_score,
                                                     mask_expr,
                                                     None,
                                                     labeled=True).sum(-1)
                else:
                    arc_marginals = nmarginal_unproj(arc_score,
                                                     mask_expr,
                                                     None,
                                                     labeled=True).squeeze(-1)
            if prune_mthresh_rel:
                # relative value
                max_arc_marginals = arc_marginals.max(-1)[0].log().unsqueeze(
                    -1)
                m_valid_mask = (arc_marginals.log() -
                                max_arc_marginals) > float(
                                    np.log(prune_mthresh))
            else:
                # absolute value
                m_valid_mask = (arc_marginals > prune_mthresh
                                )  # [*, len-m, len-h]
            final_valid_mask |= m_valid_mask
        if prune_use_topk:
            # prune by "in topk" and "gap-to-top less than gap" for each mod
            if prune_labeled:  # take argmax among label dim
                tmp_arc_score, _ = full_score.max(-1)
            else:
                # todo(note): may be modified inplaced, but does not matter since will finally be masked later
                tmp_arc_score = arc_score.squeeze(-1)
            # first apply mask
            mask_value = Constants.REAL_PRAC_MIN
            mask_mul = (mask_value * (1. - mask_expr))  # [*, len]
            tmp_arc_score += mask_mul.unsqueeze(-1)
            tmp_arc_score += mask_mul.unsqueeze(-2)
            maxlen = BK.get_shape(tmp_arc_score, -1)
            tmp_arc_score += mask_value * BK.eye(maxlen)
            prune_topk = min(prune_topk, int(maxlen * prune_perc + 1), maxlen)
            if prune_topk >= maxlen:
                topk_arc_score = tmp_arc_score
            else:
                topk_arc_score, _ = BK.topk(tmp_arc_score,
                                            prune_topk,
                                            dim=-1,
                                            sorted=False)  # [*, len, k]
            min_topk_arc_score = topk_arc_score.min(-1)[0].unsqueeze(
                -1)  # [*, len, 1]
            max_topk_arc_score = topk_arc_score.max(-1)[0].unsqueeze(
                -1)  # [*, len, 1]
            arc_score_thresh = BK.max_elem(min_topk_arc_score,
                                           max_topk_arc_score -
                                           prune_gap)  # [*, len, 1]
            t_valid_mask = (tmp_arc_score > arc_score_thresh
                            )  # [*, len-m, len-h]
            final_valid_mask |= t_valid_mask
        return final_valid_mask, arc_marginals

    # combining the above two
    @staticmethod
    def score_and_prune(insts: List[ParseInstance], output_num_label,
                        pconf: PruneG1Conf):
        arc_score, lab_score, mask_expr = G1Parser.collect_aux_scores(
            insts, output_num_label)
        valid_mask, arc_marginals = G1Parser.prune_with_scores(
            arc_score, lab_score, mask_expr, pconf)
        return valid_mask, arc_score.squeeze(
            -1), lab_score, mask_expr, arc_marginals

    # =====
    # special mode: first-order model as scorer/pruner
    def score_on_batch(self, insts: List[ParseInstance]):
        with BK.no_grad_env():
            self.refresh_batch(False)
            input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
                insts, False)
            # mask_expr = BK.input_real(mask_arr)
            arc_score = self.scorer_helper.score_arc(enc_repr)
            label_score = self.scorer_helper.score_label(enc_repr)
            return arc_score.squeeze(-1), label_score

    # todo(note): the union of two types of pruner!
    def prune_on_batch(self, insts: List[ParseInstance], pconf: PruneG1Conf):
        with BK.no_grad_env():
            self.refresh_batch(False)
            # encode
            input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
                insts, False)
            mask_expr = BK.input_real(mask_arr)
            arc_score = self.scorer_helper.score_arc(enc_repr)
            label_score = self.scorer_helper.score_label(enc_repr)
            final_valid_mask, arc_marginals = G1Parser.prune_with_scores(
                arc_score, label_score, mask_expr, pconf)
            return final_valid_mask, arc_score.squeeze(
                -1), label_score, mask_expr, arc_marginals

    # =====
    @staticmethod
    def collect_pruning_info(insts: List[ParseInstance], valid_mask_f):
        # two dimensions: coverage and pruning-effect
        maxlen = BK.get_shape(valid_mask_f, -1)
        # 1. coverage
        valid_mask_f_flattened = valid_mask_f.view([-1,
                                                    maxlen])  # [bs*len, len]
        cur_mod_base = 0
        all_mods, all_heads = [], []
        for cur_idx, cur_inst in enumerate(insts):
            for m, h in enumerate(cur_inst.heads.vals[1:], 1):
                all_mods.append(m + cur_mod_base)
                all_heads.append(h)
            cur_mod_base += maxlen
        cov_count = len(all_mods)
        cov_valid = BK.get_value(
            valid_mask_f_flattened[all_mods, all_heads].sum()).item()
        # 2. pruning-rate
        # todo(warn): to speed up, these stats are approximate because of including paddings
        # edges
        pr_edges = int(np.prod(BK.get_shape(valid_mask_f)))
        pr_edges_valid = BK.get_value(valid_mask_f.sum()).item()
        # valid as structured heads
        pr_o2_sib = pr_o2_g = pr_edges
        pr_o3_gsib = maxlen * pr_edges
        valid_chs_counts, valid_par_counts = valid_mask_f.sum(
            -2), valid_mask_f.sum(-1)  # [*, len]
        valid_gsibs = valid_chs_counts * valid_par_counts
        pr_o2_sib_valid = BK.get_value(valid_chs_counts.sum()).item()
        pr_o2_g_valid = BK.get_value(valid_par_counts.sum()).item()
        pr_o3_gsib_valid = BK.get_value(valid_gsibs.sum()).item()
        return {
            "cov_count": cov_count,
            "cov_valid": cov_valid,
            "pr_edges": pr_edges,
            "pr_edges_valid": pr_edges_valid,
            "pr_o2_sib": pr_o2_sib,
            "pr_o2_g": pr_o2_g,
            "pr_o3_gsib": pr_o3_gsib,
            "pr_o2_sib_valid": pr_o2_sib_valid,
            "pr_o2_g_valid": pr_o2_g_valid,
            "pr_o3_gsib_valid": pr_o3_gsib_valid
        }
Ejemplo n.º 28
0
class FpEncoder(BasicNode):
    def __init__(self, pc: BK.ParamCollection, conf: FpEncConf,
                 vpack: VocabPackage):
        super().__init__(pc, None, None)
        self.conf = conf
        # ===== Vocab =====
        self.word_vocab = vpack.get_voc("word")
        self.char_vocab = vpack.get_voc("char")
        self.pos_vocab = vpack.get_voc("pos")
        # avoid no params error
        self._tmp_v = self.add_param("nope", (1, ))
        # ===== Model =====
        # embedding
        self.emb = self.add_sub_node("emb",
                                     MyEmbedder(self.pc, conf.emb_conf, vpack))
        self.emb_output_dim = self.emb.get_output_dims()[0]
        # bert
        self.bert = self.add_sub_node("bert", Berter2(self.pc, conf.bert_conf))
        self.bert_output_dim = self.bert.get_output_dims()[0]
        # make sure there are inputs
        assert self.emb_output_dim > 0 or self.bert_output_dim > 0
        # middle?
        if conf.middle_dim > 0:
            self.middle_node = self.add_sub_node(
                "mid",
                Affine(self.pc,
                       self.emb_output_dim + self.bert_output_dim,
                       conf.middle_dim,
                       act="elu"))
            self.enc_input_dim = conf.middle_dim
        else:
            self.middle_node = None
            self.enc_input_dim = self.emb_output_dim + self.bert_output_dim  # concat the two parts (if needed)
        # encoder?
        # todo(note): feed compute-on-the-fly hp
        conf.enc_conf._input_dim = self.enc_input_dim
        self.enc = self.add_sub_node("enc", MyEncoder(self.pc, conf.enc_conf))
        self.enc_output_dim = self.enc.get_output_dims()[0]
        # ===== Input Specification =====
        # inputs (word, char, pos) and vocabulary
        self.need_word = self.emb.has_word
        self.need_char = self.emb.has_char
        # todo(warn): currently only allow extra fields for POS
        self.need_pos = False
        if len(self.emb.extra_names) > 0:
            assert len(
                self.emb.extra_names) == 1 and self.emb.extra_names[0] == "pos"
            self.need_pos = True
        #
        self.word_padder = DataPadder(2,
                                      pad_vals=self.word_vocab.pad,
                                      mask_range=2)
        self.char_padder = DataPadder(3,
                                      pad_lens=(0, 0, conf.char_max_length),
                                      pad_vals=self.char_vocab.pad)
        self.pos_padder = DataPadder(2, pad_vals=self.pos_vocab.pad)

    def get_output_dims(self, *input_dims):
        return (self.enc_output_dim, )

    #
    def refresh(self, rop=None):
        zfatal("Should call special_refresh instead!")

    # ====
    # special routines

    def special_refresh(self, embed_rop, other_rop):
        self.emb.refresh(embed_rop)
        self.enc.refresh(other_rop)
        self.bert.refresh(other_rop)
        if self.middle_node is not None:
            self.middle_node.refresh(other_rop)

    def prepare_training_rop(self):
        mconf = self.conf
        embed_rop = RefreshOptions(hdrop=mconf.drop_embed,
                                   dropmd=mconf.dropmd_embed,
                                   fix_drop=mconf.fix_drop)
        other_rop = RefreshOptions(hdrop=mconf.drop_hidden,
                                   idrop=mconf.idrop_rnn,
                                   gdrop=mconf.gdrop_rnn,
                                   fix_drop=mconf.fix_drop)
        return embed_rop, other_rop

    def aug_words_and_embs(self, aug_vocab, aug_wv):
        orig_vocab = self.word_vocab
        if self.emb.has_word:
            orig_arr = self.emb.word_embed.E.detach().cpu().numpy()
            # todo(+2): find same-spelling words in the original vocab if not-hit in the extra_embed?
            # todo(warn): here aug_vocab should be find in aug_wv
            aug_arr = aug_vocab.filter_embed(aug_wv, assert_all_hit=True)
            new_vocab, new_arr = MultiHelper.aug_vocab_and_arr(
                orig_vocab, orig_arr, aug_vocab, aug_arr, aug_override=True)
            # assign
            self.word_vocab = new_vocab
            self.emb.word_embed.replace_weights(new_arr)
        else:
            zwarn("No need to aug vocab since delexicalized model!!")
            new_vocab = orig_vocab
        return new_vocab

    # =====
    # run

    def prepare_inputs(self, insts, training, input_word_mask_repl=None):
        word_arr, char_arr, extra_arrs, aux_arrs = None, None, [], []
        # ===== specially prepare for the words
        wv = self.word_vocab
        W_UNK = wv.unk
        word_act_idxes = [z.words.idxes for z in insts]
        # todo(warn): still need the masks
        word_arr, mask_arr = self.word_padder.pad(word_act_idxes)
        if input_word_mask_repl is not None:
            input_word_mask_repl = input_word_mask_repl.astype(np.int)
            word_arr = word_arr * (
                1 - input_word_mask_repl
            ) + W_UNK * input_word_mask_repl  # replace with UNK
        # =====
        if not self.need_word:
            word_arr = None
        if self.need_char:
            chars = [z.chars.idxes for z in insts]
            char_arr, _ = self.char_padder.pad(chars)
        if self.need_pos:
            poses = [z.poses.idxes for z in insts]
            pos_arr, _ = self.pos_padder.pad(poses)
            extra_arrs.append(pos_arr)
        return word_arr, char_arr, extra_arrs, aux_arrs, mask_arr

    # todo(note): for rnn, need to transpose masks, thus need np.array
    # return input_repr, enc_repr, mask_arr
    def run(self, insts, training, input_word_mask_repl=None):
        self._cache_subword_tokens(insts)
        # prepare inputs
        word_arr, char_arr, extra_arrs, aux_arrs, mask_arr = \
            self.prepare_inputs(insts, training, input_word_mask_repl=input_word_mask_repl)
        # layer0: emb + bert
        layer0_reprs = []
        if self.emb_output_dim > 0:
            emb_repr = self.emb(word_arr, char_arr, extra_arrs,
                                aux_arrs)  # [BS, Len, Dim]
            layer0_reprs.append(emb_repr)
        if self.bert_output_dim > 0:
            # prepare bert inputs
            BERT_MASK_ID = self.bert.tokenizer.mask_token_id
            batch_subword_ids, batch_subword_is_starts = [], []
            for bidx, one_inst in enumerate(insts):
                st = one_inst.extra_features["st"]
                if input_word_mask_repl is not None:
                    cur_subword_ids, cur_subword_is_start, _ = \
                        st.mask_and_return(input_word_mask_repl[bidx][1:], BERT_MASK_ID)  # todo(note): exclude ROOT for bert tokens
                else:
                    cur_subword_ids, cur_subword_is_start = st.subword_ids, st.subword_is_start
                batch_subword_ids.append(cur_subword_ids)
                batch_subword_is_starts.append(cur_subword_is_start)
            bert_repr, _ = self.bert.forward_batch(
                batch_subword_ids,
                batch_subword_is_starts,
                batched_typeids=None,
                training=training)  # [BS, Len, D']
            layer0_reprs.append(bert_repr)
        # layer1: enc
        enc_input_repr = BK.concat(layer0_reprs, -1)  # [BS, Len, D+D']
        if self.middle_node is not None:
            enc_input_repr = self.middle_node(enc_input_repr)  # [BS, Len, D??]
        enc_repr = self.enc(enc_input_repr, mask_arr)
        mask_repr = BK.input_real(mask_arr)
        return enc_repr, mask_repr  # [bs, len, *], [bs, len]

    # =====
    # caching
    def _cache_subword_tokens(self, insts):
        for one_inst in insts:
            if "st" not in one_inst.extra_features:
                one_inst.extra_features["st"] = self.bert.subword_tokenize2(
                    one_inst.words.vals[1:], True)
Ejemplo n.º 29
0
 def __init__(self,
              name,
              conf: HLabelConf,
              keys: Dict[str, int],
              nil_as_zero=True):
     assert nil_as_zero, "Currently assume nil as zero for all layers"
     self.conf = conf
     self.name = name
     # from original vocab
     self.orig_counts = {k: v for k, v in keys.items()}
     keys = sorted(set(
         keys.keys()))  # for example, can be "vocab.trg_keys()"
     max_layer = 0
     # =====
     # part 1: layering
     v = {}
     keys = [None] + keys  # ad None as NIL
     for k in keys:
         cur_idx = HLabelIdx.construct_hlidx(k, conf.layered)
         max_layer = max(max_layer, len(cur_idx))
         v[k] = cur_idx
     # collect all the layered types and put idxes
     self.max_layer = max_layer
     self.layered_v = [{} for _ in range(max_layer)
                       ]  # key -> int-idx for each layer
     self.layered_k = [[] for _ in range(max_layer)
                       ]  # int-idx -> key for each layer
     self.layered_prei = [
         [] for _ in range(max_layer)
     ]  # int-idx -> int-idx: idx of prefix in previous layer
     self.layered_hlidx = [[] for _ in range(max_layer)
                           ]  # int-idx -> hlidx for each layer
     for k in keys:
         cur_hidx = v[k]
         cur_types = cur_hidx.pad_types(max_layer)
         assert len(cur_types) == max_layer
         cur_idxes = []  # int-idxes for each layer
         # assign from 0 to max-layer
         for cur_layer_i in range(max_layer):
             cur_layered_v, cur_layered_k, cur_layered_prei, cur_layered_hlidx = \
                 self.layered_v[cur_layer_i], self.layered_k[cur_layer_i], \
                 self.layered_prei[cur_layer_i], self.layered_hlidx[cur_layer_i]
             # also put empty classes here
             for valid_until_idx in range(cur_layer_i + 1):
                 cur_layer_types = cur_types[:valid_until_idx + 1] + (
                     None, ) * (cur_layer_i - valid_until_idx)
                 # for cur_layer_types in [cur_types[:cur_layer_i]+(None,), cur_types[:cur_layer_i+1]]:
                 if cur_layer_types not in cur_layered_v:
                     new_idx = len(cur_layered_k)
                     cur_layered_v[cur_layer_types] = new_idx
                     cur_layered_k.append(cur_layer_types)
                     cur_layered_prei.append(0 if cur_layer_i == 0 else
                                             cur_idxes[-1])  # previous idx
                     cur_layered_hlidx.append(
                         HLabelIdx(cur_layer_types, None)
                     )  # make a new hlidx, need to fill idxes later
             # put the actual idx
             cur_idxes.append(cur_layered_v[cur_types[:cur_layer_i + 1]])
         cur_hidx.idxes = cur_idxes  # put the idxes (actually not useful here)
     self.nil_as_zero = nil_as_zero
     if nil_as_zero:
         assert all(
             z[0].is_nil() for z in
             self.layered_hlidx)  # make sure each layer's 0 is all-Nil
     # put the idxes for layered_hlidx
     self.v = {}
     for cur_layer_i in range(max_layer):
         cur_layered_hlidx = self.layered_hlidx[cur_layer_i]
         for one_hlidx in cur_layered_hlidx:
             one_types = one_hlidx.pad_types(max_layer)
             one_hlidx.idxes = [
                 self.layered_v[i][one_types[:i + 1]]
                 for i in range(max_layer)
             ]
             self.v[str(
                 one_hlidx
             )] = one_hlidx  # todo(note): further layers will over-written previous ones
     self.nil_idx = self.v[""]
     # =====
     # (main) part 2: representation
     # link each type representation to the overall pool
     self.pools_v = {None: 0}  # NIL as 0
     self.pools_k = [None]
     self.pools_hint_lexicon = [
         []
     ]  # hit lexicon to look up in pre-trained embeddings: List[pool] of List[str]
     # links for each label-embeddings to the pool-embeddings: List(layer) of List(label) of List(idx-in-pool)
     self.layered_pool_links = [[] for _ in range(max_layer)]
     # masks indicating local-NIL(None)
     self.layered_pool_isnil = [[] for _ in range(max_layer)]
     for cur_layer_i in range(max_layer):
         cur_layered_pool_links = self.layered_pool_links[
             cur_layer_i]  # List(label) of List
         cur_layered_k = self.layered_k[cur_layer_i]  # List(full-key-tuple)
         cur_layered_pool_isnil = self.layered_pool_isnil[
             cur_layer_i]  # List[int]
         for one_k in cur_layered_k:
             one_k_final = one_k[
                 cur_layer_i]  # use the final token at least for lexicon hint
             if one_k_final is None:
                 cur_layered_pool_links.append(
                     [0])  # todo(note): None is always zero
                 cur_layered_pool_isnil.append(1)
                 continue
             cur_layered_pool_isnil.append(0)
             # either splitting into pools or splitting for lexicon hint
             one_k_final_elems: List[str] = split_camel(one_k_final)
             # adding prefix according to the strategy
             if conf.pool_sharing == "nope":
                 one_prefix = f"L{cur_layer_i}-" + ".".join(
                     one_k[:-1]) + "."
             elif conf.pool_sharing == "layered":
                 one_prefix = f"L{cur_layer_i}-"
             elif conf.pool_sharing == "shared":
                 one_prefix = ""
             else:
                 raise NotImplementedError(
                     f"UNK pool-sharing strategy {conf.pool_sharing}!")
             # put in the pools (also two types of strategies)
             if conf.pool_split_camel:
                 # possibly multiple mappings to the pool, each pool-elem gets only one hint-lexicon
                 cur_layered_pool_links.append([])
                 for this_pool_key in one_k_final_elems:
                     this_pool_key1 = one_prefix + this_pool_key
                     if this_pool_key1 not in self.pools_v:
                         self.pools_v[this_pool_key1] = len(self.pools_k)
                         self.pools_k.append(this_pool_key1)
                         self.pools_hint_lexicon.append(
                             [self.get_lexicon_hint(this_pool_key)])
                     cur_layered_pool_links[-1].append(
                         self.pools_v[this_pool_key1])
             else:
                 # only one mapping to the pool, each pool-elem can get multiple hint-lexicons
                 this_pool_key1 = one_prefix + one_k_final
                 if this_pool_key1 not in self.pools_v:
                     self.pools_v[this_pool_key1] = len(self.pools_k)
                     self.pools_k.append(this_pool_key1)
                     self.pools_hint_lexicon.append([
                         self.get_lexicon_hint(z) for z in one_k_final_elems
                     ])
                 cur_layered_pool_links.append(
                     [self.pools_v[this_pool_key1]])
     assert self.pools_v[None] == 0, "Internal error!"
     # padding and masking for the links (in numpy)
     self.layered_pool_links_padded = []  # List[arr(#layered-label, #elem)]
     self.layered_pool_links_mask = []  # List[arr(...)]
     padder = DataPadder(2, mask_range=2)  # separately for each layer
     for cur_layer_i in range(max_layer):
         cur_arr, cur_mask = padder.pad(
             self.layered_pool_links[cur_layer_i]
         )  # [each-sublabel, padded-max-elems]
         self.layered_pool_links_padded.append(cur_arr)
         self.layered_pool_links_mask.append(cur_mask)
     self.layered_prei = [np.asarray(z) for z in self.layered_prei]
     self.layered_pool_isnil = [
         np.asarray(z) for z in self.layered_pool_isnil
     ]
     self.pool_init_vec = None
     #
     zlog(
         f"Build HLabelVocab {name} (max-layer={max_layer} from pools={len(self.pools_k)}): "
         + "; ".join(
             [f"L{i}={len(self.layered_k[i])}" for i in range(max_layer)]))