コード例 #1
0
ファイル: train.py プロジェクト: zxlzr/fastNLP
def run_test(model, loader, hps, limited=False):
    """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far."""
    test_dir = os.path.join(hps.save_root, "test") # make a subdir of the root dir for eval data
    eval_dir = os.path.join(hps.save_root, "eval")
    if not os.path.exists(test_dir) : os.makedirs(test_dir)
    if not os.path.exists(eval_dir) :
        logger.exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it.", eval_dir)
        raise Exception("[Error] eval_dir %s doesn't exist. Run in train mode to create it." % (eval_dir))

    if hps.test_model == "evalbestmodel":
        bestmodel_load_path = os.path.join(eval_dir, 'bestmodel.pkl') # this is where checkpoints of best models are saved
    elif hps.test_model == "earlystop":
        train_dir = os.path.join(hps.save_root, "train")
        bestmodel_load_path = os.path.join(train_dir, 'earlystop.pkl')
    else:
        logger.error("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop")
        raise ValueError("None of such model! Must be one of evalbestmodel/trainbestmodel/earlystop")
    logger.info("[INFO] Restoring %s for testing...The path is %s", hps.test_model, bestmodel_load_path)

    modelloader = ModelLoader()
    modelloader.load_pytorch(model, bestmodel_load_path)

    if hps.use_pyrouge:
        logger.info("[INFO] Use PyRougeMetric for testing")
        tester = Tester(data=loader, model=model,
                        metrics=[LabelFMetric(pred="prediction"), PyRougeMetric(hps, pred="prediction")],
                        batch_size=hps.batch_size)
    else:
        logger.info("[INFO] Use FastRougeMetric for testing")
        tester = Tester(data=loader, model=model,
                        metrics=[LabelFMetric(pred="prediction"), FastRougeMetric(hps, pred="prediction")],
                        batch_size=hps.batch_size)
    test_info = tester.test()
    logger.info(test_info)
コード例 #2
0
def run_training(model, train_loader, valid_loader, hps):
    logger.info("[INFO] Starting run_training")

    train_dir = os.path.join(hps.save_root, "train")
    if os.path.exists(train_dir): shutil.rmtree(train_dir)
    os.makedirs(train_dir)
    eval_dir = os.path.join(
        hps.save_root, "eval")  # make a subdir of the root dir for eval data
    if not os.path.exists(eval_dir): os.makedirs(eval_dir)

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=hps.lr)
    criterion = MyCrossEntropyLoss(pred="p_sent",
                                   target=Const.TARGET,
                                   mask=Const.INPUT_LEN,
                                   reduce='none')

    trainer = Trainer(model=model,
                      train_data=train_loader,
                      optimizer=optimizer,
                      loss=criterion,
                      n_epochs=hps.n_epochs,
                      print_every=100,
                      dev_data=valid_loader,
                      metrics=[
                          LossMetric(pred="p_sent",
                                     target=Const.TARGET,
                                     mask=Const.INPUT_LEN,
                                     reduce='none'),
                          LabelFMetric(pred="prediction"),
                          FastRougeMetric(hps, pred="prediction")
                      ],
                      metric_key="loss",
                      validate_every=-1,
                      save_path=eval_dir,
                      callbacks=[TrainCallback(hps, patience=5)],
                      use_tqdm=False)

    train_info = trainer.train(load_best_model=True)
    logger.info('   | end of Train | time: {:5.2f}s | '.format(
        train_info["seconds"]))
    logger.info('[INFO] best eval model in epoch %d and iter %d',
                train_info["best_epoch"], train_info["best_step"])
    logger.info(train_info["best_eval"])

    bestmodel_save_path = os.path.join(
        eval_dir,
        'bestmodel.pkl')  # this is where checkpoints of best models are saved
    saver = ModelSaver(bestmodel_save_path)
    saver.save_pytorch(model)
    logger.info('[INFO] Saving eval best model to %s', bestmodel_save_path)