def download_recovery_checkpoints_or_weights(self) -> None: """ Download checkpoints from a run recovery object or from a weights url. Set the checkpoints path based on the run_recovery_object, weights_url or local_weights_path. This is called at the start of training. """ if self.azure_config.run_recovery_id: run_to_recover = self.azure_config.fetch_run( self.azure_config.run_recovery_id.strip()) self.run_recovery = RunRecovery.download_all_checkpoints_from_run( self.output_params, run_to_recover) else: self.run_recovery = None if self.azure_config.pretraining_run_recovery_id is not None: run_to_recover = self.azure_config.fetch_run( self.azure_config.pretraining_run_recovery_id.strip()) run_recovery_object = RunRecovery.download_all_checkpoints_from_run( self.output_params, run_to_recover, EXTRA_RUN_SUBFOLDER) self.container.extra_downloaded_run_id = run_recovery_object else: self.container.extra_downloaded_run_id = None if self.container.weights_url or self.container.local_weights_path: self.local_weights_path = self.get_and_save_modified_weights()
def test_download_recovery_single_run(test_output_dirs: OutputFolderForTests, runner_config: AzureConfig) -> None: output_dir = test_output_dirs.root_dir config = ModelConfigBase(should_validate=False) config.set_output_to(output_dir) run = get_most_recent_run(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN) run_recovery = RunRecovery.download_all_checkpoints_from_run(config, run) # This fails if there is no recovery checkpoint check_single_checkpoint(run_recovery.get_recovery_checkpoint_paths()) check_single_checkpoint(run_recovery.get_best_checkpoint_paths())
def download_recovery_checkpoints_or_weights(self) -> None: """ Download checkpoints from a run recovery object or from a weights url. Set the checkpoints path based on the run_recovery_object, weights_url or local_weights_path. This is called at the start of training. """ if self.azure_config.run_recovery_id: run_to_recover = self.azure_config.fetch_run(self.azure_config.run_recovery_id.strip()) self.run_recovery = RunRecovery.download_all_checkpoints_from_run(self.model_config, run_to_recover) else: self.run_recovery = None if self.model_config.weights_url or self.model_config.local_weights_path: self.local_weights_path = self.get_and_save_modified_weights()