예제 #1
0
def main(*aligned_files):
    RANDOM_DROP_EVT_RATE = 0.15
    # input
    aligned_insts = []
    for f in aligned_files:
        one_reader = ReaderGetterConf().get_reader(input_path=f)
        one_insts = list(one_reader)
        aligned_insts.append([z for z in yield_sents(one_insts)])
    # filter
    good_idxes = []
    for idx in range(len(aligned_insts[0])):
        sent_good = True
        for sent in yield_sents([z[idx] for z in aligned_insts]):
            if RANDOM_DROP_EVT_RATE > 0:
                for evt in list(sent.events):
                    if np.random.random_sample() < RANDOM_DROP_EVT_RATE:
                        sent.delete_frame(evt, "evt")
            for evt in sent.events:
                hits = set()
                for arg in evt.args:
                    widx, wlen = arg.arg.mention.get_span()
                    for ii in range(widx, widx + wlen):
                        if ii in hits:
                            sent_good = False
                        hits.add(ii)
        if sent_good:
            good_idxes.append(idx)
    # output
    output_prefix = "_tmp.json"
    for outi, insts in enumerate(aligned_insts):
        filtered_insts = [insts[ii] for ii in good_idxes]
        writer = WriterGetterConf().get_writer(
            output_path=f"{output_prefix}{outi}")
        writer.write_insts(filtered_insts)
        writer.close()
예제 #2
0
 def __init__(self, conf: UDAnalyzerConf):
     super().__init__(conf)
     conf: UDAnalyzerConf = self.conf
     # --
     # read main files
     main_insts = list(conf.main.get_reader(input_path=conf.gold))
     self.set_var("main", main_insts, explanation="init")
     # eval
     self.evaler = DparEvaler(conf.econf)
     # --
     all_sents = [list(yield_sents(main_insts))]
     all_toks = [[t for s in yield_sents(main_insts) for t in s.tokens]]
     for one_pidx, one_pred in enumerate(conf.preds):
         one_insts = list(conf.extra.get_reader(input_path=one_pred))  # get all of them
         one_sents = list(yield_sents(one_insts))
         assert len(one_sents) == len(all_sents[0])
         # eval
         eres = self.evaler.eval(main_insts, one_insts)
         zlog(f"#=====\nEval with {conf.main} vs. {one_pred}: res = {eres}\n{eres.get_detailed_str()}")
         # --
         all_sents.append(one_sents)
         all_toks.append([t for s in one_sents for t in s.tokens])
     # --
     s_lists = [MatchedList(z) for z in zip(*all_sents)]
     self.set_var("sl", s_lists, explanation="init")  # sent pair
     s_toks = [MatchedList(z) for z in zip(*all_toks)]
     self.set_var("tl", s_toks, explanation="init")  # token pair
예제 #3
0
def main(args):
    conf = MainConf()
    conf.update_from_args(args)
    zlog(f"Ready to evaluate with: {conf.gold} {conf.pred}")
    # --
    stat = Counter()
    gold_sents = list(yield_sents(conf.gold.get_reader()))
    pred_sents = list(yield_sents(conf.pred.get_reader()))
    assert len(gold_sents) == len(pred_sents)
    # --
    set_ee_heads(gold_sents)
    set_ee_heads(pred_sents)
    # --
    for g_sent, p_sent in zip(gold_sents, pred_sents):
        stat["sent"] += 1
        slen = len(g_sent)
        assert slen == len(p_sent)
        stat["tok"] += slen
        # check tokens
        for widx in range(slen):
            corr_pos = int(
                g_sent.seq_upos.vals[widx] == p_sent.seq_upos.vals[widx])
            corr_dhead = int(g_sent.tree_dep.seq_head.vals[widx] ==
                             p_sent.tree_dep.seq_head.vals[widx])
            corr_dlab = corr_dhead * int(
                g_sent.tree_dep.seq_label.vals[widx] ==
                p_sent.tree_dep.seq_label.vals[widx])
            stat["tok_pos"] += corr_pos
            stat["tok_dhead"] += corr_dhead
            stat["tok_dlab"] += corr_dlab
        # check args
        assert len(g_sent.events) == len(p_sent.events)
        for g_evt, p_evt in zip(g_sent.events, p_sent.events):
            assert g_evt.mention.is_equal(
                p_evt.mention) and g_evt.label == p_evt.label
            stat["frame"] += 1
            assert len(g_evt.args) == len(p_evt.args)
            frame_head_corr = 1
            for g_arg, p_arg in zip(g_evt.args, p_evt.args):
                stat["arg"] += 1
                assert g_arg.mention.is_equal(
                    p_arg.mention) and g_arg.label == p_arg.label
                arg_head_corr = int(
                    g_arg.mention.shead_widx == p_arg.mention.shead_widx)
                stat["arg_head"] += arg_head_corr
                frame_head_corr *= arg_head_corr
            stat["frame_head"] += frame_head_corr
        # --
    # --
    # report
    for k, v in stat.items():
        v0 = 0
        if len(k.split("_", 1)) > 1:
            k0 = k.split("_", 1)[0]
            if k0 in stat:
                v0 = stat[k0]
        # --
        zlog(f"{k}: {v}" + (f" -> {v}/{v0}={v/v0}" if v0 > 0 else ""))
예제 #4
0
def main(input_file: str, output_file: str, checking_file: str,
         keep_rate: float):
    keep_rate = float(keep_rate)
    _gen = Random.get_np_generator(12345)
    rstream = Random.stream(_gen.random_sample)
    # --
    # read input
    stat = {}
    input_sents = list(
        yield_sents(ReaderGetterConf().get_reader(input_path=input_file)))
    stat["input"] = get_stat(input_sents)
    if checking_file:
        checking_sents = list(
            yield_sents(
                ReaderGetterConf().get_reader(input_path=checking_file)))
        stat["check"] = get_stat(checking_sents)
        # collect keys
        hit_keys = set()
        for one_check_sent in checking_sents:
            tok_key = ''.join(one_check_sent.seq_word.vals).lower()
            tok_key = ''.join(tok_key.split())  # split and join again
            hit_keys.add(tok_key)
        # filter
        filtered_sents = []
        for one_input_sent in input_sents:
            tok_key = ''.join(one_input_sent.seq_word.vals).lower()
            tok_key = ''.join(tok_key.split())  # split and join again
            if tok_key not in hit_keys:
                filtered_sents.append(one_input_sent)
    else:
        filtered_sents = input_sents
    stat["filter"] = get_stat(filtered_sents)
    # sample
    if keep_rate < 1.:
        sample_sents = [
            s for r, s in zip(rstream, filtered_sents) if r < keep_rate
        ]
    elif keep_rate > 10:
        sample_sents = [z for z in filtered_sents]
        for _ in range(10):
            _gen.shuffle(sample_sents)
        sample_sents = sample_sents[:int(keep_rate)]
    else:
        sample_sents = filtered_sents
    stat["sample"] = get_stat(sample_sents)
    # write
    if os.path.exists(output_file):
        assert False, f"File exists: {output_file}, delete it first!"
    if output_file:
        with WriterGetterConf().get_writer(output_path=output_file) as writer:
            writer.write_insts(sample_sents)
    # stat
    zlog(
        f"Read {input_file}, check {checking_file}, output {output_file}, stat:"
    )
    OtherHelper.printd(stat)
예제 #5
0
파일: dec_upos.py 프로젝트: zzsfornlp/zmsp
 def prep_inst(self, inst, dataset):
     wset = dataset.wset
     if wset == "train":
         voc_upos, = self.vpack
         for sent in yield_sents(inst):
             seq_upos = sent.seq_upos
             seq_idxes = [voc_upos.get_else_unk(z) for z in seq_upos.vals]
             seq_upos.set_idxes(seq_idxes)
     elif self.conf.upos_pred_clear:  # clear if there are
         for sent in yield_sents(inst):
             sent.build_uposes(["UNK"] * len(sent))
예제 #6
0
def main(*args):
    conf: MainConf2 = init_everything(MainConf2(), args)
    # --
    # first read aux ones
    aux_insts = list(conf.aux.get_reader())
    aux_index = MyIndexer2(conf)
    num_aux_sent = 0
    for sent in yield_sents(aux_insts):
        num_aux_sent += 1
        fix_words(sent)
        aux_index.put(sent)
    zlog(
        f"Read from {conf.aux.input_path}: insts={len(aux_insts)}, sents={num_aux_sent}, len(index)={len(aux_index)}"
    )
    # --
    # then read input
    input_insts = list(conf.input.get_reader())
    output_sents = []
    num_input_sent = 0
    num_reset_sent = 0
    num_hit_sent = 0
    cc_status = Counter()
    for sent in yield_sents(input_insts):
        num_input_sent += 1
        fix_words(sent)
        # --
        trg_sent, trg_status = aux_index.query(sent)
        cc_status[trg_status] += 1
        if trg_sent is not None:
            num_hit_sent += 1
            # note: currently we replace upos & tree_dep
            upos_vals, head_vals, deplab_vals = \
                trg_sent.seq_upos.vals, trg_sent.tree_dep.seq_head.vals, trg_sent.tree_dep.seq_label.vals
            sent.build_uposes(upos_vals)
            sent.build_dep_tree(head_vals, deplab_vals)
            output_sents.append(sent)
        else:
            zlog(f"Miss sent: {sent.seq_word}")
            if not conf.output_sent_and_discard_nonhit:
                output_sents.append(sent)
    # --
    zlog(
        f"Read from {conf.input.input_path}: insts={len(input_insts)}, sents={num_input_sent}, (out-sent-{len(output_sents)})"
        f"reset={num_reset_sent}({num_reset_sent/num_input_sent:.4f}) hit={num_hit_sent}({num_hit_sent/num_input_sent:.4f})"
    )
    zlog(f"Query status: {cc_status}")
    # write
    with conf.output.get_writer() as writer:
        if conf.output_sent_and_discard_nonhit:
            writer.write_insts(output_sents)
        else:  # write the original insts
            writer.write_insts(input_insts)
예제 #7
0
 def __init__(self, conf: ZIConverterConf):
     self.conf = conf
     # --
     self._input_f = {
         "none": lambda x: x,
         "sent": lambda x: yield_sents(x),
         "sent_evt": lambda x: (z for z in yield_sents(x)
                                if len(z.events) > 0),  # sents with evts!
         "frame": lambda x: yield_frames(x)
     }[conf.input_strategy]
     self._convert_f = {
         "context": self._convert_context,
         "pairwise0": self._convert_pairwise0,
         "pairwise1": self._convert_pairwise1
     }[conf.convert_strategy]
예제 #8
0
def do_stat(insts):
    cc = Counter()
    voc = SimpleVocab.build_empty()
    for sent in yield_sents(insts):
        cc["sent"] += 1
        cc["tok"] += len(sent)
        cc["tok_pair"] += len(sent)**2
        _tree = sent.tree_dep
        _deplabs = _tree.seq_label.vals
        _slen = len(sent)
        for i0 in range(_slen):
            for i1 in range(_slen):
                if abs(i0 - i1) > 5:
                    continue
                path1, path2 = _tree.get_path(i0, i1)
                labs1, labs2 = sorted(
                    [[_deplabs[z].split(":")[0] for z in path1],
                     [_deplabs[z].split(":")[0] for z in path2]])
                _len = len(labs1) + len(labs2)
                # if _len<=0 or _len>2 or "punct" in labs1 or "punct" in labs2:
                if _len != 2 or "punct" in labs1 or "punct" in labs2:
                    continue
                _k = (tuple(labs1), tuple(labs2))
                voc.feed_one(_k)
    # --
    zlog(cc)
    voc.build_sort()
    d = voc.get_info_table()
    print(d[:100].to_string())
예제 #9
0
 def predict(self, insts: List[Union[Doc, Sent]]):
     conf: RuleTargetExtractorConf = self.conf
     wl_min_count, wl_max_wlen = conf.wl_min_count, conf.wl_max_wlen
     # --
     for sent in yield_sents(insts):
         sent.delete_frames(conf.ftag)  # first clear original ones!
         # extract
         sent_toks, sent_tok_feats = self.feat_toks(sent)
         for one_widx, one_feat in enumerate(sent_tok_feats):
             if one_feat in self.whitelist:
                 possible_items = self.whitelist[one_feat]['items']
             else:
                 possible_items = []
             # check each one: the order means priority
             hit_span = None
             for one_item in possible_items:
                 if one_item['count'] < wl_min_count:
                     continue  # must >= min_count
                 left_feats, right_feats = one_item['left'], one_item['right']
                 left_len, right_len = len(left_feats), len(right_feats)
                 if left_len + right_len > wl_max_wlen:
                     continue  # must <= max_wlen
                 if sent_tok_feats[max(0,one_widx-left_len):one_widx] == left_feats \
                         and sent_tok_feats[one_widx+1:one_widx+1+right_len] == right_feats:
                     # check blacklist
                     key_tok, mention_toks = sent_toks[one_widx], sent_toks[one_widx-left_len:one_widx+right_len+1]
                     if not any(rule.hit(key_tok, mention_toks, sent_toks) for rule in self.blacklist):
                         hit_span = (one_widx-left_len, left_len+1+right_len)
                         break
             # check alternative wl
             if hit_span is None and self.wl_alter_f(sent_toks[one_widx]):
                 hit_span = (one_widx, 1)  # make this token
             # new one -- directly adding (no frame-type)
             if hit_span is not None:
                 f = sent.make_frame(hit_span[0], hit_span[1], conf.ftag, adding=True)
예제 #10
0
파일: run.py 프로젝트: zzsfornlp/zmsp
def batch_stream(in_stream: Streamer, tconf: TConf, training: bool):
    _sent_counter = lambda d: len(list(yield_sents([d])))
    _tok_counter = lambda d: sum(len(s) for s in yield_sents([d]))
    _frame_counter = lambda d: sum(len(s.events) for s in yield_sents([d]))
    _ftok_counter = lambda d: sum(max(1, len(s.events))*len(s) for s in yield_sents([d]))
    batch_size_f_map = {"sent": _sent_counter, "tok": _tok_counter, "frame": _frame_counter, "ftok": _ftok_counter}
    if training:
        batch_size_f = batch_size_f_map[tconf.train_count_mode]
        b_stream = BatchArranger(in_stream, bsize=tconf.train_batch_size, maxi_bsize=tconf.train_maxibatch_size,
                                 batch_size_f=batch_size_f, dump_detectors=None, single_detectors=None, sorting_keyer=lambda x: len(x),
                                 shuffle_batches_times=tconf.train_batch_shuffle_times)
    else:
        batch_size_f = batch_size_f_map[tconf.test_count_mode]
        b_stream = BatchArranger(in_stream, bsize=tconf.test_batch_size, maxi_bsize=1, batch_size_f=batch_size_f,
                                 dump_detectors=None, single_detectors=None, sorting_keyer=None, shuffle_batches_times=0)
    return b_stream, batch_size_f
예제 #11
0
def main(output_prefix, *input_files):
    # input
    all_sents = []
    for f in input_files:
        one_reader = ReaderGetterConf().get_reader(input_path=f)
        one_insts = list(one_reader)
        all_sents.append([z for z in yield_sents(one_insts)])
        zlog(f"Read from {f}: {len(all_sents[-1])} sents")
    # align
    sent_map = OrderedDict()
    for fidx, sents in enumerate(all_sents):
        for sent in sents:
            doc_id = sent.info.get("doc_id", "UNK")
            if doc_id.split("/", 1)[0] == "ontonotes":
                doc_id = doc_id.split("/", 1)[1]
            key = doc_id + "|".join(sent.seq_word.vals)  # map by doc_id + key
            if key not in sent_map:
                sent_map[key] = [sent]
            else:
                sent_map[key].append(sent)
    # --
    num_files = len(input_files)
    matched_sents = [vs for vs in sent_map.values() if len(vs) == num_files]
    unmatched_sents = [vs for vs in sent_map.values() if len(vs) != num_files]
    zlog(f"Aligned sent of {len(matched_sents)}")
    breakpoint()
    # output
    for outi in range(num_files):
        out_sents = [z[outi] for z in matched_sents]
        writer = WriterGetterConf().get_writer(
            output_path=f"{output_prefix}{outi}")
        writer.write_insts(out_sents)
        writer.close()
예제 #12
0
파일: run.py 프로젝트: zzsfornlp/zmsp
 def _f(self, inst):
     conf: MyDataReaderConf = self.conf
     wl_use_lc, deplab_use_label0, sent_loss_weight_non, assume_frame_lu = \
         conf.wl_use_lc, conf.deplab_use_label0, conf.sent_loss_weight_non, conf.assume_frame_lu
     for sent in yield_sents([inst]):
         if wl_use_lc:
             if sent.seq_word is not None:
                 sent.seq_word.set_vals([w.lower() for w in sent.seq_word.vals])
             if sent.seq_lemma is not None:
                 sent.seq_lemma.set_vals([w.lower() for w in sent.seq_lemma.vals])
         if deplab_use_label0:
             sent_tree = sent.tree_dep
             if sent_tree is not None and sent_tree.seq_label is not None:  # use first-level label!!
                 sent_tree.seq_label.set_vals([s.split(":")[0] for s in sent_tree.seq_label.vals])
         if sent_loss_weight_non != 1.0:  # set the special property
             sent._loss_weight_non = sent_loss_weight_non
         if assume_frame_lu:  # special assumption!!
             for evt in sent.events:
                 lu_lemma, lu_pos = evt.info.get("luName").split(".")
                 assign_idx = evt.mention.wridx-1  # note: simply put at the ending token!
                 sent.seq_lemma.vals[assign_idx] = lu_lemma
                 if lu_pos in self._FN_POS_MAP:
                     sent.seq_upos.vals[assign_idx] = self._FN_POS_MAP[lu_pos]
         if conf.set_ee_heads:  # set head using dep-parse tree
             sent_tree = sent.tree_dep
             if sent_tree is not None:
                 set_ee_heads([sent])
         if conf.do_unlabeled_arg:
             for evt in sent.events:
                 for arg in list(evt.args):  # note: simply remove "V" and "C-V" here!!
                     if arg.role in ["V", "C-V"]:
                         arg.delete_self()
                     arg.role = "ZARG"  # just a placeholder!
예제 #13
0
파일: helper.py 프로젝트: zzsfornlp/zmsp
 def prepare(self, insts: List[Union[Doc, Sent]]):
     ret_sents, ret_toks = [], []
     for s in yield_sents(insts):
         ret_sents.append(s)
         ret_toks.append(s.seq_word.vals)
     return [(ret_sents, ret_toks)
             ]  # List[List[Sent], List[Toks(List[str])]]
예제 #14
0
파일: run.py 프로젝트: zzsfornlp/zmsp
 def _summarize_exit_status(self, insts, final_one_result):
     ret = {"key": "zz_ee_key", "result": final_one_result}
     try:
         ret["flops"] = self._get_flops(insts)
     except:
         pass
     # early exit time info
     time_info = self.test_recorder.summary()
     time_info["pure_time"] = time_info["_time"] - time_info.get("srl_posttime", 0) - time_info.get("subtok_time", 0)
     ret["time"] = dict(time_info)
     # --
     try:  # todo(+N): ugly!!
         # early exit layer info
         from collections import Counter
         cc = Counter()
         for sent in yield_sents(insts):
             for frame in sent.events:
                 elidx = frame.info["exit_lidx"]
                 cc[elidx] += 1
         all_count = sum(cc.values())
         avg_lidx = sum(ii*cc[ii] for ii in cc.keys()) / all_count
         details = []
         for ii in sorted(cc.keys()):
             details.append(f"{ii}: {cc[ii]}({cc[ii]/all_count:.3f})")
         zlog(f"Exit status: avg={avg_lidx} details={details}")
         # finally
         ret["exit"] = {"exit_avg": avg_lidx, "exit_cc": dict(cc)}
     except:
         zlog("Failed summarize exit status, simply skip!!")
     # --
     zlog(ret)
     return ret
예제 #15
0
def main(args):
    conf = MainConf()
    conf.update_from_args(args)
    zlog(f"Ready to evaluate with: {conf.gold} {conf.pred} => {conf.output}")
    # --
    final_insts = list(conf.gold.get_reader())  # to modify inplace!
    stat = Counter()
    gold_sents = list(yield_sents(final_insts))
    pred_sents = list(yield_sents(conf.pred.get_reader()))
    assert len(gold_sents) == len(pred_sents)
    for g_sent, p_sent in zip(gold_sents, pred_sents):
        stat["sent"] += 1
        slen = len(g_sent)
        assert slen == len(p_sent)
        stat["tok"] += slen
        # put features
        assert len(g_sent.events) == len(p_sent.events)
        for g_evt, p_evt in zip(g_sent.events, p_sent.events):
            assert g_evt.mention.is_equal(
                p_evt.mention) and g_evt.label == p_evt.label
            stat["frame"] += 1
            stat["ftok"] += slen
            assert len(g_evt.args) == len(p_evt.args)
            # --
            evt_widx = g_evt.mention.shead_widx
            g_paths = [[
                len(z) for z in g_evt.sent.tree_dep.get_path(ii, evt_widx)
            ] for ii in range(slen)]
            p_paths = [[
                len(z) for z in p_evt.sent.tree_dep.get_path(ii, evt_widx)
            ] for ii in range(slen)]
            stat["ftok_corr"] += sum(a == b for a, b in zip(g_paths, p_paths))
            # assign
            g_evt.info["dpaths"] = [g_paths, p_paths
                                    ]  # [2(g/p), SLEN, 2(word, predicate)]
        # --
    # --
    # report
    OtherHelper.printd(stat)
    zlog(
        f"FtokPathAcc: {stat['ftok_corr']} / {stat['ftok']} = {stat['ftok_corr']/stat['ftok']}"
    )
    # --
    # write
    if conf.output.output_path:
        with conf.output.get_writer() as writer:
            writer.write_insts(final_insts)
예제 #16
0
 def loss_on_batch(self,
                   annotated_insts: List,
                   loss_factor=1.,
                   training=True,
                   **kwargs):
     self.refresh_batch(training)
     # --
     sents: List[Sent] = list(yield_sents(annotated_insts))
     # ==
     # extend to events
     import numpy as np
     bsize = sum(len(z.events) for z in sents)
     mlen = max(len(z) for z in sents)
     arr_preds = np.full([bsize, mlen], 0., dtype=np.int32)
     arr_inputs = np.full([bsize, mlen], b'<pad>', dtype=object)
     arr_labels = np.full([bsize, mlen], b'<pad>', dtype=object)
     ii = 0
     for sent in sents:
         for evt in sent.events:
             widx, wlen = evt.mention.get_span()
             assert wlen == 1
             # --
             arr_preds[ii, widx] = 1
             arr_inputs[ii, :len(sent)] = [
                 s.lower().encode() for s in sent.seq_word.vals
             ]
             # --
             tmp_labels = ["O"] * len(sent)
             for arg in evt.args:
                 role = arg.role
                 a_widx, a_wlen = arg.arg.mention.get_span()
                 a_labs = ["B-" + role] + ["I-" + role] * (a_wlen - 1)
                 assert all(z == "O"
                            for z in tmp_labels[a_widx:a_widx + a_wlen])
                 tmp_labels[a_widx:a_widx + a_wlen] = a_labs
             # --
             arr_labels[ii, :len(sent)] = [z.encode() for z in tmp_labels]
             # --
             ii += 1
     assert ii == bsize
     features, labels = data.lookup(({
         "preds": NpWarapper(arr_preds),
         "inputs": NpWarapper(arr_inputs)
     }, NpWarapper(arr_labels)), "train", self.params)
     # ==
     final_loss = self.M(features, labels)
     info = {
         "inst": len(annotated_insts),
         "sent": len(sents),
         "fb": 1,
         "loss": final_loss.item()
     }
     if training:
         assert final_loss.requires_grad
         BK.backward(final_loss, loss_factor)
     zlog(
         f"batch shape = {len(annotated_insts)} {bsize} {mlen} {bsize*mlen}"
     )
     return info
예제 #17
0
def main(file_in="", file_out=""):
    insts = list(ReaderGetterConf().get_reader(input_path=file_in))  # read from stdin
    with WriterGetterConf().get_writer(output_path=file_out) as writer:
        for inst in insts:
            for sent in yield_sents([inst]):
                sent.delete_frames("evt")
                sent.delete_frames("ef")
            writer.write_inst(inst)
예제 #18
0
파일: dec_upos.py 프로젝트: zzsfornlp/zmsp
 def build_vocab(self, datasets: List):
     voc_upos = SimpleVocab.build_empty(self.name)
     for dataset in datasets:
         for sent in yield_sents(dataset.insts):
             voc_upos.feed_iter(sent.seq_upos.vals)
     # finnished
     voc_upos.build_sort()
     return (voc_upos, )
예제 #19
0
파일: run.py 프로젝트: zzsfornlp/zmsp
def train_prep_stream(in_stream: Streamer, tconf: TConf):
    # for training, we get all the sentences!
    assert tconf.train_stream_mode == "sent", "Currently we only support sent training!"
    sent_stream = FListWrapperStreamer(
        in_stream, lambda d: [x for x in yield_sents([d]) if len(x) <= tconf.train_max_length and len(x) >= tconf.train_min_length and (len(x.events) > 0 or next(_BS_sample_stream) > tconf.train_skip_noevt_rate)])  # filter out certain sents!
    if tconf.train_stream_reshuffle_times > 0:  # reshuffle for sents
        sent_stream = ShuffleStreamer(sent_stream, shuffle_bsize=tconf.train_stream_reshuffle_bsize,
                                      shuffle_times=tconf.train_stream_reshuffle_times)
    return sent_stream
예제 #20
0
def main(*args):
    conf: TrainConf = init_everything(TrainConf(), args)
    # --
    reader = conf.train.get_reader()
    inputs = yield_sents(reader)
    extractor = RuleTargetExtractor.train(inputs, conf.econf)
    # save
    zlog(f"Save extractor to {conf.save_name}")
    default_json_serializer.to_file(extractor.to_json(), conf.save_name)
예제 #21
0
def do_stat_srl(insts):
    cc = Counter()
    cc_narg = Counter()
    voc = SimpleVocab.build_empty()
    # set_ee_heads(insts)
    voc_pred, voc_arg = SimpleVocab.build_empty(), SimpleVocab.build_empty()
    voc_deplab = SimpleVocab.build_empty()
    for sent in yield_sents(insts):
        cc["sent"] += 1
        cc["tok"] += len(sent)
        cc["frame"] += len(sent.events)
        # --
        _tree = sent.tree_dep
        if _tree is not None:
            voc_deplab.feed_iter(_tree.seq_label.vals)
        for evt in sent.events:
            voc_pred.feed_one(evt.label)
            evt_widx = evt.mention.shead_widx
            cc_narg[f"NARG={len(evt.args)}"] += 1
            for arg in evt.args:
                voc_arg.feed_one(arg.label)
                cc["arg"] += 1
                # check arg overlap
                for a2 in evt.args:
                    if a2 is arg: continue  # not self
                    if not (arg.mention.widx >= a2.mention.wridx
                            or a2.mention.widx >= arg.mention.wridx):
                        cc["arg_overlap"] += 1
                    else:
                        cc["arg_overlap"] += 0
    # --
    voc.build_sort()
    voc_pred.build_sort()
    voc_arg.build_sort()
    voc_deplab.build_sort()
    # --
    # get more stat
    cc2 = dict(cc)
    cc2.update({
        "t/s": f"{cc['tok']/cc['sent']:.2f}",
        "f/s": f"{cc['frame']/cc['sent']:.2f}",
        "a/f": f"{cc['arg']/cc['frame']:.2f}"
    })
    zlog(f"CC: {cc2}")
    zlog(cc_narg)
    zlog(voc_arg.counts)
    # --
    MAX_PRINT_ITEMS = 20
    d_pred = voc_pred.get_info_table()
    print(d_pred[:MAX_PRINT_ITEMS].to_string())
    d_arg = voc_arg.get_info_table()
    print(d_arg[:MAX_PRINT_ITEMS].to_string())
    d_deplab = voc_deplab.get_info_table()
    print(d_deplab[:MAX_PRINT_ITEMS].to_string())
    d = voc.get_info_table()
    print(d[:MAX_PRINT_ITEMS].to_string())
예제 #22
0
파일: run.py 프로젝트: zzsfornlp/zmsp
 def _go_index(self, inst: Union[Doc, Sent]):
     for sent in yield_sents([inst]):  # make it iterable
         self.index_helper.index_sent(sent)
     # set read_idx!!
     inst.set_read_idx(self.count())
     # inst_preparer
     if self.inst_preparer is not None:
         inst = self.inst_preparer(inst)
     # in fact, inplaced if not wrapping model specific preparer
     return inst
예제 #23
0
def main(*args):
    conf: MainConf = init_everything(MainConf(), args)
    # --
    # first read them all
    src_sents, trg_sents = list(yield_sents(conf.src_input.get_reader())), \
                           list(yield_sents(conf.trg_input.get_reader()))
    assert len(src_sents) == len(trg_sents)
    cc = Counter()
    conv = Converter(conf)
    # --
    outputs = []
    for src_sent, trg_sent in zip(src_sents, trg_sents):
        res = conv.convert(src_sent, trg_sent, cc)
        outputs.append(res)
    zlog("Stat:")
    OtherHelper.printd(cc)
    # --
    with conf.output.get_writer() as writer:
        writer.write_insts(outputs)
예제 #24
0
파일: dec_udep.py 프로젝트: zzsfornlp/zmsp
 def prep_inst(self, inst, dataset):
     conf: ZTaskUdepConf = self.conf
     wset = dataset.wset
     # --
     if wset == "train" or conf.udep_prep_all:
         voc_udep, udep_direct_range = self.vpack
         for sent in yield_sents(inst):
             _tree = sent.tree_dep
             if _tree is None:
                 continue
             _vals = _tree.seq_label.vals
             if conf.use_l1:
                 _vals = [z.split(":")[0] for z in _vals]
             idxes_labs = [voc_udep.get_else_unk(z) for z in _vals]
             _tree.seq_label.set_idxes(idxes_labs)
             # note: refresh the cache
             _mat = _tree.label_matrix  # [m,h]
     elif conf.udep_pred_clear:  # clear if there are
         for sent in yield_sents(inst):
             sent.build_dep_tree([0]*len(sent), ["UNK"]*len(sent))
예제 #25
0
 def _convert_pairwise1(self, stream_inst):
     for inst in stream_inst:
         for sent in yield_sents([inst]):
             if sent.next_sent is None:
                 continue
             ret = InputItem.create([sent, sent.next_sent],
                                    inst=None,
                                    add_seps=[1, 1],
                                    seg_ids=[0, 1],
                                    center_sidx=None)
             yield ret
예제 #26
0
파일: dec_udep.py 프로젝트: zzsfornlp/zmsp
 def build_vocab(self, datasets: List):
     conf: ZTaskUdepConf = self.conf
     # --
     voc_udep = SimpleVocab.build_empty(self.name)
     for dataset in datasets:
         for sent in yield_sents(dataset.insts):
             _vals = sent.tree_dep.seq_label.vals
             if conf.use_l1:
                 _vals = [z.split(":")[0] for z in _vals]
             voc_udep.feed_iter(_vals)
     voc_udep.build_sort()
     _, udep_direct_range = voc_udep.non_special_range()  # range of direct labels
     zlog(f"Finish building voc_udep: {voc_udep}")
     return (voc_udep, udep_direct_range)
예제 #27
0
파일: run.py 프로젝트: zzsfornlp/zmsp
 def _get_flops(self, insts):
     all_flops = 0
     all_count = 0
     for sent in yield_sents(insts):
         for frame in sent.events:
             elidx = frame.info.get("exit_lidx")
             if elidx is None:  # static mode
                 one_flops = self._calculate_flops(frame, None)
             else:
                 one_flops = self._calculate_flops(frame, elidx)
             all_flops += one_flops
             all_count += 1
     # flops
     flops_per_inst = all_flops / all_count
     return flops_per_inst
예제 #28
0
def main(input_format, *input_files: str):
    reader_conf = ReaderGetterConf().direct_update(input_format=input_format)
    reader_conf.validate()
    # --
    all_insts = []
    for ff in input_files:
        one_insts = list(reader_conf.get_reader(input_path=ff))
        cc = Counter()
        for sent in yield_sents(one_insts):
            cc['sent'] += 1
            for evt in sent.events:
                cc['evt'] += 1
                cc['arg'] += len(evt.args)
        zlog(
            f"Read from {ff}: {cc['sent']/1000:.1f}k&{cc['evt']/1000:.1f}k&{cc['arg']/1000:.1f}k"
        )
예제 #29
0
def get_extra_hit_words(input_stream, emb, voc):
    word_counts = {}
    hit_words = set()
    # --
    for sent in yield_sents(input_stream):
        for one in sent.seq_word.vals:
            word_counts[one] = word_counts.get(one, 0) + 1
            if emb.find_key(one) is not None:
                hit_words.add(one)
    # --
    extra_hit_words = [z for z in hit_words if z not in voc]
    extra_hit_words.sort(key=lambda x: -word_counts[x])
    zlog(
        f"Iter hit words: all={len(word_counts)}, hit={len(hit_words)}, extra_hit={len(extra_hit_words)}"
    )
    return extra_hit_words
예제 #30
0
 def _convert_context(self, stream_inst):
     conf: ZIConverterConf = self.conf
     _left_extend_nsent, _right_extend_nsent, _center_special_id = \
         conf.left_extend_nsent, conf.right_extend_nsent, conf.center_special_id
     _extend_word_budget = conf.extend_word_budget
     # --
     for inst in stream_inst:
         for sent in yield_sents([inst]):
             _cur_words = len(sent)
             _cur_left, _cur_right = sent.prev_sent, sent.next_sent
             left_sents, right_sents = [], []
             while _cur_words < _extend_word_budget and (
                     _cur_left is not None or _cur_right is not None):
                 # expand left? note: prefer previous!
                 if _cur_left is not None:
                     _one_len = len(_cur_left)
                     if len(left_sents) < _left_extend_nsent and (
                             _one_len + _cur_words) <= _extend_word_budget:
                         left_sents.append(_cur_left)
                         _cur_left = _cur_left.prev_sent
                         _cur_words += _one_len
                     else:
                         _cur_left = None
                 # expand right?
                 if _cur_right is not None:
                     _one_len = len(_cur_right)
                     if len(right_sents) < _right_extend_nsent and (
                             _one_len + _cur_words) <= _extend_word_budget:
                         right_sents.append(_cur_right)
                         _cur_right = _cur_right.next_sent
                         _cur_words += _one_len
                     else:
                         _cur_right = None
             # final one
             left_sents.reverse()
             cur_sents = left_sents + [sent] + right_sents
             seg_ids = [0] * len(cur_sents)
             center_sidx = len(left_sents)
             if _center_special_id:  # especially set center as 1
                 seg_ids[center_sidx] = 1
             ret = InputItem.create(cur_sents,
                                    inst=inst,
                                    add_seps=None,
                                    seg_ids=seg_ids,
                                    center_sidx=center_sidx)
             yield ret