コード例 #1
0
 def getMetric(self):
     logger.info(
         "[INFO] Validset match_true %d, pred %d, true %d, total %d, match %d",
         self.match_true, self.pred, self.true, self.total_sentence_num,
         self.match)
     self._accu, self._precision, self._recall, self._F = eval_label(
         self.match_true, self.pred, self.true, self.total_sentence_num,
         self.match)
     logger.info(
         "[INFO] The size of totalset is %d, sent_number is %d, accu is %f, precision is %f, recall is %f, F is %f",
         self.example_num, self.total_sentence_num, self._accu,
         self._precision, self._recall, self._F)
コード例 #2
0
def run_training(model, train_loader, valid_loader, hps):
    """Repeatedly runs training iterations, logging loss to screen and writing summaries"""
    logger.info("[INFO] Starting run_training")

    train_dir = os.path.join(hps.save_root, "train")
    if not os.path.exists(train_dir): os.makedirs(train_dir)

    lr = hps.lr
    # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, betas=(0.9, 0.98),
    # eps=1e-09)
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=lr)
    criterion = torch.nn.CrossEntropyLoss(reduction='none')

    best_train_loss = None
    best_train_F = None
    best_loss = None
    best_F = None
    step_num = 0
    non_descent_cnt = 0
    for epoch in range(1, hps.n_epochs + 1):
        epoch_loss = 0.0
        train_loss = 0.0
        total_example_num = 0
        match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0
        epoch_start_time = time.time()
        for i, (batch_x, batch_y) in enumerate(train_loader):
            # if i > 10:
            #     break
            model.train()

            iter_start_time = time.time()

            input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN]
            label = batch_y[Const.TARGET]

            # logger.info(batch_x["text"][0])
            # logger.info(input[0,:,:])
            # logger.info(input_len[0:5,:])
            # logger.info(batch_y["summary"][0:5])
            # logger.info(label[0:5,:])

            # logger.info((len(batch_x["text"][0]), sum(input[0].sum(-1) != 0)))

            batch_size, N, seq_len = input.size()

            if hps.cuda:
                input = input.cuda()  # [batch, N, seq_len]
                label = label.cuda()
                input_len = input_len.cuda()

            input = Variable(input)
            label = Variable(label)
            input_len = Variable(input_len)

            model_outputs = model.forward(input, input_len)  # [batch, N, 2]

            outputs = model_outputs["p_sent"].view(-1, 2)

            label = label.view(-1)

            loss = criterion(outputs, label)  # [batch_size, doc_max_timesteps]
            # input_len = input_len.float().view(-1)
            loss = loss.view(batch_size, -1)
            loss = loss.masked_fill(input_len.eq(0), 0)
            loss = loss.sum(1).mean()
            logger.debug("loss %f", loss)

            if not (np.isfinite(loss.data)).numpy():
                logger.error("train Loss is not finite. Stopping.")
                logger.info(loss)
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        logger.info(name)
                        logger.info(param.grad.data.sum())
                raise Exception("train Loss is not finite. Stopping.")

            optimizer.zero_grad()
            loss.backward()
            if hps.grad_clip:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               hps.max_grad_norm)

            optimizer.step()
            step_num += 1

            train_loss += float(loss.data)
            epoch_loss += float(loss.data)

            if i % 100 == 0:
                # start debugger
                # import pdb; pdb.set_trace()
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        logger.debug(name)
                        logger.debug(param.grad.data.sum())
                logger.info(
                    '       | end of iter {:3d} | time: {:5.2f}s | train loss {:5.4f} | '
                    .format(i, (time.time() - iter_start_time),
                            float(train_loss / 100)))
                train_loss = 0.0

            # calculate the precision, recall and F
            prediction = outputs.max(1)[1]
            prediction = prediction.data
            label = label.data
            pred += prediction.sum()
            true += label.sum()
            match_true += ((prediction == label) & (prediction == 1)).sum()
            match += (prediction == label).sum()
            total_example_num += int(batch_size * N)

        if hps.lr_descent:
            # new_lr = pow(hps.hidden_size, -0.5) * min(pow(step_num, -0.5),
            #                                           step_num * pow(hps.warmup_steps, -1.5))
            new_lr = max(5e-6, lr / (epoch + 1))
            for param_group in list(optimizer.param_groups):
                param_group['lr'] = new_lr
            logger.info("[INFO] The learning rate now is %f", new_lr)

        epoch_avg_loss = epoch_loss / len(train_loader)
        logger.info(
            '   | end of epoch {:3d} | time: {:5.2f}s | epoch train loss {:5.4f} | '
            .format(epoch, (time.time() - epoch_start_time),
                    float(epoch_avg_loss)))

        logger.info(
            "[INFO] Trainset match_true %d, pred %d, true %d, total %d, match %d",
            match_true, pred, true, total_example_num, match)
        accu, precision, recall, F = utils.eval_label(match_true, pred, true,
                                                      total_example_num, match)
        logger.info(
            "[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f",
            total_example_num / hps.doc_max_timesteps, accu, precision, recall,
            F)

        if not best_train_loss or epoch_avg_loss < best_train_loss:
            save_file = os.path.join(train_dir, "bestmodel.pkl")
            logger.info(
                '[INFO] Found new best model with %.3f running_train_loss. Saving to %s',
                float(epoch_avg_loss), save_file)
            saver = ModelSaver(save_file)
            saver.save_pytorch(model)
            best_train_loss = epoch_avg_loss
        elif epoch_avg_loss > best_train_loss:
            logger.error(
                "[Error] training loss does not descent. Stopping supervisor..."
            )
            save_file = os.path.join(train_dir, "earlystop.pkl")
            saver = ModelSaver(save_file)
            saver.save_pytorch(model)
            logger.info('[INFO] Saving early stop model to %s', save_file)
            return

        if not best_train_F or F > best_train_F:
            save_file = os.path.join(train_dir, "bestFmodel.pkl")
            logger.info(
                '[INFO] Found new best model with %.3f F score. Saving to %s',
                float(F), save_file)
            saver = ModelSaver(save_file)
            saver.save_pytorch(model)
            best_train_F = F

        best_loss, best_F, non_descent_cnt = run_eval(model, valid_loader, hps,
                                                      best_loss, best_F,
                                                      non_descent_cnt)

        if non_descent_cnt >= 3:
            logger.error(
                "[Error] val loss does not descent for three times. Stopping supervisor..."
            )
            save_file = os.path.join(train_dir, "earlystop")
            saver = ModelSaver(save_file)
            saver.save_pytorch(model)
            logger.info('[INFO] Saving early stop model to %s', save_file)
            return
コード例 #3
0
def run_test(model, loader, hps, limited=False):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    test_dir = os.path.join(
        hps.save_root, "test")  # make a subdir of the root dir for eval data
    eval_dir = os.path.join(hps.save_root, "eval")
    if not os.path.exists(test_dir): os.makedirs(test_dir)
    if not os.path.exists(eval_dir):
        logger.exception(
            "[Error] eval_dir %s doesn't exist. Run in train mode to create it.",
            eval_dir)
        raise Exception(
            "[Error] eval_dir %s doesn't exist. Run in train mode to create it."
            % (eval_dir))

    if hps.test_model == "evalbestmodel":
        bestmodel_load_path = os.path.join(
            eval_dir, 'bestmodel.pkl'
        )  # this is where checkpoints of best models are saved
    elif hps.test_model == "evalbestFmodel":
        bestmodel_load_path = os.path.join(eval_dir, 'bestFmodel.pkl')
    elif hps.test_model == "trainbestmodel":
        train_dir = os.path.join(hps.save_root, "train")
        bestmodel_load_path = os.path.join(train_dir, 'bestmodel.pkl')
    elif hps.test_model == "trainbestFmodel":
        train_dir = os.path.join(hps.save_root, "train")
        bestmodel_load_path = os.path.join(train_dir, 'bestFmodel.pkl')
    elif hps.test_model == "earlystop":
        train_dir = os.path.join(hps.save_root, "train")
        bestmodel_load_path = os.path.join(train_dir, 'earlystop,pkl')
    else:
        logger.error(
            "None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop"
        )
        raise ValueError(
            "None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop"
        )
    logger.info("[INFO] Restoring %s for testing...The path is %s",
                hps.test_model, bestmodel_load_path)

    modelloader = ModelLoader()
    modelloader.load_pytorch(model, bestmodel_load_path)

    import datetime
    nowTime = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')  #现在
    if hps.save_label:
        log_dir = os.path.join(test_dir, hps.data_path.split("/")[-1])
        resfile = open(log_dir, "w")
    else:
        log_dir = os.path.join(test_dir, nowTime)
        resfile = open(log_dir, "wb")
    logger.info("[INFO] Write the Evaluation into %s", log_dir)

    model.eval()

    match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0
    total_example_num = 0.0
    pairs = {}
    pairs["hyps"] = []
    pairs["refer"] = []
    pred_list = []
    iter_start_time = time.time()
    with torch.no_grad():
        for i, (batch_x, batch_y) in enumerate(loader):

            input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN]
            label = batch_y[Const.TARGET]

            if hps.cuda:
                input = input.cuda()  # [batch, N, seq_len]
                label = label.cuda()
                input_len = input_len.cuda()

            batch_size, N, _ = input.size()

            input = Variable(input)
            input_len = Variable(input_len, requires_grad=False)

            model_outputs = model.forward(input, input_len)  # [batch, N, 2]
            prediction = model_outputs["prediction"]

            if hps.save_label:
                pred_list.extend(
                    model_outputs["pred_idx"].data.cpu().view(-1).tolist())
                continue

            pred += prediction.sum()
            true += label.sum()
            match_true += ((prediction == label) & (prediction == 1)).sum()
            match += (prediction == label).sum()
            total_example_num += batch_size * N

            for j in range(batch_size):
                original_article_sents = batch_x["text"][j]
                sent_max_number = len(original_article_sents)
                refer = "\n".join(batch_x["summary"][j])
                hyps = "\n".join(
                    original_article_sents[id].replace("\n", "")
                    for id in range(len(prediction[j]))
                    if prediction[j][id] == 1 and id < sent_max_number)
                if limited:
                    k = len(refer.split())
                    hyps = " ".join(hyps.split()[:k])
                    logger.info((len(refer.split()), len(hyps.split())))
                resfile.write(b"Original_article:")
                resfile.write("\n".join(batch_x["text"][j]).encode('utf-8'))
                resfile.write(b"\n")
                resfile.write(b"Reference:")
                if isinstance(refer, list):
                    for ref in refer:
                        resfile.write(ref.encode('utf-8'))
                        resfile.write(b"\n")
                        resfile.write(b'*' * 40)
                        resfile.write(b"\n")
                else:
                    resfile.write(refer.encode('utf-8'))
                resfile.write(b"\n")
                resfile.write(b"hypothesis:")
                resfile.write(hyps.encode('utf-8'))
                resfile.write(b"\n")

                if hps.use_pyrouge:
                    pairs["hyps"].append(hyps)
                    pairs["refer"].append(refer)
                else:
                    try:
                        scores = utils.rouge_all(hyps, refer)
                        pairs["hyps"].append(hyps)
                        pairs["refer"].append(refer)
                    except ValueError:
                        logger.error("Do not select any sentences!")
                        logger.debug("sent_max_number:%d", sent_max_number)
                        logger.debug(original_article_sents)
                        logger.debug("label:")
                        logger.debug(label[j])
                        continue

                    # single example res writer
                    res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f']) \
                            + "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f']) \
                                + "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'])

                    resfile.write(res.encode('utf-8'))
                resfile.write(b'-' * 89)
                resfile.write(b"\n")

    if hps.save_label:
        import json
        json.dump(pred_list, resfile)
        logger.info('   | end of test | time: {:5.2f}s | '.format(
            (time.time() - iter_start_time)))
        return

    resfile.write(b"\n")
    resfile.write(b'=' * 89)
    resfile.write(b"\n")

    if hps.use_pyrouge:
        logger.info("The number of pairs is %d", len(pairs["hyps"]))
        if not len(pairs["hyps"]):
            logger.error("During testing, no hyps is selected!")
            return
        if isinstance(pairs["refer"][0], list):
            logger.info("Multi Reference summaries!")
            scores_all = utils.pyrouge_score_all_multi(pairs["hyps"],
                                                       pairs["refer"])
        else:
            scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"])
    else:
        logger.info("The number of pairs is %d", len(pairs["hyps"]))
        if not len(pairs["hyps"]):
            logger.error("During testing, no hyps is selected!")
            return
        rouge = Rouge()
        scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True)

    # the whole model res writer
    resfile.write(b"The total testset is:")
    res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \
            + "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \
                + "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f'])
    resfile.write(res.encode("utf-8"))
    logger.info(res)
    logger.info('   | end of test | time: {:5.2f}s | '.format(
        (time.time() - iter_start_time)))

    # label prediction
    logger.info("match_true %d, pred %d, true %d, total %d, match %d", match,
                pred, true, total_example_num, match)
    accu, precision, recall, F = utils.eval_label(match_true, pred, true,
                                                  total_example_num, match)
    res = "The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f" % (
        total_example_num / hps.doc_max_timesteps, accu, precision, recall, F)
    resfile.write(res.encode('utf-8'))
    logger.info(
        "The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f",
        len(loader), accu, precision, recall, F)
コード例 #4
0
def run_eval(model, loader, hps, best_loss, best_F, non_descent_cnt):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    logger.info("[INFO] Starting eval for this model ...")
    eval_dir = os.path.join(
        hps.save_root, "eval")  # make a subdir of the root dir for eval data
    if not os.path.exists(eval_dir): os.makedirs(eval_dir)

    model.eval()

    running_loss = 0.0
    match, pred, true, match_true = 0.0, 0.0, 0.0, 0.0
    pairs = {}
    pairs["hyps"] = []
    pairs["refer"] = []
    total_example_num = 0
    criterion = torch.nn.CrossEntropyLoss(reduction='none')
    iter_start_time = time.time()

    with torch.no_grad():
        for i, (batch_x, batch_y) in enumerate(loader):
            # if i > 10:
            #     break

            input, input_len = batch_x[Const.INPUT], batch_x[Const.INPUT_LEN]
            label = batch_y[Const.TARGET]

            if hps.cuda:
                input = input.cuda()  # [batch, N, seq_len]
                label = label.cuda()
                input_len = input_len.cuda()

            batch_size, N, _ = input.size()

            input = Variable(input, requires_grad=False)
            label = Variable(label)
            input_len = Variable(input_len, requires_grad=False)

            model_outputs = model.forward(input, input_len)  # [batch, N, 2]
            outputs = model_outputs["p_sent"]
            prediction = model_outputs["prediction"]

            outputs = outputs.view(-1, 2)  # [batch * N, 2]
            label = label.view(-1)  # [batch * N]
            loss = criterion(outputs, label)
            loss = loss.view(batch_size, -1)
            loss = loss.masked_fill(input_len.eq(0), 0)
            loss = loss.sum(1).mean()
            logger.debug("loss %f", loss)
            running_loss += float(loss.data)

            label = label.data.view(batch_size, -1)
            pred += prediction.sum()
            true += label.sum()
            match_true += ((prediction == label) & (prediction == 1)).sum()
            match += (prediction == label).sum()
            total_example_num += batch_size * N

            # rouge
            prediction = prediction.view(batch_size, -1)
            for j in range(batch_size):
                original_article_sents = batch_x["text"][j]
                sent_max_number = len(original_article_sents)
                refer = "\n".join(batch_x["summary"][j])
                hyps = "\n".join(
                    original_article_sents[id]
                    for id in range(len(prediction[j]))
                    if prediction[j][id] == 1 and id < sent_max_number)
                if sent_max_number < hps.m and len(hyps) <= 1:
                    logger.error("sent_max_number is too short %d, Skip!",
                                 sent_max_number)
                    continue

                if len(hyps) >= 1 and hyps != '.':
                    # logger.debug(prediction[j])
                    pairs["hyps"].append(hyps)
                    pairs["refer"].append(refer)
                elif refer == "." or refer == "":
                    logger.error("Refer is None!")
                    logger.debug("label:")
                    logger.debug(label[j])
                    logger.debug(refer)
                elif hyps == "." or hyps == "":
                    logger.error("hyps is None!")
                    logger.debug("sent_max_number:%d", sent_max_number)
                    logger.debug("prediction:")
                    logger.debug(prediction[j])
                    logger.debug(hyps)
                else:
                    logger.error("Do not select any sentences!")
                    logger.debug("sent_max_number:%d", sent_max_number)
                    logger.debug(original_article_sents)
                    logger.debug("label:")
                    logger.debug(label[j])
                    continue

    running_avg_loss = running_loss / len(loader)

    if hps.use_pyrouge:
        logger.info("The number of pairs is %d", len(pairs["hyps"]))
        logging.getLogger('global').setLevel(logging.WARNING)
        if not len(pairs["hyps"]):
            logger.error("During testing, no hyps is selected!")
            return
        if isinstance(pairs["refer"][0], list):
            logger.info("Multi Reference summaries!")
            scores_all = utils.pyrouge_score_all_multi(pairs["hyps"],
                                                       pairs["refer"])
        else:
            scores_all = utils.pyrouge_score_all(pairs["hyps"], pairs["refer"])
    else:
        if len(pairs["hyps"]) == 0 or len(pairs["refer"]) == 0:
            logger.error("During testing, no hyps is selected!")
            return
        rouge = Rouge()
        scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True)
        # try:
        #     scores_all = rouge.get_scores(pairs["hyps"], pairs["refer"], avg=True)
        # except ValueError as e:
        #     logger.error(repr(e))
        #     scores_all = []
        #     for idx in range(len(pairs["hyps"])):
        #         try:
        #             scores = rouge.get_scores(pairs["hyps"][idx], pairs["refer"][idx])[0]
        #             scores_all.append(scores)
        #         except ValueError as e:
        #             logger.error(repr(e))
        #             logger.debug("HYPS:\t%s", pairs["hyps"][idx])
        #             logger.debug("REFER:\t%s", pairs["refer"][idx])
        # finally:
        #     logger.error("During testing, some errors happen!")
        #     logger.error(len(scores_all))
        #     exit(1)

    logger.info(
        '[INFO] End of valid | time: {:5.2f}s | valid loss {:5.4f} | '.format(
            (time.time() - iter_start_time), float(running_avg_loss)))

    logger.info(
        "[INFO] Validset match_true %d, pred %d, true %d, total %d, match %d",
        match_true, pred, true, total_example_num, match)
    accu, precision, recall, F = utils.eval_label(match_true, pred, true,
                                                  total_example_num, match)
    logger.info(
        "[INFO] The size of totalset is %d, accu is %f, precision is %f, recall is %f, F is %f",
        total_example_num / hps.doc_max_timesteps, accu, precision, recall, F)

    res = "Rouge1:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \
            + "Rouge2:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \
                + "Rougel:\n\tp:%.6f, r:%.6f, f:%.6f\n" % (scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f'])
    logger.info(res)

    # If running_avg_loss is best so far, save this checkpoint (early stopping).
    # These checkpoints will appear as bestmodel-<iteration_number> in the eval dir
    if best_loss is None or running_avg_loss < best_loss:
        bestmodel_save_path = os.path.join(
            eval_dir, 'bestmodel.pkl'
        )  # this is where checkpoints of best models are saved
        if best_loss is not None:
            logger.info(
                '[INFO] Found new best model with %.6f running_avg_loss. The original loss is %.6f, Saving to %s',
                float(running_avg_loss), float(best_loss), bestmodel_save_path)
        else:
            logger.info(
                '[INFO] Found new best model with %.6f running_avg_loss. The original loss is None, Saving to %s',
                float(running_avg_loss), bestmodel_save_path)
        saver = ModelSaver(bestmodel_save_path)
        saver.save_pytorch(model)
        best_loss = running_avg_loss
        non_descent_cnt = 0
    else:
        non_descent_cnt += 1

    if best_F is None or best_F < F:
        bestmodel_save_path = os.path.join(
            eval_dir, 'bestFmodel.pkl'
        )  # this is where checkpoints of best models are saved
        if best_F is not None:
            logger.info(
                '[INFO] Found new best model with %.6f F. The original F is %.6f, Saving to %s',
                float(F), float(best_F), bestmodel_save_path)
        else:
            logger.info(
                '[INFO] Found new best model with %.6f F. The original loss is None, Saving to %s',
                float(F), bestmodel_save_path)
        saver = ModelSaver(bestmodel_save_path)
        saver.save_pytorch(model)
        best_F = F

    return best_loss, best_F, non_descent_cnt