Beispiel #1
0
def merge_cnn_dm():
    cnn = "/scratch/cluster/jcxu/exComp/0.327,0.122,0.290-cnnTrue1.0-1True3-1093-cp_0.5"
    dm = "/scratch/cluster/jcxu/exComp/0.427,0.192,0.388-dmTrue1.0-1True3-10397-cp_0.7"
    total_pred = []
    total_ref = []
    f = open(cnn, "rb")
    cnn_dict = pickle.load(f)
    f.close()
    fine_cnn_pd = []
    for x in cnn_dict["pred"]:
        fine_x = [easy_post_processing(s) for s in x]
        fine_cnn_pd.append(fine_x)
    total_pred += fine_cnn_pd
    # total_pred += cnn_dict["pred"]
    total_ref += cnn_dict["ref"]

    f = open(dm, "rb")
    dm_dict = pickle.load(f)
    f.close()
    fine_dm_pd = []
    for x in dm_dict["pred"]:
        fine_x = [easy_post_processing(s) for s in x]
        fine_dm_pd.append(fine_x)
    total_pred += fine_dm_pd
    # cnnpd = [easy_post_processing(x) for x in dm_dict["pred"]]
    # total_pred += cnnpd
    # total_pred += dm_dict["pred"]
    # total_pred += dm_dict["pred"]
    total_ref += dm_dict["ref"]
    rouge_metrics = RougeStrEvaluation(name='mine')
    for p, r in zip(total_pred, total_ref):
        rouge_metrics(pred=p, ref=r)
    rouge_metrics.get_metric(True, note='test')
Beispiel #2
0
def _para_get_metric_reset_false(metric: RougeStrEvaluation, key, note):
    current_metrics = metric.get_metric(reset=False, note=note)
    current_best_cp_A = [x for x in current_metrics.keys() if x.endswith("_A")]
    assert len(current_best_cp_A) == 1
    current_best_cp_A = current_best_cp_A[0]
    cp_A_val = current_metrics[current_best_cp_A]
    return current_metrics, cp_A_val, metric, key
Beispiel #3
0
def get_refresh_metric():
    gold_path = "/backup3/jcxu/data/gold-cnn-dailymail-test-orgcase"
    # 0034b7c223e24477e046cf3ee085dd006be38b27.gold
    model_path = "/backup3/jcxu/data/cnn-dailymail-ensemble-model11-model7"
    # 0034b7c223e24477e046cf3ee085dd006be38b27.model
    full_dataname = "cnn"

    test_urls = '/backup3/jcxu/data/cnn-dailymail/url_lists/{}_wayback_test_urls.txt'.format(
        full_dataname)

    with open(test_urls, 'r', encoding='utf-8') as fd:
        lines = fd.read().splitlines()
        url_names = get_url_hashes(lines)
        print("len of urls {}".format(len(url_names)))
    print(url_names[0])
    rouge_metrics_sent = RougeStrEvaluation(name='refresh')
    for url in url_names:
        # gold
        try:
            with open(os.path.join(gold_path, url + '.gold'), 'r') as fd:
                abs = fd.read().splitlines()
            with open(os.path.join(model_path, url + '.model'), 'r') as fd:
                pred = fd.read().splitlines()
            rouge_metrics_sent(pred=pred, ref=[abs])
        except IOError:
            print(url)
    full_dataname = "dailymail"

    test_urls = '/backup3/jcxu/data/cnn-dailymail/url_lists/{}_wayback_test_urls.txt'.format(
        full_dataname)

    with open(test_urls, 'r', encoding='utf-8') as fd:
        lines = fd.read().splitlines()
        url_names = get_url_hashes(lines)
        print("len of urls {}".format(len(url_names)))
    print(url_names[0])
    for url in url_names:
        # gold
        try:
            with open(os.path.join(gold_path, url + '.gold'), 'r') as fd:
                abs = fd.read().splitlines()
            with open(os.path.join(model_path, url + '.model'), 'r') as fd:
                pred = fd.read().splitlines()
            rouge_metrics_sent(pred=pred, ref=[abs])
        except IOError:
            print(url)
    rouge_metrics_sent.get_metric(True)
    def __init__(self,
                 rnn_type: str = 'lstm',
                 dec_hidden_size: int = 100,
                 dec_input_size: int = 50,
                 dropout: float = 0.1,
                 fixed_dec_step: int = -1,
                 max_dec_steps: int = 2,
                 min_dec_steps: int = 2,
                 schedule_ratio_from_ground_truth: float = 0.5,
                 dec_avd_trigram_rep: bool = True,
                 mult_orac_sample_one: bool = True,
                 abs_board_file="/home/cc/exComp/board.txt",
                 valid_tmp_path='/scratch/cluster/jcxu/exComp',
                 serilization_name: str = ""):
        super().__init__()
        self.device = get_device()
        self._rnn_type = rnn_type
        self._dec_input_size = dec_input_size
        self._dec_hidden_size = dec_hidden_size

        self.fixed_dec_step = fixed_dec_step
        if fixed_dec_step == -1:
            self.min_dec_steps = min_dec_steps
            self.max_dec_steps = max_dec_steps
        else:
            self.min_dec_steps, self.max_dec_steps = fixed_dec_step, fixed_dec_step
        self.schedule_ratio_from_ground_truth = schedule_ratio_from_ground_truth
        self.mult_orac_sample_one_as_gt = mult_orac_sample_one
        self._dropout = nn.Dropout(dropout)

        self.rnn = self.build_rnn(
            self._rnn_type,
            self._dec_input_size,
            self._dec_hidden_size,
        )
        self.rnn_init_state_h = torch.nn.Linear(dec_hidden_size,
                                                dec_hidden_size)
        self.rnn_init_state_c = torch.nn.Linear(dec_hidden_size,
                                                dec_hidden_size)

        self.attn = NewAttention(enc_dim=dec_input_size,
                                 dec_dim=dec_hidden_size)
        self.CELoss = torch.nn.CrossEntropyLoss(ignore_index=-1,
                                                reduction='none')  # TODO
        self.rouge_metrics_sent = RougeStrEvaluation(
            name='sent',
            path_to_valid=valid_tmp_path,
            writting_address=valid_tmp_path,
            serilization_name=serilization_name)
        self.dec_avd_trigram_rep = dec_avd_trigram_rep
Beispiel #5
0
def my_lead3():
    data_path = "/scratch/cluster/jcxu/data/2merge-nyt"
    print(data_path)
    lead = 3
    if 'nyt' in data_path:
        lead = 5
    files = [x for x in os.listdir(data_path) if x.startswith("test.pkl")]
    rouge_metrics_sent = RougeStrEvaluation(name='mine')
    import pickle
    print(lead)
    for idx, file in enumerate(files):
        print(idx)
        f = open(os.path.join(data_path, file), 'rb')
        print("reading {}".format(file))
        data = pickle.load(f)
        for instance_fields in data:
            meta = instance_fields['metadata']
            doc_list = meta['doc_list'][:lead]
            abs_list = meta['abs_list']

            doc_list = [" ".join(x) for x in doc_list]
            abs_list = [" ".join(x) for x in abs_list]
            rouge_metrics_sent(pred=doc_list, ref=[abs_list])
    rouge_metrics_sent.get_metric(True, note='test')
Beispiel #6
0
def get_qyz():
    path = "/scratch/cluster/jcxu/data/cnndm_compar/qyz-output"
    data = "cnn"
    ref = "qyz_{}_sum.txt"
    pred = "qyz_{}.txt"
    rouge_metrics = RougeStrEvaluation(name='neusum')
    os.path.join(path, pred_path)

    def read_prediction(pred_path):
        with open(pred_path, 'r') as fd:
            lines = fd.read().splitlines()
        lines = [x.split("\t")[1] for x in lines]
        return lines

    ##SENT##
    def read_sum(sum_path):
        with open(sum_path, 'r') as fd:
            lines = fd.read().splitlines()
        lines = [x.replace("##SENT##", " ") for x in lines]
        return lines
Beispiel #7
0
    def __init__(self,
                 context_dim,
                 dec_state_dim,
                 enc_hid_dim,
                 text_field_embedder,
                 aggressive_compression: int = -1,
                 keep_threshold: float = 0.5,
                 abs_board_file="/home/cc/exComp/board.txt",
                 gather='mean',
                 dropout=0.5,
                 dropout_emb=0.2,
                 valid_tmp_path='/scratch/cluster/jcxu/exComp',
                 serilization_name: str = "",
                 vocab=None,
                 elmo: bool = False,
                 elmo_weight: str = "elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5"):
        super().__init__()
        self.use_elmo = elmo
        self.serilization_name = serilization_name
        if elmo:
            from allennlp.modules.elmo import Elmo, batch_to_ids
            from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
            self.vocab = vocab

            options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json"
            weight_file = elmo_weight
            self.elmo = Elmo(options_file, weight_file, 1, dropout=dropout_emb)
            # print(self.elmo.get_output_dim())
            # self.word_emb_dim = text_field_embedder.get_output_dim()
            # self._context_layer = PytorchSeq2SeqWrapper(
            #     torch.nn.LSTM(self.word_emb_dim + self.elmo.get_output_dim(), self.word_emb_dim,
            #                   batch_first=True, bidirectional=True))
            self.word_emb_dim = self.elmo.get_output_dim()
        else:
            self._text_field_embedder = text_field_embedder
            self.word_emb_dim = text_field_embedder.get_output_dim()

        self.XEloss = torch.nn.CrossEntropyLoss(reduction='none')
        self.device = get_device()

        # self.rouge_metrics_compression = RougeStrEvaluation(name='cp', path_to_valid=valid_tmp_path,
        #                                                     writting_address=valid_tmp_path,
        #                                                     serilization_name=serilization_name)
        # self.rouge_metrics_compression_best_possible = RougeStrEvaluation(name='cp_ub', path_to_valid=valid_tmp_path,
        #                                                                   writting_address=valid_tmp_path,
        #                                                                   serilization_name=serilization_name)
        self.enc = EncCompression(inp_dim=self.word_emb_dim, hid_dim=enc_hid_dim, gather=gather)  # TODO dropout

        self.aggressive_compression = aggressive_compression
        self.relu = torch.nn.ReLU()

        self.attn = NewAttention(enc_dim=self.enc.get_output_dim(),
                                 dec_dim=self.enc.get_output_dim_unit() * 2 + dec_state_dim)

        self.concat_size = self.enc.get_output_dim() + self.enc.get_output_dim_unit() * 2 + dec_state_dim
        self.valid_tmp_path = valid_tmp_path
        if self.aggressive_compression < 0:
            self.XELoss = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=-1)
            # self.nn_lin = torch.nn.Linear(self.concat_size, self.concat_size)
            # self.nn_lin2 = torch.nn.Linear(self.concat_size, 2)

            self.ff = FeedForward(input_dim=self.concat_size, num_layers=3,
                                  hidden_dims=[self.concat_size, self.concat_size, 2],
                                  activations=[torch.nn.Tanh(), torch.nn.Tanh(), lambda x: x],
                                  dropout=dropout
                                  )
            # Keep thresold

            # self.keep_thres = list(np.arange(start=0.2, stop=0.6, step=0.075))
            self.keep_thres = [0.0, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 1.0]
            self.rouge_metrics_compression_dict = OrderedDict()
            for thres in self.keep_thres:
                self.rouge_metrics_compression_dict["{}".format(thres)] = RougeStrEvaluation(name='cp_{}'.format(thres),
                                                                                             path_to_valid=valid_tmp_path,
                                                                                             writting_address=valid_tmp_path,
                                                                                             serilization_name=serilization_name)
from neusum.evaluation.rouge_with_pythonrouge import RougeStrEvaluation
import os

r = RougeStrEvaluation(name="test")


def read_story(file):
    with open(file, 'r') as fd:
        liens = fd.read().splitlines()
    d = {}
    for l in liens:
        part, uid, doc, abs = l.split("\t")
        abs_list = abs.split("<SPLIT>")
        d[uid] = abs_list
    return d


def test_xxz():
    test = "dm"

    root = "/backup3/jcxu/data"

    xxz = root + "/xxz-latent/xxz-output"

    if test == 'cnndm':
        files = [x for x in os.listdir(xxz)]
        d_abs_cnn = read_story(os.path.join(root, "sent_cnn.txt"))
        d_abs_dm = read_story(os.path.join(root, "sent_dm.txt"))
        d_abs = {**d_abs_cnn, **d_abs_dm}
    else: