예제 #1
0
 def __init__(self, conf: SRLAnalyzerConf):
     super().__init__(conf)
     conf: SRLAnalyzerConf = self.conf
     self.err_map = ErrDetail.ERR_MAPS[conf.err_map]
     # --
     if conf.pre_load:
         self.do_load(conf.pre_load)
     else:
         # further analyze the arguments
         num_pred = len(conf.preds)
         f_lists = self.get_var("fl")  # get frame matches
         f_all_correct_list = []
         f_some_wrong_list = []
         for fl in f_lists:
             gold_frame = fl.gold
             pred_frames = fl.preds
             assert len(pred_frames) == num_pred
             # get them all
             self._process_args(gold_frame)
             err_infos = []
             for pf in pred_frames:
                 self._process_args(pf)  # sort args
                 einfo = ErrInfo.create(gold_frame, pf)
                 err_infos.append(einfo)
             fl.err_infos = err_infos
             # --
             if all(e.fully_correct() for e in err_infos):
                 f_all_correct_list.append(fl)
             else:
                 f_some_wrong_list.append(fl)
         self.set_var("fl1", f_all_correct_list,
                      explanation="init")  # eval pair
         self.set_var("fl0", f_some_wrong_list,
                      explanation="init")  # eval pair
         zlog(
             f"All frames = {len(f_lists)}, all_corr = {len(f_all_correct_list)}({len(f_all_correct_list)/max(1, len(f_lists))})"
         )
         # --
         # breakdowns for all
         for pi in range(num_pred):
             one_err_infos = [
                 e for fl in f_lists for e in fl.err_infos[pi].rps
             ]
             self.set_var(f"eip{pi}", one_err_infos)
             # self.do_group(f"eip{pi}", "d.get_signature('etype', 'etype2', emap=self.err_map)")
             self.do_group(f"eip{pi}",
                           "d.get_signature('etype', emap=self.err_map)")
             self.do_group(f"eip{pi}", "d.get_signature('etype2')")
             # group eip0 "d.get_signature('etype2')"
             # fg eip0 "d.get_signature('explain')!='_'" "d.get_signature('explain')"
         # --
         # get ps objects
         # self.set_var("dps100", self._get_dpath_objects(f_lists, 100))
     # --
     # load vocab
     if conf.load_vocab:
         from msp2.utils import default_pickle_serializer
         self.vocabs, _ = default_pickle_serializer.from_file(
             conf.load_vocab)
예제 #2
0
파일: task.py 프로젝트: zzsfornlp/zmsp
 def load_vocab(self, v_dir: str):  # return whether succeed!
     vp_file = os.path.join(v_dir, f"v_{self.name}.pkl")
     if os.path.exists(vp_file):
         self.vpack = default_pickle_serializer.from_file(vp_file)
         zlog(f"Load vocabs ``{self.vpack}'' for {self} from {vp_file}")
         return True
     else:
         self.vpack = None  # not_found
         return False
예제 #3
0
 def load(self, prefix="./"):
     for name in self.vocabs:
         fname = prefix + "vv_" + name + ".txt"
         if os.path.exists(fname):
             vtype = self.voc_types.get(name, self._default_vocab_type)
             self.vocabs[name] = vtype.read_from_file(fname)
         else:
             zwarn("Cannot find Vocab " + name)
             self.vocabs[name] = None
     for name in self.embeds:
         fname = prefix + "ve_" + name + ".pic"
         if os.path.exists(fname):
             self.embeds[name] = default_pickle_serializer.from_file(fname)
         else:
             self.embeds[name] = None
예제 #4
0
def main(vocab_file: str, input_path: str, output_file='lt.pkl'):
    # first get vocab
    vocabs = default_pickle_serializer.from_file(vocab_file)
    arg_voc = vocabs[0]['arg']
    zlog(f"Read {arg_voc} from {vocab_file}")
    # make it to BIO-vocab
    bio_voc = SeqVocab(arg_voc)
    zlog(f"Build bio-voc of {bio_voc}")
    # read insts
    insts = list(ReaderGetterConf().get_reader(
        input_path=input_path))  # read from stdin
    all_sents = list(yield_sents(insts))
    # --
    mat = np.ones([len(bio_voc), len(bio_voc)],
                  dtype=np.float32)  # add-1 smoothing!
    cc = Counter()
    for sent in all_sents:
        for evt in sent.events:
            labels = ['O'] * len(sent)
            for arg in evt.args:
                widx, wlen = arg.mention.get_span()
                labels[widx:wlen] = ["B-" + arg.role
                                     ] + ["I-" + arg.role] * (wlen - 1)
            for a, b in zip(labels, labels[1:]):
                cc[f"{a}->{b}"] += 1
                mat[bio_voc[a], bio_voc[b]] += 1
        # --
    # --
    v = SimpleVocab()
    for name, count in cc.items():
        v.feed_one(name, count)
    v.build_sort()
    print(v.get_info_table()[:50].to_string())
    # OtherHelper.printd(cc)
    # --
    # normalize & log according to row and save
    mat = mat / mat.sum(-1, keepdims=True)
    mat = np.log(mat)
    default_pickle_serializer.to_file(mat, output_file)
예제 #5
0
 def __init__(self, conf: SrlInferenceHelperConf, dec: 'ZDecoderSrl',
              **kwargs):
     super().__init__(conf, **kwargs)
     conf: SrlInferenceHelperConf = self.conf
     # --
     self.setattr_borrow('dec', dec)
     self.arg_pp = PostProcessor(conf.arg_pp)
     # --
     self.lu_cons, self.role_cons = None, None
     if conf.frames_name:  # currently only frame->role
         from msp2.data.resources import get_frames_label_budgets
         flb = get_frames_label_budgets(conf.frames_name)
         _voc_ef, _voc_evt, _voc_arg = dec.ztask.vpack
         _role_cons = fchelper.build_constraint_arrs(
             flb, _voc_arg, _voc_evt)
         self.role_cons = BK.input_real(_role_cons)
     if conf.frames_file:
         _voc_ef, _voc_evt, _voc_arg = dec.ztask.vpack
         _fc = default_pickle_serializer.from_file(conf.frames_file)
         _lu_cons = fchelper.build_constraint_arrs(
             fchelper.build_lu_map(_fc), _voc_evt,
             warning=False)  # lexicon->frame
         _role_cons = fchelper.build_constraint_arrs(
             fchelper.build_role_map(_fc), _voc_arg,
             _voc_evt)  # frame->role
         self.lu_cons, self.role_cons = _lu_cons, BK.input_real(_role_cons)
     # --
     self.cons_evt_tok_f = conf.get_cons_evt_tok()
     self.cons_evt_frame_f = conf.get_cons_evt_frame()
     if self.dec.conf.arg_use_bio:  # extend for bio!
         self.cons_arg_bio_sels = BK.input_idx(
             self.dec.vocab_bio_arg.get_bio2origin())
     else:
         self.cons_arg_bio_sels = None
     # --
     from msp2.data.resources.frames import KBP17_TYPES
     self.pred_evt_filter = {
         'kbp17': KBP17_TYPES
     }.get(conf.pred_evt_filter, None)
예제 #6
0
def main(*args):
    conf = MainConf()
    conf.update_from_args(args)
    # --
    if conf.load_pkl:
        collection = default_pickle_serializer.from_file(conf.load_pkl)
    else:
        reader = FrameReader()
        collection = reader.read_all(conf.dir, conf.onto)
    if conf.save_pkl:
        default_pickle_serializer.to_file(collection, conf.save_pkl)
    if conf.save_txt:
        with zopen(conf.save_txt, 'w') as fd:
            for f in collection.frames:
                fd.write("#--\n" + f.to_string() + "\n")
    # --
    if conf.debug:
        breakpoint()
    if conf.query:
        map_frame = {f.name: f for f in collection.frames}
        map_lu = ZFrameCollectionHelper.build_lu_map(collection, split_lu={"pb":"_", "fn":None}[conf.onto])
        map_role = ZFrameCollectionHelper.build_role_map(collection)
        while True:
            line = input(">> ")
            fields = sh_split(line.strip())
            if len(fields) == 0:
                continue
            try:
                query0, query1 = fields
                _map = {'frame': map_frame, 'lu': map_lu, 'role': map_role}[query0]
                answer = _map.get(query1, None)
                if isinstance(answer, ZFrame):
                    zlog(answer.to_string())
                else:
                    zlog(answer)
            except:
                zlog(f"Wrong cmd: {fields}")
                pass
예제 #7
0
파일: run.py 프로젝트: zzsfornlp/zmsp
 def load(self, prefix="./"):
     fname = prefix + "zsfp.voc.pkl"
     self.vocabs, self.embeds = default_pickle_serializer.from_file(fname)
예제 #8
0
파일: analyzer.py 프로젝트: zzsfornlp/zmsp
 def do_load(self, file: str):
     zlog(f"Try loading vars from {file}")
     x = default_pickle_serializer.from_file(file)
     self.vars.update(x)  # note: update rather than replace!!
예제 #9
0
 def __init__(self, vocab: SimpleVocab, conf: SeqLabelerConf, **kwargs):
     super().__init__(conf, **kwargs)
     conf: SeqLabelerConf = self.conf
     is_pairwise = (conf.psize > 0)
     self.is_pairwise = is_pairwise
     # --
     # 0. pre mlp
     isize, psize = conf.isize, conf.psize
     self.main_mlp = MLPNode(conf.main_mlp,
                             isize=isize,
                             osize=-1,
                             use_out=False)
     isize = self.main_mlp.get_output_dims()[0]
     if is_pairwise:
         self.pair_mlp = MLPNode(conf.pair_mlp,
                                 isize=psize,
                                 osize=-1,
                                 use_out=False)
         psize = self.pair_mlp.get_output_dims()[0]
     else:
         self.pair_mlp = lambda x: x
     # 1/2. decoder & laber
     if conf.use_seqdec:
         # extra for seq-decoder
         dec_hid = conf.seqdec_conf.dec_hidden
         # setup labeler to get embedding dim
         self.laber = SimpleLabelerNode(vocab,
                                        conf.labeler_conf,
                                        isize=dec_hid,
                                        psize=psize)
         laber_embed_dim = self.laber.lookup_dim
         # init starting hidden; note: choose different according to 'is_pairwise'
         self.sd_init_aff = AffineNode(
             conf.sd_init_aff,
             isize=(psize if is_pairwise else isize),
             osize=dec_hid)
         self.sd_init_pool_f = ActivationHelper.get_pool(conf.sd_init_pool)
         # sd input: one_repr + one_idx_embed
         self.sd_input_aff = AffineNode(conf.sd_init_aff,
                                        isize=[isize, laber_embed_dim],
                                        osize=dec_hid)
         # sd output: cur_expr + hidden
         self.sd_output_aff = AffineNode(conf.sd_output_aff,
                                         isize=[isize, dec_hid],
                                         osize=dec_hid)
         # sd itself
         self.seqdec = PlainDecoder(conf.seqdec_conf, input_dim=dec_hid)
     else:
         # directly using the scorer (overwrite some values)
         self.laber = SimpleLabelerNode(vocab,
                                        conf.labeler_conf,
                                        isize=isize,
                                        psize=psize)
     # 3. bigram
     # todo(note): bigram does not consider skip_non
     if conf.use_bigram:
         self.bigram = BigramNode(conf.bigram_conf,
                                  osize=self.laber.output_dim)
     else:
         self.bigram = None
     # special decoding
     if conf.pred_use_seq_cons_from_file:
         assert not conf.pred_use_seq_cons
         _m = default_pickle_serializer.from_file(
             conf.pred_use_seq_cons_from_file)
         zlog(f"Load weights from {conf.pred_use_seq_cons_from_file}")
         self.pred_cons_mat = BK.input_real(_m)
     elif conf.pred_use_seq_cons:
         _m = vocab.get_allowed_transitions()
         self.pred_cons_mat = (1. -
                               BK.input_real(_m)) * Constants.REAL_PRAC_MIN
     else:
         self.pred_cons_mat = None
     # =====
     # loss
     self.loss_mle, self.loss_crf = [
         conf.loss_mode == z for z in ["mle", "crf"]
     ]
     if self.loss_mle:
         if conf.use_seqdec or conf.use_bigram:
             zlog("Setup SeqLabelerNode with Local complex mode!")
         else:
             zlog("Setup SeqLabelerNode with Local simple mode!")
     elif self.loss_crf:
         assert conf.use_bigram and (
             not conf.use_seqdec), "Wrong mode for crf"
         zlog("Setup SeqLabelerNode with CRF mode!")
     else:
         raise NotImplementedError(f"UNK loss mode: {conf.loss_mode}")