Exemplo n.º 1
0
def write_insts(filename, insts: List[ParseInstance]):
    with zopen(filename, "w") as fd:
        for inst in insts:
            write_conllu(
                fd,
                *(inst.get_real_values_select(
                    ["words", "poses", "heads", "labels"])))
Exemplo n.º 2
0
def main(args):
    conf: OverallConf = init_everything(args)
    dconf = conf.dconf
    #
    bmodel = get_berter(dconf.bconf)
    for one_input, one_output in zip(
        [dconf.train, dconf.dev, dconf.test],
        [dconf.aux_repr_train, dconf.aux_repr_dev, dconf.aux_repr_test]):
        zlog(f"Read from {one_input} and write to {one_output}")
        num_doc, num_sent = 0, 0
        if one_input and one_output:
            one_streamer = get_data_reader(one_input, dconf.input_format,
                                           dconf.use_label0, dconf.noef_link0,
                                           None)
            bertaug_streamer = BerterDataAuger(one_streamer, bmodel,
                                               "aux_repr")
            with zopen(one_output, 'wb') as fd:
                for one_doc in bertaug_streamer:
                    PickleRW.save_list(
                        [s.extra_features["aux_repr"] for s in one_doc.sents],
                        fd)
                    num_doc += 1
                    num_sent += len(one_doc.sents)
            zlog(f"Finish with doc={num_doc}, sent={num_sent}")
        else:
            zlog("Skip empty files")
    zlog("Finish all.")
Exemplo n.º 3
0
 def write_txt(self, fname):
     with zopen(fname, "w") as fd:
         for pack in self.yield_infos():
             i, w, count, perc, accu_count, accu_perc = pack
             ss = f"{i} {w} {count}({perc:.3f}) {accu_count}({accu_perc:.3f})\n"
             fd.write(ss)
     zlog("Write (txt) to %s: %s" % (fname, str(self)))
Exemplo n.º 4
0
def init_everything(args, ConfType=None):
    # search for basic confs
    all_argv, basic_argv = Conf.search_args(args, ["model_type", "conf_output", "log_file", "msp_seed"],
                                            [str, str, str, int], [None, None, Logger.MAGIC_CODE, None])
    # for basic argvs
    model_type = basic_argv["model_type"]
    conf_output = basic_argv["conf_output"]
    log_file = basic_argv["log_file"]
    msp_seed = basic_argv["msp_seed"]
    if conf_output:
        with zopen(conf_output, "w") as fd:
            for k,v in all_argv.items():
                # todo(note): do not save this one
                if k != "conf_output":
                    fd.write(f"{k}:{v}\n")
    utils.init(log_file, msp_seed)
    # real init of the conf
    if model_type is None:
        utils.zlog("Using the default model type = simple!")
        model_type = "simple"
    if ConfType is None:
        conf = OverallConf(model_type, args)
    else:
        conf = ConfType(model_type, args)
    nn.init(conf.niconf)
    return conf
Exemplo n.º 5
0
 def end(self):
     # sorting by idx of reading
     self.insts.sort(key=lambda x: x.inst_idx)
     # todo(+1): write other output file
     if self.outf is not None:
         with zopen(self.outf, "w") as fd:
             data_writer = get_data_writer(fd, self.out_format)
             data_writer.write(self.insts)
     #
     evaler = ParserEvaler()
     # evaler2 = ParserEvaler(ignore_punct=True, punct_set={"PUNCT", "SYM"})
     eval_arg_names = [
         "poses", "heads", "labels", "pred_poses", "pred_heads",
         "pred_labels"
     ]
     for one_inst in self.insts:
         # todo(warn): exclude the ROOT symbol; the model should assign pred_*
         real_values = one_inst.get_real_values_select(eval_arg_names)
         evaler.eval_one(*real_values)
         # evaler2.eval_one(*real_values)
     report_str, res = evaler.summary()
     # _, res2 = evaler2.summary()
     #
     zlog("Results of %s vs. %s" % (self.outf, self.goldf), func="result")
     zlog(report_str, func="result")
     res["gold"] = self.goldf  # record which file
     # res2["gold"] = self.goldf            # record which file
     zlog("zzzzztest: testing result is " + str(res))
     # zlog("zzzzztest2: testing result is " + str(res2))
     zlog("zzzzzpar: %s" % res["res"], func="result")
     return res
Exemplo n.º 6
0
def yield_ones(file):
    r = ConlluReader()
    with zopen(file) as fd:
        for p in r.yield_ones(fd):
            # add extra info
            for line in p.headlines:
                try:
                    one = json.loads(line)
                    p.misc.update(one)
                except:
                    continue
            # add extra info to tokens
            cur_length = len(p)
            if "ef_score" in p.misc:
                ef_order, ef_score = p.misc["ef_order"], p.misc["ef_score"]
                assert len(ef_order) == cur_length and len(ef_score) == cur_length
                for cur_step in range(cur_length):
                    cur_idx, cur_score = ef_order[cur_step], ef_score[cur_step]
                    assert cur_idx > 0
                    cur_t = p.get_tok(cur_idx)
                    cur_t.misc["efi"] = cur_step
                    cur_t.misc["efs"] = cur_score
            if "g1_score" in p.misc:
                g1_score = p.misc["g1_score"][1:]
                assert len(g1_score) == cur_length
                sorted_idxes = list(reversed(np.argsort(g1_score)))
                for cur_step in range(cur_length):
                    cur_idx = sorted_idxes[cur_step] + 1  # offset by root
                    cur_score = g1_score[cur_idx-1]
                    assert cur_idx > 0
                    cur_t = p.get_tok(cur_idx)
                    cur_t.misc["g1i"] = cur_step
                    cur_t.misc["g1s"] = cur_score
            yield p
Exemplo n.º 7
0
def yield_ones(file):
    r = ConlluReader()
    with zopen(file) as fd:
        for p in r.yield_ones(fd):
            # add root info
            corder_valid_flag = False
            for line in p.headlines:
                try:
                    misc0 = json.loads(line)["ROOT-MISC"]
                    for s in misc0.split("|"):
                        k, v = s.split("=")
                        p.tokens[0].misc[k] = v
                    corder_valid_flag = True
                    break
                except:
                    continue
            # add children order
            if corder_valid_flag:
                for t in p.tokens:
                    corder_str = t.misc.get("CORDER")
                    if corder_str is not None:
                        t.corder = [int(x) for x in corder_str.split(",")]
                    else:
                        t.corder = []
            yield p
Exemplo n.º 8
0
def yield_data(filename):
    if filename is None or filename == "":
        zlog("Start to read raw sentence from stdin")
        while True:
            line = input(">> ")
            if len(line) == 0:
                break
            sent = line.split()
            cur_len = len(sent)
            one = ParseInstance(sent, ["_"] * cur_len, [0] * cur_len,
                                ["_"] * cur_len)
            yield one
    # todo(note): judged by special ending!!
    elif filename.endswith(".pic"):
        zlog(f"Start to read pickle from file {filename}")
        with zopen(filename, 'rb') as fd:
            while True:
                try:
                    one = pickle.load(fd)
                    yield one
                except EOFError:
                    break
    else:
        # otherwise read collnu
        zlog(f"Start to read conllu from file {filename}")
        for one in get_data_reader(filename, "conllu", "", True):
            yield one
Exemplo n.º 9
0
def main(args):
    conf, model, vpack, test_iter = prepare_test(args)
    dconf = conf.dconf
    # todo(note): here is the main change
    # make sure the model is order 1 graph model, otherwise cannot run through
    all_results = []
    all_insts = []
    with utils.Timer(tag="Run-score", info="", print_date=True):
        for cur_insts in test_iter:
            all_insts.extend(cur_insts)
            batched_arc_scores, batched_label_scores = model.score_on_batch(
                cur_insts)
            batched_arc_scores, batched_label_scores = BK.get_value(
                batched_arc_scores), BK.get_value(batched_label_scores)
            for cur_idx in range(len(cur_insts)):
                cur_len = len(cur_insts[cur_idx]) + 1
                # discarding paddings
                cur_res = (batched_arc_scores[cur_idx, :cur_len, :cur_len],
                           batched_label_scores[cur_idx, :cur_len, :cur_len])
                all_results.append(cur_res)
    # reorder to the original order
    orig_indexes = [z.inst_idx for z in all_insts]
    orig_results = [None] * len(orig_indexes)
    for new_idx, orig_idx in enumerate(orig_indexes):
        assert orig_results[orig_idx] is None
        orig_results[orig_idx] = all_results[new_idx]
    # saving
    with utils.Timer(tag="Run-write",
                     info=f"Writing to {dconf.output_file}",
                     print_date=True):
        import pickle
        with utils.zopen(dconf.output_file, "wb") as fd:
            for one in orig_results:
                pickle.dump(one, fd)
    utils.printing("The end.")
Exemplo n.º 10
0
 def _restart(self):
     if not self.input_is_fd:
         if self.fd is not None:
             self.fd.close()
         self.fd = zopen(self.file)
     else:
         zcheck(self.restart_times_ == 0, "Cannot restart a FdStreamer")
Exemplo n.º 11
0
 def end(self):
     # sorting by idx of reading
     self.insts.sort(key=lambda x: x.inst_idx)
     # todo(+1): write other output file
     # self._set_type(self.insts)  # todo(note): no need for this step
     if self.outf is not None:
         with zopen(self.outf, "w") as fd:
             data_writer = get_data_writer(fd, self.out_format)
             data_writer.write(self.insts)
     # evaluation
     evaler = MyIEEvaler(self.eval_conf)
     res = evaler.eval(self.insts, self.insts)
     # the criterion will be average of U/L-evt/arg; now using only labeled results
     # all_results = [res["event"][0], res["event"][1], res["argument"][0], res["argument"][1]]
     # all_results = [res["event"][1], res["argument"][1]]
     all_results = [res[z][1] for z in self.eval_conf.res_list]
     res["res"] = float(np.average([float(z) for z in all_results]))
     # make it loadable by json
     for k in ["event", "argument", "argument2", "entity_filler"]:
         res[k] = str(res.get(k))
     zlog("zzzzzevent: %s" % res["res"], func="result")
     # =====
     # clear pred ones for possible reusing
     for one_doc in self.insts:
         for one_sent in one_doc.sents:
             one_sent.pred_events.clear()
             one_sent.pred_entity_fillers.clear()
     return res
Exemplo n.º 12
0
 def save_hits(self, fname):
     num_hits = len(self.hits)
     printing(f"Saving hit w2v num_words={num_hits:d}, embed_size={self.embed_size:d} to {fname}.")
     with zopen(fname, "w") as fd:
         tmp_words = sorted(self.hits.keys(), key=lambda k: self.wmap[k])  # use original ordering
         tmp_vecs = [self.vecs[self.wmap[k]] for k in tmp_words]
         WordVectors.save_txt(fd, tmp_words, tmp_vecs, self.sep)
Exemplo n.º 13
0
 def write_one(self, inst: BaseDataItem):
     if self.is_dir:
         fname = self.fname_getter(inst)
         with zopen(os.path.join(self.path_or_fd, fname), 'w') as fd:
             json.dump(inst.to_builtin(), fd)
     else:
         self.fd.write(json.dumps(inst.to_builtin()))
         self.fd.write("\n")
Exemplo n.º 14
0
 def _restart(self):
     self.base_streamer_.restart()
     if isinstance(self.file, str):
         if self.fd is not None:
             self.fd.close()
         self.fd = zopen(self.file, mode='rb', encoding=None)
     else:
         zcheck(self.restart_times_ == 0, "Cannot restart a FdStreamer")
Exemplo n.º 15
0
 def _restart(self):
     if self.is_dir:
         self.dir_file_ptr = 0
     elif not self.input_is_fd:
         if self.fd is not None:
             self.fd.close()
         self.fd = zopen(self.path_or_fd)
     else:
         assert self.restart_times_ == 0, "Cannot restart (read multiple times) a FdStreamer"
Exemplo n.º 16
0
def read_results(fname):
    results = []
    with zopen(fname, "rb") as fd:
        while True:
            try:
                one = pickle.load(fd)
                results.append(one)
            except EOFError:
                break
    return results
Exemplo n.º 17
0
def main(args):
    conf: DecodeAConf = init_everything(args, DecodeAConf)
    dconf, mconf = conf.dconf, conf.mconf
    iconf = mconf.iconf
    # vocab
    vpack = IEVocabPackage.build_by_reading(conf)
    # prepare data
    test_streamer = get_data_reader(dconf.test,
                                    dconf.input_format,
                                    dconf.use_label0,
                                    dconf.noef_link0,
                                    dconf.aux_repr_test,
                                    max_evt_layers=dconf.max_evt_layers)
    # model
    model = build_model(conf.model_type, conf, vpack)
    model.load(dconf.model_load_name)
    # use bert?
    if dconf.use_bert:
        bmodel = get_berter(dconf.bconf)
        test_streamer = BerterDataAuger(test_streamer, bmodel, "aux_repr")
    # finally prepare iter (No Cache!!, actually no batch_stream)
    test_inst_preparer = model.get_inst_preper(False)
    test_iter = index_stream(test_streamer, vpack, False, False,
                             test_inst_preparer)
    # =====
    # run
    decoder = ArgAugDecoder(conf.aconf, model)
    all_docs = []
    stat_recorder = StatRecorder(False)
    with Timer(tag="Decode", info="Decoding", print_date=True):
        with zopen(dconf.output_file, 'w') as fd:
            data_writer = get_data_writer(fd, dconf.output_format)
            for one_doc in test_iter:
                info = decoder.decode(one_doc)
                stat_recorder.record(info)
                if conf.verbose:
                    zlog(f"Decode one doc, id={one_doc.doc_id} info={info}")
                # release resources
                for one_sent in one_doc.sents:
                    one_sent.extra_features[
                        "aux_repr"] = None  # todo(note): special name!
                # write output
                data_writer.write([one_doc])
                #
                all_docs.append(one_doc)
    if conf.verbose:
        zlog(f"Finish decoding, overall: {stat_recorder.summary()}")
    # eval?
    if conf.do_eval:
        evaler = MyIEEvaler(MyIEEvalerConf())
        result = evaler.eval(all_docs, all_docs)
        Helper.printd(result)
    zlog("The end.")
Exemplo n.º 18
0
 def end(self):
     # sorting by idx of reading
     self.insts.sort(key=lambda x: x.sid)
     # todo(+1): write other output file
     if self.outf is not None:
         with zopen(self.outf, "w") as fd:
             data_writer = get_data_writer(fd, self.out_format)
             data_writer.write_list(self.insts)
     # eval for parsing & pos & ner
     evaler = ParserEvaler()
     ner_golds, ner_preds = [], []
     has_syntax, has_ner = False, False
     for one_inst in self.insts:
         gold_pos_seq = getattr(one_inst, "pos_seq", None)
         pred_pos_seq = getattr(one_inst, "pred_pos_seq", None)
         gold_ner_seq = getattr(one_inst, "ner_seq", None)
         pred_ner_seq = getattr(one_inst, "pred_ner_seq", None)
         gold_tree = getattr(one_inst, "dep_tree", None)
         pred_tree = getattr(one_inst, "pred_dep_tree", None)
         if gold_pos_seq is not None and gold_tree is not None:
             # todo(+N): only ARTI_ROOT in trees
             gold_pos, gold_heads, gold_labels = gold_pos_seq.vals, gold_tree.heads[
                 1:], gold_tree.labels[1:]
             pred_pos = pred_pos_seq.vals if (
                 pred_pos_seq is not None) else [""] * len(gold_pos)
             pred_heads, pred_labels = (pred_tree.heads[1:], pred_tree.labels[1:]) if (pred_tree is not None) \
                 else ([-1]*len(gold_heads), [""]*len(gold_labels))
             evaler.eval_one(gold_pos, gold_heads, gold_labels, pred_pos,
                             pred_heads, pred_labels)
             has_syntax = True
         if gold_ner_seq is not None and pred_ner_seq is not None:
             ner_golds.append(gold_ner_seq.vals)
             ner_preds.append(pred_ner_seq.vals)
             has_ner = True
     # -----
     zlog("Results of %s vs. %s" % (self.outf, self.goldf), func="result")
     if has_syntax:
         report_str, res = evaler.summary()
         zlog(report_str, func="result")
     else:
         res = {}
     if has_ner:
         for _n, _v in zip("prf", get_prf_scores(ner_golds, ner_preds)):
             res[f"ner_{_n}"] = _v
     res["gold"] = self.goldf  # record which file
     zlog("zzzzztest: testing result is " + str(res))
     # note that res['res'] will not be used
     if 'res' in res:
         del res['res']
     # -----
     return res
Exemplo n.º 19
0
 def __init__(self, file_or_fd, format='json', use_pred=True):
     if isinstance(file_or_fd, str):
         self.fd = zopen(file_or_fd, "w")
     else:
         self.fd = file_or_fd
     self.format = format
     self.use_pred = use_pred
     #
     self._write_f = {
         "json": self.write_json,
         "txt": self.write_txt,
         "tbf": self.write_tbf,
         "ann": self.write_ann
     }[format]
Exemplo n.º 20
0
 def _next(self):
     if self.is_dir:
         cur_ptr = self.dir_file_ptr
         if cur_ptr >= len(self.dir_file_list):
             return None
         else:
             with zopen(self.dir_file_list[cur_ptr]) as fd:
                 ss = fd.read()
             self.dir_file_ptr = cur_ptr + 1
             return ss
     else:
         line = self.fd.readline()
         if len(line) == 0:
             return None
         else:
             return line
Exemplo n.º 21
0
 def __init__(self,
              path_or_fd: str,
              is_dir=False,
              fname_getter=None,
              suffix=""):
     self.path_or_fd = path_or_fd
     # dir mode
     self.is_dir = is_dir
     self.anon_counter = -1
     self.fname_getter = fname_getter if fname_getter else self.anon_name_getter
     self.suffix = suffix
     # otherwise
     if isinstance(path_or_fd, str):
         self.fd = zopen(path_or_fd, "w")
     else:
         self.fd = path_or_fd
Exemplo n.º 22
0
 def yield_data(self, files):
     #
     if not isinstance(files, (list, tuple)):
         files = [files]
     #
     cur_num = 0
     for f in files:
         cur_num += 1
         zlog("-----\nDataReader: [#%d] Start reading file %s." %
              (cur_num, f))
         with zopen(f) as fd:
             for z in self._yield_tokens(fd):
                 yield z
         if cur_num % self.report_freq == 0:
             zlog("** DataReader: [#%d] Summary till now:" % cur_num)
             Helper.printd(self.stats)
     zlog("=====\nDataReader: End reading ALL (#%d) ==> Summary ALL:" %
          cur_num)
     Helper.printd(self.stats)
Exemplo n.º 23
0
 def _load_txt(fname, sep=" "):
     printing("Going to load pre-trained (txt) w2v from %s ..." % fname)
     one = WordVectors(sep=sep)
     repeated_count = 0
     with zopen(fname) as fd:
         # first line
         line = fd.readline()
         try:
             one.num_words, one.embed_size = [int(x) for x in line.split(sep)]
             printing("Reading w2v num_words=%d, embed_size=%d." % (one.num_words, one.embed_size))
             line = fd.readline()
         except:
             printing("Reading w2v.")
         # the rest
         while len(line) > 0:
             line = line.rstrip()
             fields = line.split(sep)
             word, vec = fields[0], [float(x) for x in fields[1:]]
             # zcheck(word not in one.wmap, "Repeated key.")
             # keep the old one
             if word in one.wmap:
                 repeated_count += 1
                 zwarn(f"Repeat key {word}")
                 line = fd.readline()
                 continue
             #
             if one.embed_size is None:
                 one.embed_size = len(vec)
             else:
                 zcheck(len(vec) == one.embed_size, "Unmatched embed dimension.")
             one.vecs.append(vec)
             one.wmap[word] = len(one.words)
             one.words.append(word)
             line = fd.readline()
     # final
     if one.num_words is not None:
         zcheck(one.num_words == len(one.vecs)+repeated_count, "Unmatched num of words.")
     one.num_words = len(one.vecs)
     printing(f"Read ok: w2v num_words={one.num_words:d}, embed_size={one.embed_size:d}, repeat={repeated_count:d}")
     return one
Exemplo n.º 24
0
 def collect_freqs(file):
     zlog(f"Starting dealing with {file}")
     MIN_TOK_PER_DOC = 100  # minimum token per doc
     MIN_PARA_PER_DOC = 5  # minimum line-num (paragraph)
     word2info = {}  # str -> [count, doc-count]
     num_doc = 0
     with zopen(file) as fd:
         docs = fd.read().split("\n\n")
         for one_doc in docs:
             tokens = one_doc.split()  # space or newline
             if len(tokens) >= MIN_TOK_PER_DOC and len(
                     one_doc.split("\n")) >= MIN_PARA_PER_DOC:
                 num_doc += 1
                 # first raw counts
                 for t in tokens:
                     t = str.lower(t)  # todo(note): lowercase!!
                     if t not in word2info:
                         word2info[t] = [0, 0]
                     word2info[t][0] += 1
                 # then doc counts (must be there)
                 for t in set(tokens):
                     t = str.lower(t)
                     word2info[t][1] += 1
     return num_doc, word2info
Exemplo n.º 25
0
def load_model(file) -> StatModel:
    import pickle
    # return PickleRW.from_file(file)
    with zopen(file, 'rb') as fd:
        return pickle.load(fd)
Exemplo n.º 26
0
 def _restart(self):
     if self.fd is not None:
         self.fd.close()
     self.fd = zopen(self.file)
Exemplo n.º 27
0
 def __init__(self, file_or_fd):
     if isinstance(file_or_fd, str):
         self.fd = zopen(file_or_fd, "w")
     else:
         self.fd = file_or_fd
Exemplo n.º 28
0
        global_ranks = [self.get_rank(x) for x in local_words]
        ret = [(w, fr, gr, s1, s2) for w, fr, gr, s1, s2 in zip(
            local_words, global_ranks, cur_tf, cur_tf_idf, cur_range_jump)]
        return ret


if __name__ == '__main__':
    import sys
    conf = KeyWordConf()
    conf.update_from_args(sys.argv[1:])
    if len(conf.build_files) > 0:
        m = KeyWordModel.build(conf)
        if conf.save_file:
            m.save(conf.save_file)
    else:
        m = KeyWordModel.load(conf.load_file, conf)
        # test it
        testing_file = "1993.01.tok"
        with zopen(testing_file) as fd:
            docs = fd.read().split("\n\n")
            for one_doc in docs:
                tokens = one_doc.split()  # space or newline
                one_result = m.extract(tokens)
                res_sort1 = sorted(one_result, key=lambda x: -x[-2])
                res_sort2 = sorted(one_result, key=lambda x: -x[-1])
                res_sort0 = sorted(one_result, key=lambda x: -x[-3])
                zzzzz = 0

# PYTHONPATH=../src/ python3 k2.py build_files:./*.tok build_num_core:10 save_file:./nyt.voc.pic
# PYTHONPATH=../src/ python3 -m pdb k2.py build_files:[] load_file:./nyt.voc.pic
Exemplo n.º 29
0
def main_loop(conf: SDBasicConf, sp: SentProcessor):
    np.seterr(all='raise')
    nn_init(conf.niconf)
    np.random.seed(conf.rand_seed)
    records = defaultdict(int)
    #
    # will trigger error otherwise, save time of loading model
    featurer = None if conf.already_pre_computed else Featurer(conf.fconf)
    output_pic_fd = zopen(conf.output_pic, 'wb') if conf.output_pic else None
    all_insts = []
    vocab = Vocab.read(conf.vocab_file) if conf.vocab_file else None
    unk_repl_upos_set = set(conf.unk_repl_upos)
    with BK.no_grad_env():
        input_stream = yield_data(conf.input_file)
        if conf.rand_input:
            inputs = list(input_stream)
            np.random.shuffle(inputs)
            input_stream = inputs
        for one_inst in input_stream:
            # -----
            # make sure the results are the same; to check whether we mistakenly use gold in that jumble of analysis
            if conf.debug_no_gold:
                one_inst.heads.vals = [0] * len(one_inst.heads.vals)
                if len(one_inst.heads.vals) > 2:
                    one_inst.heads.vals[2] = 1  # avoid err in certain analysis
                one_inst.labels.vals = ["_"] * len(one_inst.labels.vals)
            # -----
            if len(one_inst) >= conf.min_len and len(one_inst) <= conf.max_len:
                folded_distances = one_inst.extra_features.get("sd2_scores")
                if folded_distances is None:
                    if conf.fake_scores:
                        one_inst.extra_features["sd2_scores"] = np.zeros(
                            featurer.output_shape(len(
                                one_inst.words.vals[1:])))
                    else:
                        # ===== replace certain words?
                        word_seq = one_inst.words.vals[1:]
                        upos_seq = one_inst.poses.vals[1:]
                        if conf.unk_repl_thresh > 0:
                            word_seq = [
                                (conf.unk_repl_token if
                                 (u in unk_repl_upos_set and
                                  vocab.getval(w, 0) <= conf.unk_repl_thresh)
                                 else w) for w, u in zip(word_seq, upos_seq)
                            ]
                        if conf.unk_repl_split_thresh < 10:
                            berter_toker = featurer.berter.tokenizer
                            word_seq = [
                                conf.unk_repl_token if
                                (u in unk_repl_upos_set
                                 and len(berter_toker.tokenize(w)) >
                                 conf.unk_repl_split_thresh) else w
                                for w, u in zip(word_seq, upos_seq)
                            ]
                        # ===== auto repl by bert?
                        sent_repls = [word_seq]
                        sent_fixed = [np.zeros(len(word_seq)).astype(np.bool)]
                        for _ in range(conf.sent_repl_times):
                            new_sent, new_fixed = featurer.repl_sent(
                                sent_repls[-1], sent_fixed[-1])
                            sent_repls.append(new_sent)
                            sent_fixed.append(
                                new_fixed)  # once fixed, always fixed
                        one_inst.extra_features["sd3_repls"] = sent_repls
                        one_inst.extra_features["sd3_fixed"] = sent_fixed
                        # ===== score
                        folded_distances = featurer.get_scores(sent_repls[-1])
                        assert len(sent_repls[-1]) == len(word_seq)
                        # ---
                        records["repl_count"] += len(word_seq)
                        records["repl_repl"] += sum(
                            a != b for a, b in zip(sent_repls[-1], word_seq))
                        # ---
                        one_inst.extra_features[
                            "sd2_scores"] = folded_distances
                        one_inst.extra_features["feat_seq"] = word_seq
                if output_pic_fd is not None:
                    pickle.dump(one_inst, output_pic_fd)
                if conf.processing:
                    one_info = sp.test_one_sent(one_inst)
                    # put prediction
                    one_inst.pred_heads.set_vals([0] +
                                                 list(one_info["output"][0]))
                    one_inst.pred_labels.set_vals(["_"] *
                                                  len(one_inst.labels.vals))
                    #
                    phrase_tree_string = one_info.get("phrase_tree")
                    if phrase_tree_string is not None:
                        one_inst.extra_pred_misc[
                            "phrase_tree"] = phrase_tree_string
                all_insts.append(one_inst)
    if output_pic_fd is not None:
        output_pic_fd.close()
    if conf.output_file:
        with zopen(conf.output_file, 'w') as wfd:
            data_writer = get_data_writer(wfd, "conllu")
            data_writer.write(all_insts)
    if conf.output_file_ptree:
        with zopen(conf.output_file_ptree, 'w') as wfd:
            for one_inst in all_insts:
                phrase_tree_string = one_inst.extra_pred_misc.get(
                    "phrase_tree")
                wfd.write(str(phrase_tree_string) + "\n")
    # -----
    Helper.printd(records)
    Helper.printd(sp.summary())
Exemplo n.º 30
0
def write_results(fname, results):
    with zopen(fname, "wb") as fd:
        for one in results:
            pickle.dump(one, fd)