def setup(self, stage: Optional[str] = None) -> None: """ Checks if the dataset folder is present, and the dataset file exists. This is execute on each node in distributed training. """ # Check for existing dataset.csv file in the correct locations. Skip that if a dataset has already been # loaded (typically only during tests) if self.config.dataset_data_frame is None: assert self.config.local_dataset is not None validate_dataset_paths(self.config.local_dataset, self.config.dataset_csv) self.config.read_dataset_if_needed() self.data_loaders = self.config.create_data_loaders()
def validate_dataset_paths( model_config: Union[ScalarModelBase, SegmentationModelBase]) -> None: """Check that validation of dataset paths is succeeds when csv file exists, and fails when it's missing.""" assert model_config.local_dataset is not None ml_util.validate_dataset_paths(model_config.local_dataset, model_config.dataset_csv) dataset_csv_path = model_config.local_dataset / model_config.dataset_csv dataset_csv_path.unlink() ex_message = f"The dataset file {model_config.dataset_csv} is not present" with pytest.raises(ValueError) as ex: ml_util.validate_dataset_paths(model_config.local_dataset, model_config.dataset_csv) assert ex_message in str(ex)
def run(self) -> None: """ Driver function to run a ML experiment. If an offline cross validation run is requested, then this function is recursively called for each cross validation split. """ if self.is_offline_cross_val_parent_run(): if self.model_config.is_segmentation_model: raise NotImplementedError( "Offline cross validation is only supported for classification models." ) self.spawn_offline_cross_val_classification_child_runs() return # Get the AzureML context in which the script is running if not self.model_config.is_offline_run and PARENT_RUN_CONTEXT is not None: logging.info("Setting tags from parent run.") self.set_run_tags_from_parent() self.save_build_info_for_dotnet_consumers() # Set data loader start method self.set_multiprocessing_start_method() # configure recovery container if provided checkpoint_handler = CheckpointHandler(model_config=self.model_config, azure_config=self.azure_config, project_root=self.project_root, run_context=RUN_CONTEXT) checkpoint_handler.discover_and_download_checkpoints_from_previous_runs( ) # do training and inference, unless the "only register" switch is set (which requires a run_recovery # to be valid). if not self.azure_config.register_model_only_for_epoch: # Set local_dataset to the mounted path specified in azure_runner.py, if any, or download it if that fails # and config.local_dataset was not already set. self.model_config.local_dataset = self.mount_or_download_dataset() self.model_config.write_args_file() logging.info(str(self.model_config)) # Ensure that training runs are fully reproducible - setting random seeds alone is not enough! make_pytorch_reproducible() # Check for existing dataset.csv file in the correct locations. Skip that if a dataset has already been # loaded (typically only during tests) if self.model_config.dataset_data_frame is None: assert self.model_config.local_dataset is not None ml_util.validate_dataset_paths(self.model_config.local_dataset) # train a new model if required if self.azure_config.train: with logging_section("Model training"): model_train(self.model_config, checkpoint_handler) else: self.model_config.write_dataset_files() self.create_activation_maps() # log the number of epochs used for model training RUN_CONTEXT.log(name="Train epochs", value=self.model_config.num_epochs) # We specify the ModelProcessing as DEFAULT here even if the run_recovery points to an ensemble run, because # the current run is a single one. See the documentation of ModelProcessing for more details. best_epoch = self.run_inference_and_register_model( checkpoint_handler, ModelProcessing.DEFAULT) # Generate report if best_epoch: Runner.generate_report(self.model_config, best_epoch, ModelProcessing.DEFAULT) elif self.model_config.is_scalar_model and len( self.model_config.get_test_epochs()) == 1: # We don't register scalar models but still want to create a report if we have run inference. Runner.generate_report(self.model_config, self.model_config.get_test_epochs()[0], ModelProcessing.DEFAULT)
def run(self) -> None: """ Driver function to run a ML experiment. If an offline cross validation run is requested, then this function is recursively called for each cross validation split. """ if self.is_offline_cross_val_parent_run(): if self.model_config.is_segmentation_model: raise NotImplementedError("Offline cross validation is only supported for classification models.") self.spawn_offline_cross_val_classification_child_runs() return # Get the AzureML context in which the script is running if not self.model_config.is_offline_run and PARENT_RUN_CONTEXT is not None: logging.info("Setting tags from parent run.") self.set_run_tags_from_parent() self.save_build_info_for_dotnet_consumers() # Set data loader start method self.set_multiprocessing_start_method() # configure recovery container if provided checkpoint_handler = CheckpointHandler(model_config=self.model_config, azure_config=self.azure_config, project_root=self.project_root, run_context=RUN_CONTEXT) checkpoint_handler.download_recovery_checkpoints_or_weights() # do training and inference, unless the "only register" switch is set (which requires a run_recovery # to be valid). if not self.azure_config.only_register_model: # Set local_dataset to the mounted path specified in azure_runner.py, if any, or download it if that fails # and config.local_dataset was not already set. self.model_config.local_dataset = self.mount_or_download_dataset() # Check for existing dataset.csv file in the correct locations. Skip that if a dataset has already been # loaded (typically only during tests) if self.model_config.dataset_data_frame is None: assert self.model_config.local_dataset is not None ml_util.validate_dataset_paths( self.model_config.local_dataset, self.model_config.dataset_csv) # train a new model if required if self.azure_config.train: with logging_section("Model training"): model_train(self.model_config, checkpoint_handler, num_nodes=self.azure_config.num_nodes) else: self.model_config.write_dataset_files() self.create_activation_maps() # log the number of epochs used for model training RUN_CONTEXT.log(name="Train epochs", value=self.model_config.num_epochs) # We specify the ModelProcessing as DEFAULT here even if the run_recovery points to an ensemble run, because # the current run is a single one. See the documentation of ModelProcessing for more details. self.run_inference_and_register_model(checkpoint_handler, ModelProcessing.DEFAULT) if self.model_config.generate_report: self.generate_report(ModelProcessing.DEFAULT) # If this is an cross validation run, and the present run is child run 0, then wait for the sibling runs, # build the ensemble model, and write a report for that. if self.model_config.number_of_cross_validation_splits > 0: if self.model_config.should_wait_for_other_cross_val_child_runs(): self.wait_for_runs_to_finish() self.create_ensemble_model()