コード例 #1
0
def merge_cnn_dm():
    cnn = "/scratch/cluster/jcxu/exComp/0.327,0.122,0.290-cnnTrue1.0-1True3-1093-cp_0.5"
    dm = "/scratch/cluster/jcxu/exComp/0.427,0.192,0.388-dmTrue1.0-1True3-10397-cp_0.7"
    total_pred = []
    total_ref = []
    f = open(cnn, "rb")
    cnn_dict = pickle.load(f)
    f.close()
    fine_cnn_pd = []
    for x in cnn_dict["pred"]:
        fine_x = [easy_post_processing(s) for s in x]
        fine_cnn_pd.append(fine_x)
    total_pred += fine_cnn_pd
    # total_pred += cnn_dict["pred"]
    total_ref += cnn_dict["ref"]

    f = open(dm, "rb")
    dm_dict = pickle.load(f)
    f.close()
    fine_dm_pd = []
    for x in dm_dict["pred"]:
        fine_x = [easy_post_processing(s) for s in x]
        fine_dm_pd.append(fine_x)
    total_pred += fine_dm_pd
    # cnnpd = [easy_post_processing(x) for x in dm_dict["pred"]]
    # total_pred += cnnpd
    # total_pred += dm_dict["pred"]
    # total_pred += dm_dict["pred"]
    total_ref += dm_dict["ref"]
    rouge_metrics = RougeStrEvaluation(name='mine')
    for p, r in zip(total_pred, total_ref):
        rouge_metrics(pred=p, ref=r)
    rouge_metrics.get_metric(True, note='test')
コード例 #2
0
 def test_easy_post_processing(self):
     inp=[
         "In two years ' time , the Scandinavian nation is slated to become the first in the world to phase out radio entirely .",
         "Digitally , there are four times that number .",
         "Frum : Ukrainians want to enter EU and lessen dependence on Russia ; Putin fighting to stop it .",
         "-LRB- CNN -RRB- He might have just won one of sport 's most prestigious events , but it was n't long before Jordan Spieth 's thoughts turned to his autistic sister in the glow of victory . "
     ]
     for x in inp:
         y=easy_post_processing(x)
         print(y)
コード例 #3
0
    def run(self) -> (List, List, List, List):
        # go through everything
        _pred = [[] for _ in range(len(self.keep_threshold))]
        # _visuals = [[] for _ in range(len(self.keep_threshold))]

        # first keep all of the compressions in record
        self.read_sent_record_compressions(self.sent_idxs)

        # start diverging
        # delete those under threshold
        self.del_under_threshold()  # diverge!
        # iterate: delete those already covered in context
        processed_words = self.iterative_rep_del()
        # reorder
        # sent_order: List[int]
        order = np.argsort(self.sort_order)
        # output
        for kepidx, kep in enumerate(self.keep_threshold):
            processed_words[kepidx] = [processed_words[kepidx][o] for o in order]

        # output something for evaluation
        # bag_pred_eval = [[] for _ in range(len(self.keep_threshold))]
        bag_pred_eval = []
        for i, words in enumerate(processed_words):
            _tmp = []
            for j, sent in enumerate(words):
                sent = [x for x in sent if (not x.startswith(sp_tok)) and (not x.startswith(sp_tok_rep))]
                out = easy_post_processing(" ".join(sent))
                _tmp.append(out)
            bag_pred_eval.append(_tmp)

        # (optional) visualization

        if random.random() < 0.005:
            try:
                logger = logging.getLogger()
                logger.info("Prob\t\tType\t\tRatio\t\tRouge\t\tLen\t\tContent")
                for idx, d in enumerate(self.compressions):
                    for key, value in d.items():
                        wt = [value['prob'], value['type'], value['ratio'], value['rouge'], value['len'], key]
                        wt = "\t\t".join([str(x) for x in wt])
                        logger.info(wt)
                log_universal(Partition=self.part, Name=self.name,
                              Abs=self.abs_str
                              )
                for idx in range(len(self.keep_threshold)):
                    lis = processed_words[idx]
                    lis_out = [" ".join(x) for x in lis]
                    log_universal(Kep=self.keep_threshold[idx],
                                  Visual=" | ".join(lis_out))
                # write del_record to disk
                f = open(self.ser_fname, 'a')
                js = json.dumps(self.del_record)
                f.write("\n")
                f.write(js)
                f.close()
            except ZeroDivisionError:
                pass

        # return processed_words, self.del_record, self.compressions, self.full_sents, bag_pred_eval
        return bag_pred_eval
コード例 #4
0
    def _dec_compression_one_step(self, predict_compression,
                                  sp_meta,
                                  word_sent: List[str], keep_threshold: List[float],
                                  context: List[List[str]] = None):

        full_set_len = set(range(len(word_sent)))
        # max_comp, _ = predict_compression.size

        preds = [full_set_len.copy() for _ in range(len(keep_threshold))]

        # Show all of the compression spans
        stat_compression = {}
        for comp_idx, comp_meta in enumerate(sp_meta):
            p = predict_compression[comp_idx][1]
            node_type, sel_idx, rouge, ratio = comp_meta
            if node_type != "BASELINE":
                selected_words = [x for idx, x in enumerate(word_sent) if idx in sel_idx]
                selected_words_str = "_".join(selected_words)
                stat_compression["{}".format(selected_words_str)] = {
                    "prob": float("{0:.2f}".format(p)),  # float("{0:.2f}".format())
                    "type": node_type,
                    "rouge": float("{0:.2f}".format(rouge)),
                    "ratio": float("{0:.2f}".format(ratio)),
                    "sel_idx": sel_idx,
                    "len": len(sel_idx)
                }
        stat_compression_order = OrderedDict(
            sorted(stat_compression.items(), key=lambda item: item[1]["prob"], reverse=True))  # Python 3
        for idx, _keep_thres in enumerate(keep_threshold):
            history: List[str] = context[idx]
            his_set = set((" ".join(history)).split(" "))
            for key, value in stat_compression_order.items():
                p = value['prob']
                sel_idx = value['sel_idx']
                sel_txt = set([word_sent[x] for x in sel_idx])
                if sel_txt - his_set == set():
                    # print("Save big!")
                    # print("Context: {}\tCandidate: {}".format(his_set, sel_txt))
                    preds[idx] = preds[idx] - set(value['sel_idx'])
                    continue
                if p > _keep_thres:
                    preds[idx] = preds[idx] - set(value['sel_idx'])

        preds = [list(x) for x in preds]
        for pred in preds:
            pred.sort()
        # Visual output
        visual_outputs: List[str] = []
        words_for_evaluation: List[str] = []
        meta_keep_ratio_word = []

        for idx, compression in enumerate(preds):
            output = [word_sent[jdx] if (jdx in compression) else '_' + word_sent[jdx] + '_' for jdx in
                      range(len(word_sent))]
            visual_outputs.append(" ".join(output))

            words = [word_sent[x] for x in compression]
            meta_keep_ratio_word.append(float(len(words) / len(word_sent)))
            # meta_kepp_ratio_span.append(1 - float(len(survery['type'][idx]) / len(sp_meta)))
            words = " ".join(words)
            words = easy_post_processing(words)
            # print(words)
            words_for_evaluation.append(words)
        d: List[List] = []
        for kep_th, vis, words_eva, keep_word_ratio in zip(keep_threshold, visual_outputs, words_for_evaluation,
                                                           meta_keep_ratio_word):
            d.append([kep_th, vis, words_eva, keep_word_ratio])
        return stat_compression_order, d
コード例 #5
0
    path = "/scratch/cluster/jcxu/exComp"
    file = "0.325-0.120-0.289-cnnTrue1.0-1True-1093-cp_0.6"
    see_output = "/scratch/cluster/jcxu/data/cnndm_compar/pointgencov/cnn"
    ext_bag, model_bag, ext_dp_bag, see_bag = [], [], [], []
    see_bag = read_abigail_output(see_output)
    with open(os.path.join(path, file), 'rb') as fd:
        x = pickle.load(fd)
        pred = x['pred']
        ori = x['ori']
        cnt = 0
        for pre, o in zip(pred, ori):
            shuffle(pre)
            shuffle(o)
            p = [
                meta_str_surgery(
                    easy_post_processing(replace_lrbrrb(
                        fix_vowel(x)))).lower() for x in pre
            ]
            o = [
                meta_str_surgery(
                    easy_post_processing(replace_lrbrrb(
                        rm_head_cnn(x)))).lower() for x in o
            ]
            o_drop = [dropword(x) for x in o]

            o = [detok(x) for x in o]
            o_drop = [detok(x) for x in o_drop]
            p = [detok(x) for x in p]
            # print("\n".join(p))
            # print("-" * 5)
            ext_bag += o
            ext_dp_bag += o_drop