Beispiel #1
0
    def download_recovery_checkpoints_or_weights(self) -> None:
        """
        Download checkpoints from a run recovery object or from a weights url. Set the checkpoints path based on the
        run_recovery_object, weights_url or local_weights_path.
        This is called at the start of training.
        """
        if self.azure_config.run_recovery_id:
            run_to_recover = self.azure_config.fetch_run(
                self.azure_config.run_recovery_id.strip())
            self.run_recovery = RunRecovery.download_all_checkpoints_from_run(
                self.output_params, run_to_recover)
        else:
            self.run_recovery = None

        if self.azure_config.pretraining_run_recovery_id is not None:
            run_to_recover = self.azure_config.fetch_run(
                self.azure_config.pretraining_run_recovery_id.strip())
            run_recovery_object = RunRecovery.download_all_checkpoints_from_run(
                self.output_params, run_to_recover, EXTRA_RUN_SUBFOLDER)
            self.container.extra_downloaded_run_id = run_recovery_object
        else:
            self.container.extra_downloaded_run_id = None

        if self.container.weights_url or self.container.local_weights_path:
            self.local_weights_path = self.get_and_save_modified_weights()
def test_download_recovery_single_run(test_output_dirs: OutputFolderForTests,
                                      runner_config: AzureConfig) -> None:
    output_dir = test_output_dirs.root_dir
    config = ModelConfigBase(should_validate=False)
    config.set_output_to(output_dir)
    run = get_most_recent_run(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)
    run_recovery = RunRecovery.download_all_checkpoints_from_run(config, run)

    # This fails if there is no recovery checkpoint
    check_single_checkpoint(run_recovery.get_recovery_checkpoint_paths())
    check_single_checkpoint(run_recovery.get_best_checkpoint_paths())
    def download_recovery_checkpoints_or_weights(self) -> None:
        """
        Download checkpoints from a run recovery object or from a weights url. Set the checkpoints path based on the
        run_recovery_object, weights_url or local_weights_path.
        This is called at the start of training.
        """
        if self.azure_config.run_recovery_id:
            run_to_recover = self.azure_config.fetch_run(self.azure_config.run_recovery_id.strip())
            self.run_recovery = RunRecovery.download_all_checkpoints_from_run(self.model_config, run_to_recover)
        else:
            self.run_recovery = None

        if self.model_config.weights_url or self.model_config.local_weights_path:
            self.local_weights_path = self.get_and_save_modified_weights()