Пример #1
0
 def get_metric(self, reset=True):
     logger.info("[INFO] Hyps and Refer number is %d, %d", len(self.hyps), len(self.refers))
     if len(self.hyps) == 0 or len(self.refers) == 0:
         logger.error("During testing, no hyps or refers is selected!")
         return
     if isinstance(self.refers[0], list):
         logger.info("Multi Reference summaries!")
         scores_all = pyrouge_score_all_multi(self.hyps, self.refers)
     else:
         scores_all = pyrouge_score_all(self.hyps, self.refers)
     if reset:
         self.hyps = []
         self.refers = []
     logger.info(scores_all)
     return scores_all
def run_test(model, loader, hps, limited=False):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    test_dir = os.path.join(
        hps.save_root, "test")  # make a subdir of the root dir for eval data
    eval_dir = os.path.join(hps.save_root, "eval")
    if not os.path.exists(test_dir): os.makedirs(test_dir)
    if not os.path.exists(eval_dir):
        logger.exception(
            "[Error] eval_dir %s doesn't exist. Run in train mode to create it.",
            eval_dir)
        raise Exception(
            "[Error] eval_dir %s doesn't exist. Run in train mode to create it."
            % (eval_dir))

    if hps.test_model == "evalbestmodel":
        bestmodel_load_path = os.path.join(
            eval_dir, 'bestmodel.pkl'
        )  # this is where checkpoints of best models are saved
    elif hps.test_model == "evalbestFmodel":
        bestmodel_load_path = os.path.join(eval_dir, 'bestFmodel.pkl')
    elif hps.test_model == "trainbestmodel":
        train_dir = os.path.join(hps.save_root, "train")
        bestmodel_load_path = os.path.join(train_dir, 'bestmodel.pkl')
    elif hps.test_model == "trainbestFmodel":
        train_dir = os.path.join(hps.save_root, "train")
        bestmodel_load_path = os.path.join(train_dir, 'bestFmodel.pkl')
    elif hps.test_model == "earlystop":
        train_dir = os.path.join(hps.save_root, "train")
        bestmodel_load_path = os.path.join(train_dir, 'earlystop,pkl')
    else:
        logger.error(
            "None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop"
        )
        raise ValueError(
            "None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop"
        )
    logger.info("[INFO] Restoring %s for testing...The path is %s",
                hps.test_model, bestmodel_load_path)

    modelloader = ModelLoader()
    modelloader.load_pytorch(model, bestmodel_load_path)

    import datetime
    nowTime = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')  #现在
    if hps.save_label:
        log_dir = os.path.join(test_dir, hps.data_path.split("/")[-1])
        resfile = open(log_dir, "w")
    else:
        log_dir = os.path.join(test_dir, nowTime)
        resfile = open(log_dir, "wb")
    logger.info("[INFO] Write the Evaluation into %s", log_dir)

    model.eval()

    match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0
    total_example_num = 0.0
    pairs = {}
    pairs["hyps"] = []
    pairs["refer"] = []
    pred_list = []
    iter_start_time = time.time()
    with torch.no_grad():
        for i, (batch_x, batch_y) in enumerate(loader):

            input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN]
            label = batch_y[Const.TARGET]

            if hps.cuda:
                input = input.cuda()  # [batch, N, seq_len]
                label = label.cuda()
                input_len = input_len.cuda()

            batch_size, N, _ = input.size()

            input = Variable(input)
            input_len = Variable(input_len, requires_grad=False)

            model_outputs = model.forward(input, input_len)  # [batch, N, 2]
            prediction = model_outputs["prediction"]

            if hps.save_label:
                pred_list.extend(
                    model_outputs["pred_idx"].data.cpu().view(-1).tolist())
                continue

            pred += prediction.sum()
            true += label.sum()
            match_true += ((prediction == label) & (prediction == 1)).sum()
            match += (prediction == label).sum()
            total_example_num += batch_size * N

            for j in range(batch_size):
                original_article_sents = batch_x["text"][j]
                sent_max_number = len(original_article_sents)
                refer = "\n".join(batch_x["summary"][j])
                hyps = "\n".join(
                    original_article_sents[id].replace("\n", "")
                    for id in range(len(prediction[j]))
                    if prediction[j][id] == 1 and id < sent_max_number)
                if limited:
                    k = len(refer.split())
                    hyps = " ".join(hyps.split()[:k])
                    logger.info((len(refer.split()), len(hyps.split())))
                resfile.write(b"Original_article:")
                resfile.write("\n".join(batch_x["text"][j]).encode('utf-8'))
                resfile.write(b"\n")
                resfile.write(b"Reference:")
                if isinstance(refer, list):
                    for ref in refer:
                        resfile.write(ref.encode('utf-8'))
                        resfile.write(b"\n")
                        resfile.write(b'*' * 40)
                        resfile.write(b"\n")
                else:
                    resfile.write(refer.encode('utf-8'))
                resfile.write(b"\n")
                resfile.write(b"hypothesis:")
                resfile.write(hyps.encode('utf-8'))
                resfile.write(b"\n")

                if hps.use_pyrouge:
                    pairs["hyps"].append(hyps)
                    pairs["refer"].append(refer)
                else:
                    try:
                        scores = utils.rouge_all(hyps, refer)
                        pairs["hyps"].append(hyps)
                        pairs["refer"].append(refer)
                    except ValueError:
                        logger.error("Do not select any sentences!")
                        logger.debug("sent_max_number:%d", sent_max_number)
                        logger.debug(original_article_sents)
                        logger.debug("label:")
                        logger.debug(label[j])
                        continue

                    # single example res writer
                    res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f']) \
                            + "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f']) \
                                + "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'])

                    resfile.write(res.encode('utf-8'))
                resfile.write(b'-' * 89)
                resfile.write(b"\n")

    if hps.save_label:
        import json
        json.dump(pred_list, resfile)
        logger.info('   | end of test | time: {:5.2f}s | '.format(
            (time.time() - iter_start_time)))
        return

    resfile.write(b"\n")
    resfile.write(b'=' * 89)
    resfile.write(b"\n")

    if hps.use_pyrouge:
        logger.info("The number of pairs is %d", len(pairs["hyps"]))
        if not len(pairs["hyps"]):
            logger.error("During testing, no hyps is selected!")
            return
        if isinstance(pairs["refer"][0], list):
            logger.info("Multi Reference summaries!")
            scores_all = utils.pyrouge_score_all_multi(pairs["hyps"],
                                                       pairs["refer"])
        else:
            scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"])
    else:
        logger.info("The number of pairs is %d", len(pairs["hyps"]))
        if not len(pairs["hyps"]):
            logger.error("During testing, no hyps is selected!")
            return
        rouge = Rouge()
        scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True)

    # the whole model res writer
    resfile.write(b"The total testset is:")
    res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \
            + "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \
                + "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f'])
    resfile.write(res.encode("utf-8"))
    logger.info(res)
    logger.info('   | end of test | time: {:5.2f}s | '.format(
        (time.time() - iter_start_time)))

    # label prediction
    logger.info("match_true %d, pred %d, true %d, total %d, match %d", match,
                pred, true, total_example_num, match)
    accu, precision, recall, F = utils.eval_label(match_true, pred, true,
                                                  total_example_num, match)
    res = "The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f" % (
        total_example_num / hps.doc_max_timesteps, accu, precision, recall, F)
    resfile.write(res.encode('utf-8'))
    logger.info(
        "The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f",
        len(loader), accu, precision, recall, F)
def run_eval(model, loader, hps, best_loss, best_F, non_descent_cnt):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    logger.info("[INFO] Starting eval for this model ...")
    eval_dir = os.path.join(
        hps.save_root, "eval")  # make a subdir of the root dir for eval data
    if not os.path.exists(eval_dir): os.makedirs(eval_dir)

    model.eval()

    running_loss = 0.0
    match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0
    pairs = {}
    pairs["hyps"] = []
    pairs["refer"] = []
    total_example_num = 0
    criterion = torch.nn.CrossEntropyLoss(reduction='none')
    iter_start_time = time.time()

    with torch.no_grad():
        for i, (batch_x, batch_y) in enumerate(loader):
            # if i > 10:
            #     break

            input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN]
            label = batch_y[Const.TARGET]

            if hps.cuda:
                input = input.cuda()  # [batch, N, seq_len]
                label = label.cuda()
                input_len = input_len.cuda()

            batch_size, N, _ = input.size()

            input = Variable(input, requires_grad=False)
            label = Variable(label)
            input_len = Variable(input_len, requires_grad=False)

            model_outputs = model.forward(input, input_len)  # [batch, N, 2]
            outputs = model_outputs["p_sent"]
            prediction = model_outputs["prediction"]

            outputs = outputs.view(-1, 2)  # [batch * N, 2]
            label = label.view(-1)  # [batch * N]
            loss = criterion(outputs, label)
            loss = loss.view(batch_size, -1)
            loss = loss.masked_fill(input_len.eq(0), 0)
            loss = loss.sum(1).mean()
            logger.debug("loss %f", loss)
            running_loss += float(loss.data)

            label = label.data.view(batch_size, -1)
            pred += prediction.sum()
            true += label.sum()
            match_true += ((prediction == label) & (prediction == 1)).sum()
            match += (prediction == label).sum()
            total_example_num += batch_size * N

            # rouge
            prediction = prediction.view(batch_size, -1)
            for j in range(batch_size):
                original_article_sents = batch_x["text"][j]
                sent_max_number = len(original_article_sents)
                refer = "\n".join(batch_x["summary"][j])
                hyps = "\n".join(
                    original_article_sents[id]
                    for id in range(len(prediction[j]))
                    if prediction[j][id] == 1 and id < sent_max_number)
                if sent_max_number < hps.m and len(hyps) <= 1:
                    logger.error("sent_max_number is too short %d, Skip!",
                                 sent_max_number)
                    continue

                if len(hyps) >= 1 and hyps != '.':
                    # logger.debug(prediction[j])
                    pairs["hyps"].append(hyps)
                    pairs["refer"].append(refer)
                elif refer == "." or refer == "":
                    logger.error("Refer is None!")
                    logger.debug("label:")
                    logger.debug(label[j])
                    logger.debug(refer)
                elif hyps == "." or hyps == "":
                    logger.error("hyps is None!")
                    logger.debug("sent_max_number:%d", sent_max_number)
                    logger.debug("prediction:")
                    logger.debug(prediction[j])
                    logger.debug(hyps)
                else:
                    logger.error("Do not select any sentences!")
                    logger.debug("sent_max_number:%d", sent_max_number)
                    logger.debug(original_article_sents)
                    logger.debug("label:")
                    logger.debug(label[j])
                    continue

    running_avg_loss = running_loss / len(loader)

    if hps.use_pyrouge:
        logger.info("The number of pairs is %d", len(pairs["hyps"]))
        logging.getLogger('global').setLevel(logging.WARNING)
        if not len(pairs["hyps"]):
            logger.error("During testing, no hyps is selected!")
            return
        if isinstance(pairs["refer"][0], list):
            logger.info("Multi Reference summaries!")
            scores_all = utils.pyrouge_score_all_multi(pairs["hyps"],
                                                       pairs["refer"])
        else:
            scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"])
    else:
        if len(pairs["hyps"]) == 0 or len(pairs["refer"]) == 0:
            logger.error("During testing, no hyps is selected!")
            return
        rouge = Rouge()
        scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True)
        # try:
        #     scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True)
        # except ValueError as e:
        #     logger.error(repr(e))
        #     scores_all = []
        #     for idx in range(len(pairs["hyps"])):
        #         try:
        #             scores = rouge.get_scores(pairs["hyps"][idx], pairs["refer"][idx])[0]
        #             scores_all.append(scores)
        #         except ValueError as e:
        #             logger.error(repr(e))
        #             logger.debug("HYPS:\t%s", pairs["hyps"][idx])
        #             logger.debug("REFER:\t%s", pairs["refer"][idx])
        # finally:
        #     logger.error("During testing, some errors happen!")
        #     logger.error(len(scores_all))
        #     exit(1)

    logger.info(
        '[INFO] End of valid | time: {:5.2f}s | valid loss {:5.4f} | '.format(
            (time.time() - iter_start_time), float(running_avg_loss)))

    logger.info(
        "[INFO] Validset match_true %d, pred %d, true %d, total %d, match %d",
        match_true, pred, true, total_example_num, match)
    accu, precision, recall, F = utils.eval_label(match_true, pred, true,
                                                  total_example_num, match)
    logger.info(
        "[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f",
        total_example_num / hps.doc_max_timesteps, accu, precision, recall, F)

    res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \
            + "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \
                + "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f'])
    logger.info(res)

    # If running_avg_loss is best so far, save this checkpoint (early stopping).
    # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
    if best_loss is None or running_avg_loss < best_loss:
        bestmodel_save_path = os.path.join(
            eval_dir, 'bestmodel.pkl'
        )  # this is where checkpoints of best models are saved
        if best_loss is not None:
            logger.info(
                '[INFO] Found new best model with %.6f running_avg_loss. The original loss is %.6f, Saving to %s',
                float(running_avg_loss), float(best_loss), bestmodel_save_path)
        else:
            logger.info(
                '[INFO] Found new best model with %.6f running_avg_loss. The original loss is None, Saving to %s',
                float(running_avg_loss), bestmodel_save_path)
        saver = ModelSaver(bestmodel_save_path)
        saver.save_pytorch(model)
        best_loss = running_avg_loss
        non_descent_cnt = 0
    else:
        non_descent_cnt += 1

    if best_F is None or best_F < F:
        bestmodel_save_path = os.path.join(
            eval_dir, 'bestFmodel.pkl'
        )  # this is where checkpoints of best models are saved
        if best_F is not None:
            logger.info(
                '[INFO] Found new best model with %.6f F. The original F is %.6f, Saving to %s',
                float(F), float(best_F), bestmodel_save_path)
        else:
            logger.info(
                '[INFO] Found new best model with %.6f F. The original loss is None, Saving to %s',
                float(F), bestmodel_save_path)
        saver = ModelSaver(bestmodel_save_path)
        saver.save_pytorch(model)
        best_F = F

    return best_loss, best_F, non_descent_cnt
Пример #4
0
def run_test(model, dataset, loader, model_name, hps):
    test_dir = os.path.join(
        hps.save_root, "test")  # make a subdir of the root dir for eval data
    eval_dir = os.path.join(hps.save_root, "eval")
    if not os.path.exists(test_dir): os.makedirs(test_dir)
    if not os.path.exists(eval_dir):
        logger.exception(
            "[Error] eval_dir %s doesn't exist. Run in train mode to create it.",
            eval_dir)
        raise Exception(
            "[Error] eval_dir %s doesn't exist. Run in train mode to create it."
            % (eval_dir))

    resfile = None
    if hps.save_label:
        log_dir = os.path.join(test_dir, hps.cache_dir.split("/")[-1])
        resfile = open(log_dir, "w")
        logger.info("[INFO] Write the Evaluation into %s", log_dir)

    model = load_test_model(model, model_name, eval_dir, hps.save_root)
    model.eval()

    iter_start_time = time.time()
    with torch.no_grad():
        logger.info("[Model] Sequence Labeling!")
        tester = SLTester(model, hps.m, limited=hps.limited, test_dir=test_dir)

        for i, (G, index) in enumerate(loader):
            if hps.cuda:
                G.to(torch.device("cuda"))
            tester.evaluation(G, index, dataset, blocking=hps.blocking)

    running_avg_loss = tester.running_avg_loss

    if hps.save_label:
        # save label and do not calculate rouge
        json.dump(tester.extractLabel, resfile)
        tester.SaveDecodeFile()
        logger.info('   | end of test | time: {:5.2f}s | '.format(
            (time.time() - iter_start_time)))
        return

    logger.info("The number of pairs is %d", tester.rougePairNum)
    if not tester.rougePairNum:
        logger.error("During testing, no hyps is selected!")
        sys.exit(1)

    if hps.use_pyrouge:
        if isinstance(tester.refer[0], list):
            logger.info("Multi Reference summaries!")
            scores_all = utils.pyrouge_score_all_multi(tester.hyps,
                                                       tester.refer)
        else:
            scores_all = utils.pyrouge_score_all(tester.hyps, tester.refer)
    else:
        rouge = Rouge()
        scores_all = rouge.get_scores(tester.hyps, tester.refer, avg=True)

    res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \
            + "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \
                + "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f'])
    logger.info(res)

    tester.getMetric()
    tester.SaveDecodeFile()
    logger.info(
        '[INFO] End of test | time: {:5.2f}s | test loss {:5.4f} | '.format(
            (time.time() - iter_start_time), float(running_avg_loss)))