コード例 #1
0
 def setup(self, stage: Optional[str] = None) -> None:
     """
     Checks if the dataset folder is present, and the dataset file exists. This is execute on each node in
     distributed training.
     """
     # Check for existing dataset.csv file in the correct locations. Skip that if a dataset has already been
     # loaded (typically only during tests)
     if self.config.dataset_data_frame is None:
         assert self.config.local_dataset is not None
         validate_dataset_paths(self.config.local_dataset,
                                self.config.dataset_csv)
     self.config.read_dataset_if_needed()
     self.data_loaders = self.config.create_data_loaders()
コード例 #2
0
def validate_dataset_paths(
        model_config: Union[ScalarModelBase, SegmentationModelBase]) -> None:
    """Check that validation of dataset paths is succeeds when csv file exists,
    and fails when it's missing."""
    assert model_config.local_dataset is not None
    ml_util.validate_dataset_paths(model_config.local_dataset, model_config.dataset_csv)

    dataset_csv_path = model_config.local_dataset / model_config.dataset_csv
    dataset_csv_path.unlink()

    ex_message = f"The dataset file {model_config.dataset_csv} is not present"
    with pytest.raises(ValueError) as ex:
        ml_util.validate_dataset_paths(model_config.local_dataset, model_config.dataset_csv)
    assert ex_message in str(ex)
コード例 #3
0
    def run(self) -> None:
        """
        Driver function to run a ML experiment. If an offline cross validation run is requested, then
        this function is recursively called for each cross validation split.
        """
        if self.is_offline_cross_val_parent_run():
            if self.model_config.is_segmentation_model:
                raise NotImplementedError(
                    "Offline cross validation is only supported for classification models."
                )
            self.spawn_offline_cross_val_classification_child_runs()
            return

        # Get the AzureML context in which the script is running
        if not self.model_config.is_offline_run and PARENT_RUN_CONTEXT is not None:
            logging.info("Setting tags from parent run.")
            self.set_run_tags_from_parent()

        self.save_build_info_for_dotnet_consumers()

        # Set data loader start method
        self.set_multiprocessing_start_method()

        # configure recovery container if provided
        checkpoint_handler = CheckpointHandler(model_config=self.model_config,
                                               azure_config=self.azure_config,
                                               project_root=self.project_root,
                                               run_context=RUN_CONTEXT)
        checkpoint_handler.discover_and_download_checkpoints_from_previous_runs(
        )
        # do training and inference, unless the "only register" switch is set (which requires a run_recovery
        # to be valid).
        if not self.azure_config.register_model_only_for_epoch:
            # Set local_dataset to the mounted path specified in azure_runner.py, if any, or download it if that fails
            # and config.local_dataset was not already set.
            self.model_config.local_dataset = self.mount_or_download_dataset()
            self.model_config.write_args_file()
            logging.info(str(self.model_config))
            # Ensure that training runs are fully reproducible - setting random seeds alone is not enough!
            make_pytorch_reproducible()

            # Check for existing dataset.csv file in the correct locations. Skip that if a dataset has already been
            # loaded (typically only during tests)
            if self.model_config.dataset_data_frame is None:
                assert self.model_config.local_dataset is not None
                ml_util.validate_dataset_paths(self.model_config.local_dataset)

            # train a new model if required
            if self.azure_config.train:
                with logging_section("Model training"):
                    model_train(self.model_config, checkpoint_handler)
            else:
                self.model_config.write_dataset_files()
                self.create_activation_maps()

            # log the number of epochs used for model training
            RUN_CONTEXT.log(name="Train epochs",
                            value=self.model_config.num_epochs)

        # We specify the ModelProcessing as DEFAULT here even if the run_recovery points to an ensemble run, because
        # the current run is a single one. See the documentation of ModelProcessing for more details.
        best_epoch = self.run_inference_and_register_model(
            checkpoint_handler, ModelProcessing.DEFAULT)

        # Generate report
        if best_epoch:
            Runner.generate_report(self.model_config, best_epoch,
                                   ModelProcessing.DEFAULT)
        elif self.model_config.is_scalar_model and len(
                self.model_config.get_test_epochs()) == 1:
            # We don't register scalar models but still want to create a report if we have run inference.
            Runner.generate_report(self.model_config,
                                   self.model_config.get_test_epochs()[0],
                                   ModelProcessing.DEFAULT)
コード例 #4
0
    def run(self) -> None:
        """
        Driver function to run a ML experiment. If an offline cross validation run is requested, then
        this function is recursively called for each cross validation split.
        """
        if self.is_offline_cross_val_parent_run():
            if self.model_config.is_segmentation_model:
                raise NotImplementedError("Offline cross validation is only supported for classification models.")
            self.spawn_offline_cross_val_classification_child_runs()
            return

        # Get the AzureML context in which the script is running
        if not self.model_config.is_offline_run and PARENT_RUN_CONTEXT is not None:
            logging.info("Setting tags from parent run.")
            self.set_run_tags_from_parent()

        self.save_build_info_for_dotnet_consumers()

        # Set data loader start method
        self.set_multiprocessing_start_method()

        # configure recovery container if provided
        checkpoint_handler = CheckpointHandler(model_config=self.model_config,
                                               azure_config=self.azure_config,
                                               project_root=self.project_root,
                                               run_context=RUN_CONTEXT)
        checkpoint_handler.download_recovery_checkpoints_or_weights()
        # do training and inference, unless the "only register" switch is set (which requires a run_recovery
        # to be valid).
        if not self.azure_config.only_register_model:
            # Set local_dataset to the mounted path specified in azure_runner.py, if any, or download it if that fails
            # and config.local_dataset was not already set.
            self.model_config.local_dataset = self.mount_or_download_dataset()
            # Check for existing dataset.csv file in the correct locations. Skip that if a dataset has already been
            # loaded (typically only during tests)
            if self.model_config.dataset_data_frame is None:
                assert self.model_config.local_dataset is not None
                ml_util.validate_dataset_paths(
                    self.model_config.local_dataset,
                    self.model_config.dataset_csv)

            # train a new model if required
            if self.azure_config.train:
                with logging_section("Model training"):
                    model_train(self.model_config, checkpoint_handler, num_nodes=self.azure_config.num_nodes)
            else:
                self.model_config.write_dataset_files()
                self.create_activation_maps()

            # log the number of epochs used for model training
            RUN_CONTEXT.log(name="Train epochs", value=self.model_config.num_epochs)

        # We specify the ModelProcessing as DEFAULT here even if the run_recovery points to an ensemble run, because
        # the current run is a single one. See the documentation of ModelProcessing for more details.
        self.run_inference_and_register_model(checkpoint_handler, ModelProcessing.DEFAULT)

        if self.model_config.generate_report:
            self.generate_report(ModelProcessing.DEFAULT)

        # If this is an cross validation run, and the present run is child run 0, then wait for the sibling runs,
        # build the ensemble model, and write a report for that.
        if self.model_config.number_of_cross_validation_splits > 0:
            if self.model_config.should_wait_for_other_cross_val_child_runs():
                self.wait_for_runs_to_finish()
                self.create_ensemble_model()