def test_get_checkpoints_to_test( test_output_dirs: OutputFolderForTests) -> None: config = ModelConfigBase(should_validate=False) config.set_output_to(test_output_dirs.root_dir) config.outputs_folder.mkdir() manage_recovery = get_default_checkpoint_handler( model_config=config, project_root=test_output_dirs.root_dir) # Set a local_weights_path to get checkpoint from. Model has not trained and no run recovery provided, # so the local weights should be used ignoring any epochs to test config.epochs_to_test = [1, 2] local_weights_path = test_output_dirs.root_dir / "exist.pth" stored_checkpoint = create_checkpoint_path( full_ml_test_data_path("checkpoints"), epoch=1) shutil.copyfile(str(stored_checkpoint), local_weights_path) config.local_weights_path = local_weights_path manage_recovery.discover_and_download_checkpoints_from_previous_runs() checkpoint_and_paths = manage_recovery.get_checkpoints_to_test() assert checkpoint_and_paths assert len(checkpoint_and_paths) == 1 assert checkpoint_and_paths[0].epoch == 0 assert checkpoint_and_paths[0].checkpoint_paths == [ manage_recovery.model_config.outputs_folder / WEIGHTS_FILE ] # Now set a run recovery object and set the start epoch to 1, so we get one epoch from # run recovery and one from the training checkpoints manage_recovery.azure_config.run_recovery_id = DEFAULT_RUN_RECOVERY_ID config.start_epoch = 1 manage_recovery.additional_training_done() manage_recovery.discover_and_download_checkpoints_from_previous_runs() # Copy checkpoint to make it seem like training has happened stored_checkpoint = create_checkpoint_path( path=full_ml_test_data_path("checkpoints"), epoch=1) expected_checkpoint = create_checkpoint_path(path=config.checkpoint_folder, epoch=2) shutil.copyfile(str(stored_checkpoint), str(expected_checkpoint)) checkpoint_and_paths = manage_recovery.get_checkpoints_to_test() assert checkpoint_and_paths assert len(checkpoint_and_paths) == 2 assert checkpoint_and_paths[0].epoch == 1 assert checkpoint_and_paths[0].checkpoint_paths == [ create_checkpoint_path(path=config.checkpoint_folder / DEFAULT_RUN_RECOVERY_ID.split(":")[1], epoch=1) ] assert checkpoint_and_paths[1].epoch == 2 assert checkpoint_and_paths[1].checkpoint_paths == [ create_checkpoint_path(path=config.checkpoint_folder, epoch=2) ] # This epoch does not exist config.epochs_to_test = [3] checkpoint_and_paths = manage_recovery.get_checkpoints_to_test() assert checkpoint_and_paths is None
def test_get_recovery_path_train( test_output_dirs: OutputFolderForTests) -> None: config = ModelConfigBase(should_validate=False) config.set_output_to(test_output_dirs.root_dir) config.outputs_folder.mkdir() checkpoint_handler = get_default_checkpoint_handler( model_config=config, project_root=test_output_dirs.root_dir) assert checkpoint_handler.get_recovery_path_train() is None checkpoint_handler.azure_config.run_recovery_id = DEFAULT_RUN_RECOVERY_ID checkpoint_handler.discover_and_download_checkpoints_from_previous_runs() # We have not set a start_epoch but we are trying to use run_recovery, this should fail with pytest.raises(ValueError) as ex: checkpoint_handler.get_recovery_path_train() assert "Run recovery set, but start epoch is 0" in ex.value.args[0] # Run recovery with start epoch provided should succeed config.start_epoch = 20 expected_path = create_checkpoint_path( path=config.checkpoint_folder / DEFAULT_RUN_RECOVERY_ID.split(":")[1], epoch=config.start_epoch) assert checkpoint_handler.get_recovery_path_train() == expected_path # set an ensemble run as recovery - not supported checkpoint_handler.azure_config.run_recovery_id = DEFAULT_ENSEMBLE_RUN_RECOVERY_ID checkpoint_handler.discover_and_download_checkpoints_from_previous_runs() with pytest.raises(ValueError) as ex: checkpoint_handler.get_recovery_path_train() assert "Found more than one checkpoint for epoch" in ex.value.args[0] # weights from local_weights_path and weights_url will be modified if needed and stored at this location expected_path = checkpoint_handler.model_config.outputs_folder / WEIGHTS_FILE # Set a weights_url to get checkpoint from checkpoint_handler.azure_config.run_recovery_id = "" config.weights_url = EXTERNAL_WEIGHTS_URL_EXAMPLE checkpoint_handler.discover_and_download_checkpoints_from_previous_runs() assert checkpoint_handler.local_weights_path == expected_path config.start_epoch = 0 assert checkpoint_handler.get_recovery_path_train() == expected_path # Can't resume training from an external checkpoint config.start_epoch = 20 with pytest.raises(ValueError) as ex: checkpoint_handler.get_recovery_path_train() assert ex.value.args == "Start epoch is > 0, but no run recovery object has been provided to resume training." # Set a local_weights_path to get checkpoint from config.weights_url = "" local_weights_path = test_output_dirs.root_dir / "exist.pth" stored_checkpoint = create_checkpoint_path( full_ml_test_data_path("checkpoints"), epoch=1) shutil.copyfile(str(stored_checkpoint), local_weights_path) config.local_weights_path = local_weights_path checkpoint_handler.discover_and_download_checkpoints_from_previous_runs() assert checkpoint_handler.local_weights_path == expected_path config.start_epoch = 0 assert checkpoint_handler.get_recovery_path_train() == expected_path # Can't resume training from an external checkpoint config.start_epoch = 20 with pytest.raises(ValueError) as ex: checkpoint_handler.get_recovery_path_train() assert ex.value.args == "Start epoch is > 0, but no run recovery object has been provided to resume training."
def test_discover_and_download_checkpoints_from_previous_runs( test_output_dirs: OutputFolderForTests) -> None: config = ModelConfigBase(should_validate=False) config.set_output_to(test_output_dirs.root_dir) config.outputs_folder.mkdir() # No checkpoint handling options set. checkpoint_handler = get_default_checkpoint_handler( model_config=config, project_root=test_output_dirs.root_dir) checkpoint_handler.discover_and_download_checkpoints_from_previous_runs() assert not checkpoint_handler.run_recovery assert not checkpoint_handler.local_weights_path # Set a run recovery object - non ensemble checkpoint_handler.azure_config.run_recovery_id = DEFAULT_RUN_RECOVERY_ID checkpoint_handler.discover_and_download_checkpoints_from_previous_runs() expected_checkpoint_root = config.checkpoint_folder / DEFAULT_RUN_RECOVERY_ID.split( ":")[1] expected_paths = [ create_checkpoint_path(path=expected_checkpoint_root, epoch=epoch) for epoch in [1, 2, 3, 4, 20] ] assert checkpoint_handler.run_recovery assert checkpoint_handler.run_recovery.checkpoints_roots == [ expected_checkpoint_root ] for path in expected_paths: assert path.is_file() # Set a run recovery object - ensemble checkpoint_handler.azure_config.run_recovery_id = DEFAULT_ENSEMBLE_RUN_RECOVERY_ID checkpoint_handler.discover_and_download_checkpoints_from_previous_runs() expected_checkpoint_roots = [ config.checkpoint_folder / OTHER_RUNS_SUBDIR_NAME / str(i) for i in range(3) ] expected_path_lists = [[ create_checkpoint_path(path=expected_checkpoint_root, epoch=epoch) for epoch in [1, 2] ] for expected_checkpoint_root in expected_checkpoint_roots] assert set(checkpoint_handler.run_recovery.checkpoints_roots) == set( expected_checkpoint_roots) for path_list in expected_path_lists: for path in path_list: assert path.is_file() # weights from local_weights_path and weights_url will be modified if needed and stored at this location expected_path = checkpoint_handler.model_config.outputs_folder / WEIGHTS_FILE # Set a weights_path checkpoint_handler.azure_config.run_recovery_id = "" config.weights_url = EXTERNAL_WEIGHTS_URL_EXAMPLE checkpoint_handler.discover_and_download_checkpoints_from_previous_runs() assert checkpoint_handler.local_weights_path == expected_path assert checkpoint_handler.local_weights_path.is_file() # set a local_weights_path config.weights_url = "" local_weights_path = test_output_dirs.root_dir / "exist.pth" stored_checkpoint = create_checkpoint_path( path=full_ml_test_data_path("checkpoints"), epoch=1) shutil.copyfile(str(stored_checkpoint), local_weights_path) config.local_weights_path = local_weights_path checkpoint_handler.discover_and_download_checkpoints_from_previous_runs() assert checkpoint_handler.local_weights_path == expected_path
def test_get_checkpoint_from_epoch( test_output_dirs: OutputFolderForTests) -> None: config = ModelConfigBase(should_validate=False) config.set_output_to(test_output_dirs.root_dir) config.outputs_folder.mkdir() manage_recovery = get_default_checkpoint_handler( model_config=config, project_root=test_output_dirs.root_dir) # We have not set a run_recovery, nor have we trained, so this should fail to get a checkpoint with pytest.raises(ValueError) as ex: manage_recovery.get_checkpoint_from_epoch(1) assert "no run recovery object provided and no training has been done in this run" in ex.value.args[ 0] # We have set a run_recovery_id now, so this should work manage_recovery.azure_config.run_recovery_id = DEFAULT_RUN_RECOVERY_ID manage_recovery.discover_and_download_checkpoints_from_previous_runs() expected_checkpoint = create_checkpoint_path( path=config.checkpoint_folder / DEFAULT_RUN_RECOVERY_ID.split(":")[1], epoch=1) checkpoint = manage_recovery.get_checkpoint_from_epoch(1) assert checkpoint assert len(checkpoint.checkpoint_paths) == 1 assert expected_checkpoint == checkpoint.checkpoint_paths[0] assert checkpoint.epoch == 1 # ensemble run recovery manage_recovery.azure_config.run_recovery_id = DEFAULT_ENSEMBLE_RUN_RECOVERY_ID manage_recovery.discover_and_download_checkpoints_from_previous_runs() expected_checkpoints = [ create_checkpoint_path(path=config.checkpoint_folder / OTHER_RUNS_SUBDIR_NAME / str(i), epoch=1) for i in range(3) ] checkpoint = manage_recovery.get_checkpoint_from_epoch(1) assert checkpoint assert len(checkpoint.checkpoint_paths) == 3 assert set(expected_checkpoints) == set(checkpoint.checkpoint_paths) assert checkpoint.epoch == 1 # From now on, the checkpoint handler will think that the run was started from epoch 1, i.e. we should use the # run recovery checkpoint for epoch 1 and the training run checkpoint for epoch 2 manage_recovery.additional_training_done() # go back to non ensemble run recovery manage_recovery.azure_config.run_recovery_id = DEFAULT_RUN_RECOVERY_ID manage_recovery.discover_and_download_checkpoints_from_previous_runs() config.start_epoch = 1 # We haven't actually done a training run ,so the checkpoint for epoch 2 is missing - and we should not use the one # from run recovery assert manage_recovery.get_checkpoint_from_epoch(2) is None # Should work for epoch 1 checkpoint = manage_recovery.get_checkpoint_from_epoch(1) expected_checkpoint = create_checkpoint_path( path=config.checkpoint_folder / DEFAULT_RUN_RECOVERY_ID.split(":")[1], epoch=1) assert checkpoint assert len(checkpoint.checkpoint_paths) == 1 assert checkpoint.checkpoint_paths[0] == expected_checkpoint assert checkpoint.epoch == 1 # Copy over checkpoints to make it look like training has happened stored_checkpoint = create_checkpoint_path( path=full_ml_test_data_path("checkpoints"), epoch=1) expected_checkpoint = create_checkpoint_path(path=config.checkpoint_folder, epoch=2) shutil.copyfile(str(stored_checkpoint), str(expected_checkpoint)) # Should now work for epoch 2 checkpoint = manage_recovery.get_checkpoint_from_epoch(2) assert checkpoint assert len(checkpoint.checkpoint_paths) == 1 assert expected_checkpoint == checkpoint.checkpoint_paths[0] assert checkpoint.epoch == 2