示例#1
0
def test():
    saved_model_path = './model.pth'  # lightning_logs/version_N/checkpoints/* 最优模型软链接
    model = BertClassifier.load_from_checkpoint(saved_model_path)
    model.eval()
    print(model)
    trainer = Trainer(gpus=1)
    result = trainer.test(model)
    print(result)
示例#2
0
    model_path = str(list(exp_dir.glob("**/*.ckpt"))[0])
    hparams_path = str(exp_dir / "hparams.yaml")

    # load model
    model = Distiller.load_from_checkpoint(
        model_path,
        hparams_file=hparams_path,
        map_location=device,
    )
    model.eval()
    model.freeze()
    return model


def load_data(data_dir, num_workers, hparams):
    hparams.dataset.path = data_dir
    hparams.dataset.on_memory = True
    hparams.dataset.num_workers = num_workers
    return TripleEmbeddingDataModule(hparams)


if __name__ == "__main__":
    model = load_model(args.exp_dir)
    dm = load_data(args.data_dir, args.num_workers, model.hparams)
    # dm.prepare_data()
    dm.setup("test")

    # trainer for test
    trainer = Trainer()
    trainer.test(model, datamodule=dm)