def test_model_inference_train_and_test( test_output_dirs: OutputFolderForTests, perform_cross_validation: bool, perform_training_set_inference: bool) -> None: config = DummyModel() config.number_of_cross_validation_splits = 2 if perform_cross_validation else 0 config.perform_training_set_inference = perform_training_set_inference # Plotting crashes with random TCL errors on Windows, disable that for Windows PR builds. config.is_plotting_enabled = common_util.is_linux() config.set_output_to(test_output_dirs.root_dir) config.local_dataset = full_ml_test_data_path() # To make it seem like there was a training run before this, copy checkpoints into the checkpoints folder. stored_checkpoints = full_ml_test_data_path("checkpoints") shutil.copytree(str(stored_checkpoints), str(config.checkpoint_folder)) checkpoint_handler = get_default_checkpoint_handler( model_config=config, project_root=test_output_dirs.root_dir) checkpoint_handler.additional_training_done() result, _, _ = MLRunner(config).model_inference_train_and_test( checkpoint_handler=checkpoint_handler) if result is None: raise ValueError("Error result cannot be None") assert isinstance(result, InferenceMetricsForSegmentation) for key, _ in result.epochs.items(): epoch_folder_name = common_util.epoch_folder_name(key) for folder in [ ModelExecutionMode.TRAIN.value, ModelExecutionMode.VAL.value, ModelExecutionMode.TEST.value ]: results_folder = config.outputs_folder / epoch_folder_name / folder folder_exists = results_folder.is_dir() if folder in [ ModelExecutionMode.TRAIN.value, ModelExecutionMode.VAL.value ]: if perform_training_set_inference: assert folder_exists else: assert folder_exists
def test_model_inference_train_and_test( test_output_dirs: TestOutputDirectories, perform_cross_validation: bool, perform_training_set_inference: bool) -> None: config = DummyModel() config.number_of_cross_validation_splits = 2 if perform_cross_validation else 0 config.perform_training_set_inference = perform_training_set_inference # Plotting crashes with random TCL errors on Windows, disable that for Windows PR builds. config.is_plotting_enabled = common_util.is_linux() config.set_output_to(test_output_dirs.root_dir) config.local_dataset = full_ml_test_data_path() # Mimic the behaviour that checkpoints are downloaded from blob storage into the checkpoints folder. stored_checkpoints = full_ml_test_data_path("checkpoints") shutil.copytree(str(stored_checkpoints), str(config.checkpoint_folder)) result, _, _ = MLRunner(config).model_inference_train_and_test() if result is None: raise ValueError("Error result cannot be None") assert isinstance(result, InferenceMetricsForSegmentation) for key, _ in result.epochs.items(): epoch_folder_name = common_util.epoch_folder_name(key) for folder in [ ModelExecutionMode.TRAIN.value, ModelExecutionMode.VAL.value, ModelExecutionMode.TEST.value ]: results_folder = config.outputs_folder / epoch_folder_name / folder folder_exists = results_folder.is_dir() if folder in [ ModelExecutionMode.TRAIN.value, ModelExecutionMode.VAL.value ]: if perform_training_set_inference: assert folder_exists else: assert folder_exists
def test_model_inference_train_and_test( test_output_dirs: OutputFolderForTests, perform_cross_validation: bool, perform_training_set_inference: bool) -> None: config = DummyModel() config.number_of_cross_validation_splits = 2 if perform_cross_validation else 0 config.perform_training_set_inference = perform_training_set_inference # Plotting crashes with random TCL errors on Windows, disable that for Windows PR builds. config.is_plotting_enabled = common_util.is_linux() config.set_output_to(test_output_dirs.root_dir) config.local_dataset = full_ml_test_data_path() checkpoint_path = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX create_model_and_store_checkpoint(config, checkpoint_path) checkpoint_handler = get_default_checkpoint_handler( model_config=config, project_root=test_output_dirs.root_dir) checkpoint_handler.additional_training_done() result, _, _ = MLRunner(config).model_inference_train_and_test( checkpoint_handler=checkpoint_handler) if result is None: raise ValueError("Error result cannot be None") assert isinstance(result, InferenceMetricsForSegmentation) epoch_folder_name = common_util.BEST_EPOCH_FOLDER_NAME for folder in [ ModelExecutionMode.TRAIN.value, ModelExecutionMode.VAL.value, ModelExecutionMode.TEST.value ]: results_folder = config.outputs_folder / epoch_folder_name / folder folder_exists = results_folder.is_dir() if folder in [ ModelExecutionMode.TRAIN.value, ModelExecutionMode.VAL.value ]: if perform_training_set_inference: assert folder_exists else: assert folder_exists