Пример #1
0
def test_train_and_predict_main(config_paths):
    """
    Test main in train and predict by checking it can run.

    :param config_paths: list of file paths for configuration.
    """
    train_main(
        args=[
            "--gpu",
            "",
            "--exp_name",
            "test_train",
            "--config_path",
        ]
        + config_paths
    )

    # check output folders
    assert os.path.isdir("logs/test_train/save")
    assert os.path.isdir("logs/test_train/train")
    assert os.path.isdir("logs/test_train/validation")
    assert os.path.isfile("logs/test_train/config.yaml")

    predict_main(
        args=[
            "--gpu",
            "",
            "--ckpt_path",
            "logs/test_train/save/ckpt-2",
            "--split",
            "test",
            "--exp_name",
            "test_predict",
            "--save_nifti",
            "--save_png",
        ]
    )

    # check output folders
    assert os.path.isdir("logs/test_predict/test/pair_0_1/label_0")
    assert os.path.isdir("logs/test_predict/test/pair_0_1/label_1")
    assert os.path.isdir("logs/test_predict/test/pair_0_1/label_2")
    assert os.path.isfile("logs/test_predict/test/metrics.csv")
    assert os.path.isfile("logs/test_predict/test/metrics_stats_per_label.csv")
    assert os.path.isfile("logs/test_predict/test/metrics_stats_overall.csv")

    shutil.rmtree("logs/test_train")
    shutil.rmtree("logs/test_predict")
Пример #2
0
def test_train_and_predict_main():
    """
    Test main in train and predict by checking it can run.
    """
    train_main(
        args=[
            "--gpu",
            "",
            "--log_dir",
            "test_train",
            "--config_path",
            "config/unpaired_labeled_ddf.yaml",
        ]
    )

    # check output folders
    assert os.path.isdir("logs/test_train/save")
    assert os.path.isdir("logs/test_train/train")
    assert os.path.isdir("logs/test_train/validation")
    assert os.path.isfile("logs/test_train/config.yaml")

    predict_main(
        args=[
            "--gpu",
            "",
            "--ckpt_path",
            "logs/test_train/save/weights-epoch2.ckpt",
            "--mode",
            "test",
            "--log_dir",
            "test_predict",
            "--save_nifti",
            "--save_png",
        ]
    )

    # check output folders
    assert os.path.isdir("logs/test_predict/test/pair_0_1/label_0")
    assert os.path.isdir("logs/test_predict/test/pair_0_1/label_1")
    assert os.path.isdir("logs/test_predict/test/pair_0_1/label_2")
    assert os.path.isfile("logs/test_predict/test/metrics.csv")
    assert os.path.isfile("logs/test_predict/test/metrics_stats_per_label.csv")
    assert os.path.isfile("logs/test_predict/test/metrics_stats_overall.csv")