def test_train_and_predict_main(config_paths): """ Test main in train and predict by checking it can run. :param config_paths: list of file paths for configuration. """ train_main( args=[ "--gpu", "", "--exp_name", "test_train", "--config_path", ] + config_paths ) # check output folders assert os.path.isdir("logs/test_train/save") assert os.path.isdir("logs/test_train/train") assert os.path.isdir("logs/test_train/validation") assert os.path.isfile("logs/test_train/config.yaml") predict_main( args=[ "--gpu", "", "--ckpt_path", "logs/test_train/save/ckpt-2", "--split", "test", "--exp_name", "test_predict", "--save_nifti", "--save_png", ] ) # check output folders assert os.path.isdir("logs/test_predict/test/pair_0_1/label_0") assert os.path.isdir("logs/test_predict/test/pair_0_1/label_1") assert os.path.isdir("logs/test_predict/test/pair_0_1/label_2") assert os.path.isfile("logs/test_predict/test/metrics.csv") assert os.path.isfile("logs/test_predict/test/metrics_stats_per_label.csv") assert os.path.isfile("logs/test_predict/test/metrics_stats_overall.csv") shutil.rmtree("logs/test_train") shutil.rmtree("logs/test_predict")
def test_train_and_predict_main(): """ Test main in train and predict by checking it can run. """ train_main( args=[ "--gpu", "", "--log_dir", "test_train", "--config_path", "config/unpaired_labeled_ddf.yaml", ] ) # check output folders assert os.path.isdir("logs/test_train/save") assert os.path.isdir("logs/test_train/train") assert os.path.isdir("logs/test_train/validation") assert os.path.isfile("logs/test_train/config.yaml") predict_main( args=[ "--gpu", "", "--ckpt_path", "logs/test_train/save/weights-epoch2.ckpt", "--mode", "test", "--log_dir", "test_predict", "--save_nifti", "--save_png", ] ) # check output folders assert os.path.isdir("logs/test_predict/test/pair_0_1/label_0") assert os.path.isdir("logs/test_predict/test/pair_0_1/label_1") assert os.path.isdir("logs/test_predict/test/pair_0_1/label_2") assert os.path.isfile("logs/test_predict/test/metrics.csv") assert os.path.isfile("logs/test_predict/test/metrics_stats_per_label.csv") assert os.path.isfile("logs/test_predict/test/metrics_stats_overall.csv")