コード例 #1
0
def test_rnn_classifier_via_config_1(
        use_combined_model: bool, imaging_feature_type: ImagingFeatureType,
        combine_hidden_state: bool, use_encoder_layer_norm: bool,
        use_mean_teacher_model: bool,
        test_output_dirs: OutputFolderForTests) -> None:
    """
    Test if we can build a simple RNN model that only feeds off non-image features.
    This just tests the mechanics of training, but not if the model learned.
    """
    logging_to_stdout()
    config = ToySequenceModel(use_combined_model,
                              imaging_feature_type=imaging_feature_type,
                              combine_hidden_states=combine_hidden_state,
                              use_encoder_layer_norm=use_encoder_layer_norm,
                              use_mean_teacher_model=use_mean_teacher_model,
                              should_validate=False)
    # This fails with 16bit precision, saying "torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are
    # unsafe to autocast. Many models use a sigmoid layer right before the binary cross entropy layer. In this case,
    # combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits or
    # torch.nn.BCEWithLogitsLoss.  binary_cross_entropy_with_logits and BCEWithLogits are safe to autocast."
    config.use_mixed_precision = False
    config.set_output_to(test_output_dirs.root_dir)
    config.dataset_data_frame = _get_mock_sequence_dataset()
    # Patch the load_images function that will be called once we access a dataset item
    image_and_seg = ImageAndSegmentations[np.ndarray](
        images=np.random.uniform(0, 1, SCAN_SIZE),
        segmentations=np.random.randint(0, 2, SCAN_SIZE))
    with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats',
                    return_value=image_and_seg):
        model_train(
            config,
            get_default_checkpoint_handler(
                model_config=config, project_root=test_output_dirs.root_dir))
コード例 #2
0
def test_recover_training_mean_teacher_model(
        test_output_dirs: OutputFolderForTests) -> None:
    """
    Tests that training can be recovered from a previous checkpoint.
    """
    config = DummyClassification()
    config.mean_teacher_alpha = 0.999
    config.recovery_checkpoint_save_interval = 1
    config.set_output_to(test_output_dirs.root_dir / "original")
    os.makedirs(str(config.outputs_folder))

    original_checkpoint_folder = config.checkpoint_folder

    # First round of training
    config.num_epochs = 2
    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=test_output_dirs.root_dir)
    model_train(config, checkpoint_handler=checkpoint_handler)
    assert len(list(config.checkpoint_folder.glob("*.*"))) == 2

    # Restart training from previous run
    config.start_epoch = 2
    config.num_epochs = 3
    config.set_output_to(test_output_dirs.root_dir / "recovered")
    os.makedirs(str(config.outputs_folder))
    # make if seem like run recovery objects have been downloaded
    checkpoint_root = config.checkpoint_folder / "old_run"
    shutil.copytree(str(original_checkpoint_folder), str(checkpoint_root))
    checkpoint_handler.run_recovery = RunRecovery([checkpoint_root])

    model_train(config, checkpoint_handler=checkpoint_handler)
    # remove recovery checkpoints
    shutil.rmtree(checkpoint_root)
    assert len(list(config.checkpoint_folder.glob("*.*"))) == 2
コード例 #3
0
def test_rnn_classifier_via_config_1(
        use_combined_model: bool, imaging_feature_type: ImagingFeatureType,
        combine_hidden_state: bool, use_encoder_layer_norm: bool,
        use_mean_teacher_model: bool,
        test_output_dirs: OutputFolderForTests) -> None:
    """
    Test if we can build a simple RNN model that only feeds off non-image features.
    This just tests the mechanics of training, but not if the model learned.
    """
    logging_to_stdout()
    config = ToySequenceModel(use_combined_model,
                              imaging_feature_type=imaging_feature_type,
                              combine_hidden_states=combine_hidden_state,
                              use_encoder_layer_norm=use_encoder_layer_norm,
                              use_mean_teacher_model=use_mean_teacher_model,
                              should_validate=False)
    config.use_mixed_precision = True
    config.set_output_to(test_output_dirs.root_dir)
    config.dataset_data_frame = _get_mock_sequence_dataset()
    # Patch the load_images function that will be called once we access a dataset item
    image_and_seg = ImageAndSegmentations[np.ndarray](
        images=np.random.uniform(0, 1, SCAN_SIZE),
        segmentations=np.random.randint(0, 2, SCAN_SIZE))
    with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats',
                    return_value=image_and_seg):
        model_train(
            config,
            get_default_checkpoint_handler(
                model_config=config, project_root=test_output_dirs.root_dir))
コード例 #4
0
def test_mean_teacher_model() -> None:
    """
    Test training and weight updates of the mean teacher model computation.
    """
    def _get_parameters_of_model(
            model: Union[torch.nn.Module, DataParallelModel]) -> Any:
        """
        Returns the iterator of model parameters
        """
        if isinstance(model, DataParallelModel):
            return model.module.parameters()
        else:
            return model.parameters()

    config = DummyClassification()
    config.num_epochs = 1
    # Set train batch size to be arbitrary big to ensure we have only one training step
    # i.e. one mean teacher update.
    config.train_batch_size = 100
    # Train without mean teacher
    model_train(config)

    # Retrieve the weight after one epoch
    model = create_model_with_temperature_scaling(config)
    print(config.get_path_to_checkpoint(1))
    _ = model_util.load_checkpoint(model, config.get_path_to_checkpoint(1))
    model_weight = next(_get_parameters_of_model(model))

    # Get the starting weight of the mean teacher model
    ml_util.set_random_seed(config.get_effective_random_seed())
    _ = create_model_with_temperature_scaling(config)
    mean_teach_model = create_model_with_temperature_scaling(config)
    initial_weight_mean_teacher_model = next(
        _get_parameters_of_model(mean_teach_model))

    # Now train with mean teacher and check the update of the weight
    alpha = 0.999
    config.mean_teacher_alpha = alpha
    model_train(config)

    # Retrieve weight of mean teacher model saved in the checkpoint
    mean_teacher_model = create_model_with_temperature_scaling(config)
    _ = model_util.load_checkpoint(
        mean_teacher_model,
        config.get_path_to_checkpoint(1, for_mean_teacher_model=True))
    result_weight = next(_get_parameters_of_model(mean_teacher_model))
    # Retrieve the associated student weight
    _ = model_util.load_checkpoint(model, config.get_path_to_checkpoint(1))
    student_model_weight = next(_get_parameters_of_model(model))

    # Assert that the student weight corresponds to the weight of a simple training without mean teacher
    # computation
    assert student_model_weight.allclose(model_weight)

    # Check the update of the parameters
    assert torch.all(alpha * initial_weight_mean_teacher_model +
                     (1 - alpha) * student_model_weight == result_weight)
コード例 #5
0
def test_rnn_classifier_via_config_1(use_combined_model: bool,
                                     imaging_feature_type: ImagingFeatureType,
                                     combine_hidden_state: bool,
                                     use_encoder_layer_norm: bool,
                                     use_mean_teacher_model: bool,
                                     test_output_dirs: TestOutputDirectories) -> None:
    """
    Test if we can build a simple RNN model that only feeds off non-image features.
    This just tests the mechanics of training, but not if the model learned.
    """
    logging_to_stdout()
    config = ToySequenceModel(use_combined_model,
                              imaging_feature_type=imaging_feature_type,
                              combine_hidden_states=combine_hidden_state,
                              use_encoder_layer_norm=use_encoder_layer_norm,
                              use_mean_teacher_model=use_mean_teacher_model,
                              should_validate=False)
    config.set_output_to(test_output_dirs.root_dir)
    config.dataset_data_frame = _get_mock_sequence_dataset()
    # Patch the load_images function that will be called once we access a dataset item
    image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, SCAN_SIZE),
                                                      segmentations=np.random.randint(0, 2, SCAN_SIZE))
    with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg):
        results = model_train(config)
        assert len(results.optimal_temperature_scale_values_per_checkpoint_epoch) \
               == config.get_total_number_of_save_epochs()
コード例 #6
0
def test_rnn_classifier_via_config_2(test_output_dirs: TestOutputDirectories) -> None:
    """
    Test if we can build an RNN classifier that learns sequences, of the same kind as in
    test_rnn_classifier_toy_problem, but built via the config.
    """
    expected_max_train_loss = 0.71
    expected_max_val_loss = 0.71
    num_sequences = 100
    ml_util.set_random_seed(123)
    dataset_contents = "subject,index,feature,label\n"
    for subject in range(num_sequences):
        # Sequences have variable length
        sequence_length = np.random.choice([9, 10, 11, 12])
        # Each sequence is a series of 0 and 1
        inputs = np.random.choice([0, 1], size=(sequence_length,), p=[1. / 3, 2. / 3])
        label = np.sum(inputs) > (sequence_length // 2)
        for i, value in enumerate(inputs.tolist()):
            dataset_contents += f"S{subject},{i},{value},{label}\n"
    logging_to_stdout()
    config = ToySequenceModel2(should_validate=False)
    config.num_epochs = 2
    config.set_output_to(test_output_dirs.root_dir)
    config.dataset_data_frame = _get_mock_sequence_dataset(dataset_contents)
    results = model_train(config)

    actual_train_loss = results.train_results_per_epoch[-1].values()[MetricType.LOSS.value][0]
    actual_val_loss = results.val_results_per_epoch[-1].values()[MetricType.LOSS.value][0]
    print(f"Training loss after {config.num_epochs} epochs: {actual_train_loss}")
    print(f"Validation loss after {config.num_epochs} epochs: {actual_val_loss}")
    assert actual_train_loss <= expected_max_train_loss, "Training loss too high"
    assert actual_val_loss <= expected_max_val_loss, "Validation loss too high"
    assert len(results.optimal_temperature_scale_values_per_checkpoint_epoch) \
           == config.get_total_number_of_save_epochs()
    assert np.allclose(results.optimal_temperature_scale_values_per_checkpoint_epoch, [0.97], rtol=0.1)
コード例 #7
0
def model_train_unittest(config: Optional[DeepLearningConfig],
                         dirs: OutputFolderForTests,
                         checkpoint_handler: Optional[CheckpointHandler] = None,
                         lightning_container: Optional[LightningContainer] = None) -> \
        Tuple[StoringLogger, CheckpointHandler]:
    """
    A shortcut for running model training in the unit test suite. It runs training for the given config, with the
    default checkpoint handler initialized to point to the test output folder specified in dirs.
    :param config: The configuration of the model to train.
    :param dirs: The test fixture that provides an output folder for the test.
    :param lightning_container: An optional LightningContainer object that will be pass through to the training routine.
    :param checkpoint_handler: The checkpoint handler that should be used for training. If not provided, it will be
    created via get_default_checkpoint_handler.
    :return: Tuple[StoringLogger, CheckpointHandler]
    """
    runner = MLRunner(model_config=config, container=lightning_container)
    # Setup will set random seeds before model creation, and set the model in the container.
    # It will also set random seeds correctly. Later we use so initialized container.
    # For all tests running in AzureML, we need to skip the downloading of datasets that would otherwise happen,
    # because all unit test configs come with their own local dataset already.
    runner.setup(use_mount_or_download_dataset=False)
    if checkpoint_handler is None:
        azure_config = get_default_azure_config()
        checkpoint_handler = CheckpointHandler(azure_config=azure_config,
                                               container=runner.container,
                                               project_root=dirs.root_dir)
    _, storing_logger = model_train(checkpoint_handler=checkpoint_handler,
                                    container=runner.container)
    return storing_logger, checkpoint_handler  # type: ignore
コード例 #8
0
def test_train_2d_classification_model(test_output_dirs: OutputFolderForTests,
                                       use_mixed_precision: bool) -> None:
    """
    Test training and testing of 2d classification models.
    """
    logging_to_stdout(logging.DEBUG)
    config = ClassificationModelForTesting2D()
    config.set_output_to(test_output_dirs.root_dir)

    # Train for 4 epochs, checkpoints at epochs 2 and 4
    config.num_epochs = 4
    config.use_mixed_precision = use_mixed_precision

    checkpoint_handler = get_default_checkpoint_handler(model_config=config,
                                                        project_root=Path(test_output_dirs.root_dir))
    model_training_result = model_training.model_train(config, checkpoint_handler=checkpoint_handler)
    assert model_training_result is not None
    expected_learning_rates = [0.0001, 9.99971e-05, 9.99930e-05, 9.99861e-05]

    expected_train_loss = [0.705931, 0.698664, 0.694489, 0.693151]
    expected_val_loss = [1.078517, 1.140510, 1.199026, 1.248595]

    actual_train_loss = model_training_result.get_metric(is_training=True, metric_type=MetricType.LOSS.value)
    actual_val_loss = model_training_result.get_metric(is_training=False, metric_type=MetricType.LOSS.value)
    actual_lr = model_training_result.get_metric(is_training=True, metric_type=MetricType.LEARNING_RATE.value)

    assert actual_train_loss == pytest.approx(expected_train_loss, abs=1e-6)
    assert actual_val_loss == pytest.approx(expected_val_loss, abs=1e-6)
    assert actual_lr == pytest.approx(expected_learning_rates, rel=1e-5)
    test_results = model_testing.model_test(config, ModelExecutionMode.TRAIN, checkpoint_handler=checkpoint_handler)
    assert isinstance(test_results, InferenceMetricsForClassification)
コード例 #9
0
def test_recover_training_mean_teacher_model() -> None:
    """
    Tests that training can be recovered from a previous checkpoint.
    """
    config = DummyClassification()
    config.mean_teacher_alpha = 0.999

    # First round of training
    config.num_epochs = 2
    model_train(config)
    assert len(os.listdir(config.checkpoint_folder)) == 1

    # Restart training from previous run
    config.start_epoch = 2
    config.num_epochs = 3
    model_train(config)
    assert len(os.listdir(config.checkpoint_folder)) == 2
コード例 #10
0
def test_non_image_encoder(test_output_dirs: OutputFolderForTests,
                           hidden_layer_num_feature_channels: Optional[int]) -> None:
    """
    Test if we can build a simple MLP model that only feeds off non-image features.
    """
    dataset_folder = Path(test_output_dirs.make_sub_dir("dataset"))
    dataset_contents = _get_fake_dataset_contents()
    (dataset_folder / DATASET_CSV_FILE_NAME).write_text(dataset_contents)
    config = NonImageEncoder(should_validate=False, hidden_layer_num_feature_channels=hidden_layer_num_feature_channels)
    config.local_dataset = dataset_folder
    config.max_batch_grad_cam = 1
    config.validate()
    # run model training
    checkpoint_handler = get_default_checkpoint_handler(model_config=config,
                                                        project_root=Path(test_output_dirs.root_dir))
    model_train(config, checkpoint_handler=checkpoint_handler)
    # run model inference
    MLRunner(config).model_inference_train_and_test(checkpoint_handler=checkpoint_handler)
    assert config.get_total_number_of_non_imaging_features() == 18
コード例 #11
0
def test_train_2d_classification_model(test_output_dirs: OutputFolderForTests,
                                       use_mixed_precision: bool) -> None:
    """
    Test training and testing of 2d classification models.
    """
    logging_to_stdout(logging.DEBUG)
    config = ClassificationModelForTesting2D()
    config.set_output_to(test_output_dirs.root_dir)

    # Train for 4 epochs, checkpoints at epochs 2 and 4
    config.num_epochs = 4
    config.use_mixed_precision = use_mixed_precision
    config.save_start_epoch = 2
    config.save_step_epochs = 2
    config.test_start_epoch = 2
    config.test_step_epochs = 2
    config.test_diff_epochs = 2
    expected_epochs = [2, 4]
    assert config.get_test_epochs() == expected_epochs

    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=Path(test_output_dirs.root_dir))
    model_training_result = model_training.model_train(
        config, checkpoint_handler=checkpoint_handler)
    assert model_training_result is not None
    expected_learning_rates = [0.0001, 9.99971e-05, 9.99930e-05, 9.99861e-05]

    expected_train_loss = [0.705931, 0.698664, 0.694489, 0.693151]
    expected_val_loss = [1.078517, 1.140510, 1.199026, 1.248595]

    def extract_loss(results: List[MetricsDict]) -> List[float]:
        return [d.values()[MetricType.LOSS.value][0] for d in results]

    actual_train_loss = extract_loss(
        model_training_result.train_results_per_epoch)
    actual_val_loss = extract_loss(model_training_result.val_results_per_epoch)
    actual_learning_rates = list(
        flatten(model_training_result.learning_rates_per_epoch))

    assert actual_train_loss == pytest.approx(expected_train_loss, abs=1e-6)
    assert actual_val_loss == pytest.approx(expected_val_loss, abs=1e-6)
    assert actual_learning_rates == pytest.approx(expected_learning_rates,
                                                  rel=1e-5)
    test_results = model_testing.model_test(
        config,
        ModelExecutionMode.TRAIN,
        checkpoint_handler=checkpoint_handler)
    assert isinstance(test_results, InferenceMetricsForClassification)
    assert list(test_results.epochs.keys()) == expected_epochs
コード例 #12
0
def test_runner_restart(test_output_dirs: OutputFolderForTests) -> None:
    """
    Test if starting training from a folder where the checkpoints folder already has recovery checkpoints picks up
    that it is a recovery run. Also checks that we update the start epoch in the config at loading time.
    """
    model_config = DummyClassification()
    model_config.set_output_to(test_output_dirs.root_dir)
    model_config.num_epochs = FIXED_EPOCH + 2
    # We save all checkpoints - if recovery works as expected we should have a new checkpoint for epoch 4, 5.
    model_config.recovery_checkpoint_save_interval = 1
    model_config.recovery_checkpoints_save_last_k = -1
    runner = MLRunner(model_config=model_config)
    runner.setup(use_mount_or_download_dataset=False)
    # Epochs are 0 based for saving
    create_model_and_store_checkpoint(model_config,
                                      runner.container.checkpoint_folder /
                                      f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
                                      f"{FIXED_EPOCH - 1}{CHECKPOINT_SUFFIX}",
                                      weights_only=False)
    azure_config = get_default_azure_config()
    checkpoint_handler = CheckpointHandler(
        azure_config=azure_config,
        container=runner.container,
        project_root=test_output_dirs.root_dir)
    _, storing_logger = model_train(checkpoint_handler=checkpoint_handler,
                                    container=runner.container)
    # We expect to have 4 checkpoints, FIXED_EPOCH (recovery), FIXED_EPOCH+1, FIXED_EPOCH and best.
    assert len(os.listdir(runner.container.checkpoint_folder)) == 4
    assert (runner.container.checkpoint_folder /
            f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
            f"{FIXED_EPOCH - 1}{CHECKPOINT_SUFFIX}").exists()
    assert (runner.container.checkpoint_folder /
            f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
            f"{FIXED_EPOCH}{CHECKPOINT_SUFFIX}").exists()
    assert (runner.container.checkpoint_folder /
            f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
            f"{FIXED_EPOCH + 1}{CHECKPOINT_SUFFIX}").exists()
    assert (runner.container.checkpoint_folder /
            BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).exists()
    # Check that we really restarted epoch from epoch FIXED_EPOCH.
    assert list(storing_logger.epochs) == [FIXED_EPOCH,
                                           FIXED_EPOCH + 1]  # type: ignore
コード例 #13
0
def test_rnn_classifier_via_config_2(
        test_output_dirs: OutputFolderForTests) -> None:
    """
    Test if we can build an RNN classifier that learns sequences, of the same kind as in
    test_rnn_classifier_toy_problem, but built via the config.
    """
    expected_max_train_loss = 0.71
    expected_max_val_loss = 0.71
    num_sequences = 100
    ml_util.set_random_seed(123)
    dataset_contents = "subject,index,feature,label\n"
    for subject in range(num_sequences):
        # Sequences have variable length
        sequence_length = np.random.choice([9, 10, 11, 12])
        # Each sequence is a series of 0 and 1
        inputs = np.random.choice([0, 1],
                                  size=(sequence_length, ),
                                  p=[1. / 3, 2. / 3])
        label = np.sum(inputs) > (sequence_length // 2)
        for i, value in enumerate(inputs.tolist()):
            dataset_contents += f"S{subject},{i},{value},{label}\n"
    logging_to_stdout()
    config = ToySequenceModel2(should_validate=False)
    config.num_epochs = 2
    config.set_output_to(test_output_dirs.root_dir)
    config.dataset_data_frame = _get_mock_sequence_dataset(dataset_contents)
    results = model_train(
        config,
        get_default_checkpoint_handler(model_config=config,
                                       project_root=test_output_dirs.root_dir))

    actual_train_loss = results.get_metric(
        is_training=True, metric_type=MetricType.LOSS.value)[-1]
    actual_val_loss = results.get_metric(is_training=False,
                                         metric_type=MetricType.LOSS.value)[-1]
    print(
        f"Training loss after {config.num_epochs} epochs: {actual_train_loss}")
    print(
        f"Validation loss after {config.num_epochs} epochs: {actual_val_loss}")
    assert actual_train_loss <= expected_max_train_loss, "Training loss too high"
    assert actual_val_loss <= expected_max_val_loss, "Validation loss too high"
コード例 #14
0
def test_train_classification_model(test_output_dirs: TestOutputDirectories,
                                    use_mixed_precision: bool) -> None:
    """
    Test training and testing of classification models, asserting on the individual results from training and testing.
    Expected test results are stored for GPU with and without mixed precision.
    """
    logging_to_stdout(logging.DEBUG)
    config = ClassificationModelForTesting()
    config.set_output_to(test_output_dirs.root_dir)
    # Train for 4 epochs, checkpoints at epochs 2 and 4
    config.num_epochs = 4
    config.use_mixed_precision = use_mixed_precision
    config.save_start_epoch = 2
    config.save_step_epochs = 2
    config.test_start_epoch = 2
    config.test_step_epochs = 2
    config.test_diff_epochs = 2
    expected_epochs = [2, 4]
    assert config.get_test_epochs() == expected_epochs
    model_training_result = model_training.model_train(config)
    assert model_training_result is not None
    expected_learning_rates = [0.0001, 9.99971e-05, 9.99930e-05, 9.99861e-05]
    use_mixed_precision_and_gpu = use_mixed_precision and machine_has_gpu
    if use_mixed_precision_and_gpu:
        expected_train_loss = [0.686614, 0.686465, 0.686316, 0.686167]
        expected_val_loss = [0.737039, 0.736721, 0.736339, 0.735957]
    else:
        expected_train_loss = [0.686614, 0.686465, 0.686316, 0.686167]
        expected_val_loss = [0.737061, 0.736690, 0.736321, 0.735952]

    def extract_loss(results: List[MetricsDict]) -> List[float]:
        return [d.values()[MetricType.LOSS.value][0] for d in results]

    actual_train_loss = extract_loss(
        model_training_result.train_results_per_epoch)
    actual_val_loss = extract_loss(model_training_result.val_results_per_epoch)
    actual_learning_rates = list(
        flatten(model_training_result.learning_rates_per_epoch))
    assert actual_train_loss == pytest.approx(expected_train_loss, abs=1e-6)
    assert actual_val_loss == pytest.approx(expected_val_loss, abs=1e-6)
    assert actual_learning_rates == pytest.approx(expected_learning_rates,
                                                  rel=1e-5)
    test_results = model_testing.model_test(config, ModelExecutionMode.TRAIN)
    assert isinstance(test_results, InferenceMetricsForClassification)
    assert list(test_results.epochs.keys()) == expected_epochs
    if use_mixed_precision_and_gpu:
        expected_metrics = {
            2: [0.635942, 0.736691],
            4: [0.636085, 0.735952],
        }
    else:
        expected_metrics = {
            2: [0.635941, 0.736690],
            4: [0.636084, 0.735952],
        }
    for epoch in expected_epochs:
        assert test_results.epochs[epoch].values()[MetricType.CROSS_ENTROPY.value] == \
               pytest.approx(expected_metrics[epoch], abs=1e-6)
    # Run detailed logs file check only on CPU, it will contain slightly different metrics on GPU, but here
    # we want to mostly assert that the files look reasonable
    if not machine_has_gpu:
        # Check log EPOCH_METRICS_FILE_NAME
        epoch_metrics_path = config.outputs_folder / ModelExecutionMode.TRAIN.value / EPOCH_METRICS_FILE_NAME
        # Auto-format will break the long header line, hence the strange way of writing it!
        expected_epoch_metrics = \
            "loss,cross_entropy,accuracy_at_threshold_05,seconds_per_batch,seconds_per_epoch,learning_rate," + \
            "area_under_roc_curve,area_under_pr_curve,accuracy_at_optimal_threshold," \
            "false_positive_rate_at_optimal_threshold,false_negative_rate_at_optimal_threshold," \
            "optimal_threshold,subject_count,epoch,cross_validation_split_index\n" + \
            """0.6866141557693481,0.6866141557693481,0.5,0,0,0.0001,1.0,1.0,0.5,0.0,0.0,0.529514,2.0,1,-1
            0.6864652633666992,0.6864652633666992,0.5,0,0,9.999712322065557e-05,1.0,1.0,0.5,0.0,0.0,0.529475,2.0,2,-1
            0.6863163113594055,0.6863162517547607,0.5,0,0,9.999306876841536e-05,1.0,1.0,0.5,0.0,0.0,0.529437,2.0,3,-1
            0.6861673593521118,0.6861673593521118,0.5,0,0,9.998613801725043e-05,1.0,1.0,0.5,0.0,0.0,0.529399,2.0,4,-1
            """
        check_log_file(epoch_metrics_path,
                       expected_epoch_metrics,
                       ignore_columns=[
                           LoggingColumns.SecondsPerBatch.value,
                           LoggingColumns.SecondsPerEpoch.value
                       ])

        # Check log METRICS_FILE_NAME
        metrics_path = config.outputs_folder / ModelExecutionMode.TRAIN.value / METRICS_FILE_NAME
        metrics_expected = \
            """prediction_target,epoch,subject,model_output,label,cross_validation_split_index,data_split
Default,1,S4,0.5216594338417053,0.0,-1,Train
Default,1,S2,0.5295137763023376,1.0,-1,Train
Default,2,S4,0.5214819312095642,0.0,-1,Train
Default,2,S2,0.5294750332832336,1.0,-1,Train
Default,3,S4,0.5213046073913574,0.0,-1,Train
Default,3,S2,0.5294366478919983,1.0,-1,Train
Default,4,S4,0.5211275815963745,0.0,-1,Train
Default,4,S2,0.5293986201286316,1.0,-1,Train
"""
        check_log_file(metrics_path, metrics_expected, ignore_columns=[])
コード例 #15
0
def _test_model_train(output_dirs: OutputFolderForTests,
                      image_channels: Any,
                      ground_truth_ids: Any,
                      no_mask_channel: bool = False) -> None:
    def _check_patch_centers(diagnostics_per_epoch: List[np.ndarray],
                             should_equal: bool) -> None:
        patch_centers_epoch1 = diagnostics_per_epoch[0]
        assert len(
            diagnostics_per_epoch
        ) > 1, "Not enough data to check patch centers, need at least 2"
        for diagnostic in diagnostics_per_epoch[1:]:
            assert np.array_equal(patch_centers_epoch1,
                                  diagnostic) == should_equal

    def _check_voxel_count(results_per_epoch: List[Dict[str, float]],
                           expected_voxel_count_per_epoch: List[float],
                           prefix: str) -> None:
        assert len(results_per_epoch) == len(expected_voxel_count_per_epoch)
        for epoch, (results, voxel_count) in enumerate(
                zip(results_per_epoch, expected_voxel_count_per_epoch)):
            # In the test data, both structures "region" and "region_1" are read from the same nifti file, hence
            # their voxel counts must be identical.
            for structure in ["region", "region_1"]:
                assert results[f"{MetricType.VOXEL_COUNT.value}/{structure}"] == pytest.approx(voxel_count, abs=1e-2), \
                    f"{prefix} voxel count mismatch for '{structure}' epoch {epoch}"

    def _mean(a: List[float]) -> float:
        return sum(a) / len(a)

    def _mean_list(lists: List[List[float]]) -> List[float]:
        return list(map(_mean, lists))

    logging_to_stdout(log_level=logging.DEBUG)
    train_config = DummyModel()
    train_config.local_dataset = base_path
    train_config.set_output_to(output_dirs.root_dir)
    train_config.image_channels = image_channels
    train_config.ground_truth_ids = ground_truth_ids
    train_config.mask_id = None if no_mask_channel else train_config.mask_id
    train_config.random_seed = 42
    train_config.class_weights = [0.5, 0.25, 0.25]
    train_config.store_dataset_sample = True
    train_config.recovery_checkpoint_save_interval = 1

    if machine_has_gpu:
        expected_train_losses = [0.4553468, 0.454904]
        expected_val_losses = [0.4553881, 0.4553041]
    else:
        expected_train_losses = [0.4553469, 0.4548947]
        expected_val_losses = [0.4553880, 0.4553041]
    loss_absolute_tolerance = 1e-6
    expected_learning_rates = [train_config.l_rate, 5.3589e-4]

    checkpoint_handler = get_default_checkpoint_handler(
        model_config=train_config, project_root=Path(output_dirs.root_dir))
    model_training_result = model_training.model_train(
        train_config, checkpoint_handler=checkpoint_handler)
    assert isinstance(model_training_result, ModelTrainingResults)

    def assert_all_close(metric: str, expected: List[float],
                         **kwargs: Any) -> None:
        actual = model_training_result.get_training_metric(metric)
        assert np.allclose(
            actual, expected, **kwargs
        ), f"Mismatch for {metric}: Got {actual}, expected {expected}"

    # check to make sure training batches are NOT all the same across epochs
    _check_patch_centers(model_training_result.train_diagnostics,
                         should_equal=False)
    # check to make sure validation batches are all the same across epochs
    _check_patch_centers(model_training_result.val_diagnostics,
                         should_equal=True)
    assert_all_close(MetricType.SUBJECT_COUNT.value, [3.0, 3.0])
    assert_all_close(MetricType.LEARNING_RATE.value,
                     expected_learning_rates,
                     rtol=1e-6)

    if is_windows():
        # Randomization comes out slightly different on Windows. Skip the rest of the detailed checks.
        return

    # Simple regression test: Voxel counts should be the same in both epochs on the validation set,
    # and be the same across 'region' and 'region_1' because they derive from the same Nifti files.
    # The following values are read off directly from the results of compute_dice_across_patches in the training loop
    # This checks that averages are computed correctly, and that metric computers are reset after each epoch.
    train_voxels = [[83092.0, 83212.0, 82946.0], [83000.0, 82881.0, 83309.0]]
    val_voxels = [[82765.0, 83212.0], [82765.0, 83212.0]]
    _check_voxel_count(model_training_result.train_results_per_epoch,
                       _mean_list(train_voxels), "Train")
    _check_voxel_count(model_training_result.val_results_per_epoch,
                       _mean_list(val_voxels), "Val")

    actual_train_losses = model_training_result.get_training_metric(
        MetricType.LOSS.value)
    actual_val_losses = model_training_result.get_validation_metric(
        MetricType.LOSS.value)
    print("actual_train_losses = {}".format(actual_train_losses))
    print("actual_val_losses = {}".format(actual_val_losses))
    assert np.allclose(actual_train_losses,
                       expected_train_losses,
                       atol=loss_absolute_tolerance), "Train losses"
    assert np.allclose(actual_val_losses,
                       expected_val_losses,
                       atol=loss_absolute_tolerance), "Val losses"
    # Check that the metric we track for Hyperdrive runs is actually written.
    assert TrackedMetrics.Val_Loss.value.startswith(VALIDATION_PREFIX)
    tracked_metric = TrackedMetrics.Val_Loss.value[len(VALIDATION_PREFIX):]
    for val_result in model_training_result.val_results_per_epoch:
        assert tracked_metric in val_result

    # The following values are read off directly from the results of compute_dice_across_patches in the
    # training loop. Results are slightly different for CPU, hence use a larger tolerance there.
    dice_tolerance = 1e-4 if machine_has_gpu else 4.5e-4
    train_dice_region = [[0.0, 0.0, 4.0282e-04], [0.0309, 0.0334, 0.0961]]
    train_dice_region1 = [[0.4806, 0.4800, 0.4832], [0.4812, 0.4842, 0.4663]]
    # There appears to be some amount of non-determinism here: When using a tolerance of 1e-4, we get occasional
    # test failures on Linux in the cloud (not on Windows, not on AzureML) Unclear where it comes from. Even when
    # failing here, the losses match up to the expected tolerance.
    assert_all_close("Dice/region",
                     _mean_list(train_dice_region),
                     atol=dice_tolerance)
    assert_all_close("Dice/region_1",
                     _mean_list(train_dice_region1),
                     atol=dice_tolerance)
    expected_average_dice = [
        _mean(train_dice_region[i] + train_dice_region1[i])  # type: ignore
        for i in range(len(train_dice_region))
    ]
    assert_all_close("Dice/AverageAcrossStructures",
                     expected_average_dice,
                     atol=dice_tolerance)

    # check output files/directories
    assert train_config.outputs_folder.is_dir()
    assert train_config.logs_folder.is_dir()

    # Tensorboard event files go into a Lightning subfolder (Pytorch Lightning default)
    assert (train_config.logs_folder / "Lightning").is_dir()
    assert len([(train_config.logs_folder / "Lightning").glob("events*")]) == 1

    assert train_config.num_epochs == 2
    # Checkpoint folder
    assert train_config.checkpoint_folder.is_dir()
    actual_checkpoints = list(train_config.checkpoint_folder.rglob("*.ckpt"))
    assert len(
        actual_checkpoints) == 2, f"Actual checkpoints: {actual_checkpoints}"
    assert (train_config.checkpoint_folder /
            RECOVERY_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()
    assert (train_config.checkpoint_folder /
            BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()
    assert (train_config.outputs_folder / DATASET_CSV_FILE_NAME).is_file()
    assert (train_config.outputs_folder /
            STORED_CSV_FILE_NAMES[ModelExecutionMode.TRAIN]).is_file()
    assert (train_config.outputs_folder /
            STORED_CSV_FILE_NAMES[ModelExecutionMode.VAL]).is_file()

    # Path visualization: There should be 3 slices for each of the 2 subjects
    sampling_folder = train_config.outputs_folder / PATCH_SAMPLING_FOLDER
    assert sampling_folder.is_dir()
    assert train_config.show_patch_sampling > 0
    assert len(list(sampling_folder.rglob(
        "*.png"))) == 3 * train_config.show_patch_sampling

    # Time per epoch: Test that we have all these times logged.
    model_training_result.get_training_metric(
        MetricType.SECONDS_PER_EPOCH.value)
    model_training_result.get_validation_metric(
        MetricType.SECONDS_PER_EPOCH.value)
    model_training_result.get_validation_metric(
        MetricType.SECONDS_PER_BATCH.value)
    model_training_result.get_training_metric(
        MetricType.SECONDS_PER_BATCH.value)
コード例 #16
0
def test_mean_teacher_model(test_output_dirs: OutputFolderForTests) -> None:
    """
    Test training and weight updates of the mean teacher model computation.
    """
    def _get_parameters_of_model(model: DeviceAwareModule) -> Any:
        """
        Returns the iterator of model parameters
        """
        if isinstance(model, DataParallelModel):
            return model.module.parameters()
        else:
            return model.parameters()

    config = DummyClassification()
    config.set_output_to(test_output_dirs.root_dir)
    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=test_output_dirs.root_dir)

    config.num_epochs = 1
    # Set train batch size to be arbitrary big to ensure we have only one training step
    # i.e. one mean teacher update.
    config.train_batch_size = 100
    # Train without mean teacher
    model_train(config, checkpoint_handler=checkpoint_handler)

    # Retrieve the weight after one epoch
    model_and_info = ModelAndInfo(
        config=config,
        model_execution_mode=ModelExecutionMode.TEST,
        checkpoint_path=config.get_path_to_checkpoint(epoch=1))
    model_and_info.try_create_model_and_load_from_checkpoint()
    model = model_and_info.model
    model_weight = next(_get_parameters_of_model(model))

    # Get the starting weight of the mean teacher model
    ml_util.set_random_seed(config.get_effective_random_seed())

    model_and_info_mean_teacher = ModelAndInfo(
        config=config,
        model_execution_mode=ModelExecutionMode.TEST,
        checkpoint_path=None)
    model_and_info_mean_teacher.try_create_model_and_load_from_checkpoint()

    model_and_info_mean_teacher.try_create_mean_teacher_model_and_load_from_checkpoint(
    )
    mean_teach_model = model_and_info_mean_teacher.mean_teacher_model
    assert mean_teach_model is not None  # for mypy
    initial_weight_mean_teacher_model = next(
        _get_parameters_of_model(mean_teach_model))

    # Now train with mean teacher and check the update of the weight
    alpha = 0.999
    config.mean_teacher_alpha = alpha
    model_train(config, checkpoint_handler=checkpoint_handler)

    # Retrieve weight of mean teacher model saved in the checkpoint
    model_and_info_mean_teacher = ModelAndInfo(
        config=config,
        model_execution_mode=ModelExecutionMode.TEST,
        checkpoint_path=config.get_path_to_checkpoint(1))
    model_and_info_mean_teacher.try_create_mean_teacher_model_and_load_from_checkpoint(
    )
    mean_teacher_model = model_and_info_mean_teacher.mean_teacher_model
    assert mean_teacher_model is not None  # for mypy
    result_weight = next(_get_parameters_of_model(mean_teacher_model))
    # Retrieve the associated student weight
    model_and_info_mean_teacher.try_create_model_and_load_from_checkpoint()
    student_model = model_and_info_mean_teacher.model
    student_model_weight = next(_get_parameters_of_model(student_model))

    # Assert that the student weight corresponds to the weight of a simple training without mean teacher
    # computation
    assert student_model_weight.allclose(model_weight)

    # Check the update of the parameters
    assert torch.all(alpha * initial_weight_mean_teacher_model +
                     (1 - alpha) * student_model_weight == result_weight)
コード例 #17
0
def test_train_classification_model(
        class_name: str, test_output_dirs: OutputFolderForTests) -> None:
    """
    Test training and testing of classification models, asserting on the individual results from training and
    testing.
    Expected test results are stored for GPU with and without mixed precision.
    """
    logging_to_stdout(logging.DEBUG)
    config = ClassificationModelForTesting()
    config.class_names = [class_name]
    config.set_output_to(test_output_dirs.root_dir)
    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=Path(test_output_dirs.root_dir))
    # Train for 4 epochs, checkpoints at epochs 2 and 4
    config.num_epochs = 4
    model_training_result = model_training.model_train(
        config, checkpoint_handler=checkpoint_handler)
    assert model_training_result is not None
    expected_learning_rates = [0.0001, 9.99971e-05, 9.99930e-05, 9.99861e-05]
    expected_train_loss = [0.686614, 0.686465, 0.686316, 0.686167]
    expected_val_loss = [0.737061, 0.736691, 0.736321, 0.735952]
    # Ensure that all metrics are computed on both training and validation set
    assert len(
        model_training_result.train_results_per_epoch) == config.num_epochs
    assert len(
        model_training_result.val_results_per_epoch) == config.num_epochs
    assert len(model_training_result.train_results_per_epoch[0]) >= 11
    assert len(model_training_result.val_results_per_epoch[0]) >= 11

    for metric in [
            MetricType.ACCURACY_AT_THRESHOLD_05,
            MetricType.ACCURACY_AT_OPTIMAL_THRESHOLD,
            MetricType.AREA_UNDER_PR_CURVE, MetricType.AREA_UNDER_ROC_CURVE,
            MetricType.CROSS_ENTROPY, MetricType.LOSS,
            MetricType.SECONDS_PER_BATCH, MetricType.SECONDS_PER_EPOCH,
            MetricType.SUBJECT_COUNT
    ]:
        assert metric.value in model_training_result.train_results_per_epoch[0], \
            f"{metric.value} not in training"
        assert metric.value in model_training_result.val_results_per_epoch[0], \
            f"{metric.value} not in validation"

    actual_train_loss = model_training_result.get_metric(
        is_training=True, metric_type=MetricType.LOSS.value)
    actual_val_loss = model_training_result.get_metric(
        is_training=False, metric_type=MetricType.LOSS.value)
    actual_lr = model_training_result.get_metric(
        is_training=True, metric_type=MetricType.LEARNING_RATE.value)
    assert actual_train_loss == pytest.approx(expected_train_loss,
                                              abs=1e-6), "Training loss"
    assert actual_val_loss == pytest.approx(expected_val_loss,
                                            abs=1e-6), "Validation loss"
    assert actual_lr == pytest.approx(expected_learning_rates,
                                      rel=1e-5), "Learning rates"
    test_results = model_testing.model_test(
        config,
        ModelExecutionMode.TRAIN,
        checkpoint_handler=checkpoint_handler)
    assert isinstance(test_results, InferenceMetricsForClassification)
    expected_metrics = [0.636085, 0.735952]
    assert test_results.metrics.values(class_name)[MetricType.CROSS_ENTROPY.value] == \
           pytest.approx(expected_metrics, abs=1e-5)
    # Run detailed logs file check only on CPU, it will contain slightly different metrics on GPU, but here
    # we want to mostly assert that the files look reasonable
    if machine_has_gpu:
        return

    # Check epoch_metrics.csv
    epoch_metrics_path = config.outputs_folder / ModelExecutionMode.TRAIN.value / EPOCH_METRICS_FILE_NAME
    # Auto-format will break the long header line, hence the strange way of writing it!
    expected_epoch_metrics = \
        f"{LoggingColumns.Loss.value},{LoggingColumns.CrossEntropy.value}," \
        f"{LoggingColumns.AccuracyAtThreshold05.value},{LoggingColumns.LearningRate.value}," + \
        f"{LoggingColumns.AreaUnderRocCurve.value}," \
        f"{LoggingColumns.AreaUnderPRCurve.value}," \
        f"{LoggingColumns.AccuracyAtOptimalThreshold.value}," \
        f"{LoggingColumns.FalsePositiveRateAtOptimalThreshold.value}," \
        f"{LoggingColumns.FalseNegativeRateAtOptimalThreshold.value}," \
        f"{LoggingColumns.OptimalThreshold.value}," \
        f"{LoggingColumns.SubjectCount.value},{LoggingColumns.Epoch.value}," \
        f"{LoggingColumns.CrossValidationSplitIndex.value}\n" + \
        """0.6866141557693481,0.6866141557693481,0.5,0.0001,1.0,1.0,0.5,0.0,0.0,0.529514,2.0,0,-1	
        0.6864652633666992,0.6864652633666992,0.5,9.999712322065557e-05,1.0,1.0,0.5,0.0,0.0,0.529475,2.0,1,-1	
        0.6863163113594055,0.6863162517547607,0.5,9.999306876841536e-05,1.0,1.0,0.5,0.0,0.0,0.529437,2.0,2,-1	
        0.6861673593521118,0.6861673593521118,0.5,9.998613801725043e-05,1.0,1.0,0.5,0.0,0.0,0.529399,2.0,3,-1	
        """
    check_log_file(epoch_metrics_path,
                   expected_epoch_metrics,
                   ignore_columns=[])
    # Check metrics.csv: This contains the per-subject per-epoch model outputs
    # Randomization comes out slightly different on Windows, hence only execute the test on Linux
    if common_util.is_windows():
        return
    metrics_path = config.outputs_folder / ModelExecutionMode.TRAIN.value / SUBJECT_METRICS_FILE_NAME
    metrics_expected = \
        f"""epoch,subject,prediction_target,model_output,label,data_split,cross_validation_split_index
0,S2,{class_name},0.529514,1,Train,-1
0,S4,{class_name},0.521659,0,Train,-1
1,S4,{class_name},0.521482,0,Train,-1
1,S2,{class_name},0.529475,1,Train,-1
2,S4,{class_name},0.521305,0,Train,-1
2,S2,{class_name},0.529437,1,Train,-1
3,S2,{class_name},0.529399,1,Train,-1
3,S4,{class_name},0.521128,0,Train,-1
"""
    check_log_file(metrics_path, metrics_expected, ignore_columns=[])
    # Check log METRICS_FILE_NAME inside of the folder epoch_004/Train, which is written when we run model_test.
    # Normally, we would run it on the Test and Val splits, but for convenience we test on the train split here.
    inference_metrics_path = config.outputs_folder / get_epoch_results_path(ModelExecutionMode.TRAIN) / \
                             SUBJECT_METRICS_FILE_NAME
    inference_metrics_expected = \
        f"""prediction_target,subject,model_output,label,cross_validation_split_index,data_split
{class_name},S2,0.5293986201286316,1.0,-1,Train
{class_name},S4,0.5211275815963745,0.0,-1,Train
"""
    check_log_file(inference_metrics_path,
                   inference_metrics_expected,
                   ignore_columns=[])
コード例 #18
0
def test_train_classification_multilabel_model(
        test_output_dirs: OutputFolderForTests) -> None:
    """
    Test training and testing of classification models, asserting on the individual results from training and
    testing.
    Expected test results are stored for GPU with and without mixed precision.
    """
    logging_to_stdout(logging.DEBUG)
    config = DummyMulticlassClassification()
    config.set_output_to(test_output_dirs.root_dir)
    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=Path(test_output_dirs.root_dir))
    # Train for 4 epochs, checkpoints at epochs 2 and 4
    config.num_epochs = 4
    model_training_result = model_training.model_train(
        config, checkpoint_handler=checkpoint_handler)
    assert model_training_result is not None
    expected_learning_rates = [0.0001, 9.99971e-05, 9.99930e-05, 9.99861e-05]
    expected_train_loss = [
        0.699870228767395, 0.6239662170410156, 0.551329493522644,
        0.4825132489204407
    ]
    expected_val_loss = [
        0.6299371719360352, 0.5546272993087769, 0.4843321740627289,
        0.41909298300743103
    ]
    # Ensure that all metrics are computed on both training and validation set
    assert len(
        model_training_result.train_results_per_epoch) == config.num_epochs
    assert len(
        model_training_result.val_results_per_epoch) == config.num_epochs
    assert len(model_training_result.train_results_per_epoch[0]) >= 11
    assert len(model_training_result.val_results_per_epoch[0]) >= 11
    for class_name in config.class_names:
        for metric in [
                MetricType.ACCURACY_AT_THRESHOLD_05,
                MetricType.ACCURACY_AT_OPTIMAL_THRESHOLD,
                MetricType.AREA_UNDER_PR_CURVE,
                MetricType.AREA_UNDER_ROC_CURVE, MetricType.CROSS_ENTROPY
        ]:
            assert f'{metric.value}/{class_name}' in model_training_result.train_results_per_epoch[
                0], f"{metric.value} not in training"
            assert f'{metric.value}/{class_name}' in model_training_result.val_results_per_epoch[
                0], f"{metric.value} not in validation"
    for metric in [
            MetricType.LOSS, MetricType.SECONDS_PER_EPOCH,
            MetricType.SUBJECT_COUNT
    ]:
        assert metric.value in model_training_result.train_results_per_epoch[
            0], f"{metric.value} not in training"
        assert metric.value in model_training_result.val_results_per_epoch[
            0], f"{metric.value} not in validation"

    actual_train_loss = model_training_result.get_metric(
        is_training=True, metric_type=MetricType.LOSS.value)
    actual_val_loss = model_training_result.get_metric(
        is_training=False, metric_type=MetricType.LOSS.value)
    actual_lr = model_training_result.get_metric(
        is_training=True, metric_type=MetricType.LEARNING_RATE.value)
    assert actual_train_loss == pytest.approx(expected_train_loss,
                                              abs=1e-6), "Training loss"
    assert actual_val_loss == pytest.approx(expected_val_loss,
                                            abs=1e-6), "Validation loss"
    assert actual_lr == pytest.approx(expected_learning_rates,
                                      rel=1e-5), "Learning rates"
    test_results = model_testing.model_test(
        config,
        ModelExecutionMode.TRAIN,
        checkpoint_handler=checkpoint_handler)
    assert isinstance(test_results, InferenceMetricsForClassification)

    expected_metrics = {
        MetricType.CROSS_ENTROPY: [1.3996, 5.2966, 1.4020, 0.3553, 0.6908],
        MetricType.ACCURACY_AT_THRESHOLD_05:
        [0.0000, 0.0000, 0.0000, 1.0000, 1.0000]
    }

    for i, class_name in enumerate(config.class_names):
        for metric in expected_metrics.keys():
            assert expected_metrics[metric][i] == pytest.approx(
                test_results.metrics.get_single_metric(metric_name=metric,
                                                       hue=class_name), 1e-4)

    def get_epoch_path(mode: ModelExecutionMode) -> Path:
        p = get_epoch_results_path(mode=mode)
        return config.outputs_folder / p / SUBJECT_METRICS_FILE_NAME

    path_to_best_epoch_train = get_epoch_path(ModelExecutionMode.TRAIN)
    path_to_best_epoch_val = get_epoch_path(ModelExecutionMode.VAL)
    path_to_best_epoch_test = get_epoch_path(ModelExecutionMode.TEST)
    generate_classification_notebook(
        result_notebook=config.outputs_folder /
        get_ipynb_report_name(config.model_category.value),
        config=config,
        train_metrics=path_to_best_epoch_train,
        val_metrics=path_to_best_epoch_val,
        test_metrics=path_to_best_epoch_test)
    assert (config.outputs_folder /
            get_html_report_name(config.model_category.value)).exists()

    report_name_multilabel = f"{config.model_category.value}_multilabel"
    generate_classification_multilabel_notebook(
        result_notebook=config.outputs_folder /
        get_ipynb_report_name(report_name_multilabel),
        config=config,
        train_metrics=path_to_best_epoch_train,
        val_metrics=path_to_best_epoch_val,
        test_metrics=path_to_best_epoch_test)
    assert (config.outputs_folder /
            get_html_report_name(report_name_multilabel)).exists()
コード例 #19
0
    def run(self) -> None:
        """
        Driver function to run a ML experiment. If an offline cross validation run is requested, then
        this function is recursively called for each cross validation split.
        """
        if self.is_offline_cross_val_parent_run():
            if self.model_config.is_segmentation_model:
                raise NotImplementedError(
                    "Offline cross validation is only supported for classification models."
                )
            self.spawn_offline_cross_val_classification_child_runs()
            return

        # Get the AzureML context in which the script is running
        if not self.model_config.is_offline_run and PARENT_RUN_CONTEXT is not None:
            logging.info("Setting tags from parent run.")
            self.set_run_tags_from_parent()

        self.save_build_info_for_dotnet_consumers()

        # Set data loader start method
        self.set_multiprocessing_start_method()

        # configure recovery container if provided
        checkpoint_handler = CheckpointHandler(model_config=self.model_config,
                                               azure_config=self.azure_config,
                                               project_root=self.project_root,
                                               run_context=RUN_CONTEXT)
        checkpoint_handler.discover_and_download_checkpoints_from_previous_runs(
        )
        # do training and inference, unless the "only register" switch is set (which requires a run_recovery
        # to be valid).
        if not self.azure_config.register_model_only_for_epoch:
            # Set local_dataset to the mounted path specified in azure_runner.py, if any, or download it if that fails
            # and config.local_dataset was not already set.
            self.model_config.local_dataset = self.mount_or_download_dataset()
            self.model_config.write_args_file()
            logging.info(str(self.model_config))
            # Ensure that training runs are fully reproducible - setting random seeds alone is not enough!
            make_pytorch_reproducible()

            # Check for existing dataset.csv file in the correct locations. Skip that if a dataset has already been
            # loaded (typically only during tests)
            if self.model_config.dataset_data_frame is None:
                assert self.model_config.local_dataset is not None
                ml_util.validate_dataset_paths(self.model_config.local_dataset)

            # train a new model if required
            if self.azure_config.train:
                with logging_section("Model training"):
                    model_train(self.model_config, checkpoint_handler)
            else:
                self.model_config.write_dataset_files()
                self.create_activation_maps()

            # log the number of epochs used for model training
            RUN_CONTEXT.log(name="Train epochs",
                            value=self.model_config.num_epochs)

        # We specify the ModelProcessing as DEFAULT here even if the run_recovery points to an ensemble run, because
        # the current run is a single one. See the documentation of ModelProcessing for more details.
        best_epoch = self.run_inference_and_register_model(
            checkpoint_handler, ModelProcessing.DEFAULT)

        # Generate report
        if best_epoch:
            Runner.generate_report(self.model_config, best_epoch,
                                   ModelProcessing.DEFAULT)
        elif self.model_config.is_scalar_model and len(
                self.model_config.get_test_epochs()) == 1:
            # We don't register scalar models but still want to create a report if we have run inference.
            Runner.generate_report(self.model_config,
                                   self.model_config.get_test_epochs()[0],
                                   ModelProcessing.DEFAULT)
コード例 #20
0
def test_train_classification_model(
        test_output_dirs: OutputFolderForTests) -> None:
    """
    Test training and testing of classification models, asserting on the individual results from training and
    testing.
    Expected test results are stored for GPU with and without mixed precision.
    """
    logging_to_stdout(logging.DEBUG)
    config = ClassificationModelForTesting()
    config.set_output_to(test_output_dirs.root_dir)
    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=Path(test_output_dirs.root_dir))
    # Train for 4 epochs, checkpoints at epochs 2 and 4
    config.num_epochs = 4
    model_training_result = model_training.model_train(
        config, checkpoint_handler=checkpoint_handler)
    assert model_training_result is not None
    expected_learning_rates = [0.0001, 9.99971e-05, 9.99930e-05, 9.99861e-05]
    expected_train_loss = [0.686614, 0.686465, 0.686316, 0.686167]
    expected_val_loss = [0.737061, 0.736691, 0.736321, 0.735952]
    # Ensure that all metrics are computed on both training and validation set
    assert len(
        model_training_result.train_results_per_epoch) == config.num_epochs
    assert len(
        model_training_result.val_results_per_epoch) == config.num_epochs
    assert len(model_training_result.train_results_per_epoch[0]) >= 11
    assert len(model_training_result.val_results_per_epoch[0]) >= 11
    for metric in [
            MetricType.ACCURACY_AT_THRESHOLD_05,
            MetricType.ACCURACY_AT_OPTIMAL_THRESHOLD,
            MetricType.AREA_UNDER_PR_CURVE,
            MetricType.AREA_UNDER_ROC_CURVE,
            MetricType.CROSS_ENTROPY,
            MetricType.LOSS,
            # For unknown reasons, we don't get seconds_per_batch for the training data.
            # MetricType.SECONDS_PER_BATCH,
            MetricType.SECONDS_PER_EPOCH,
            MetricType.SUBJECT_COUNT,
    ]:
        assert metric.value in model_training_result.train_results_per_epoch[
            0], f"{metric.value} not in training"
        assert metric.value in model_training_result.val_results_per_epoch[
            0], f"{metric.value} not in validation"
    actual_train_loss = model_training_result.get_metric(
        is_training=True, metric_type=MetricType.LOSS.value)
    actual_val_loss = model_training_result.get_metric(
        is_training=False, metric_type=MetricType.LOSS.value)
    actual_lr = model_training_result.get_metric(
        is_training=True, metric_type=MetricType.LEARNING_RATE.value)
    assert actual_train_loss == pytest.approx(expected_train_loss,
                                              abs=1e-6), "Training loss"
    assert actual_val_loss == pytest.approx(expected_val_loss,
                                            abs=1e-6), "Validation loss"
    assert actual_lr == pytest.approx(expected_learning_rates,
                                      rel=1e-5), "Learning rates"
    test_results = model_testing.model_test(
        config,
        ModelExecutionMode.TRAIN,
        checkpoint_handler=checkpoint_handler)
    assert isinstance(test_results, InferenceMetricsForClassification)
    expected_metrics = [0.636085, 0.735952]
    assert test_results.metrics.values()[MetricType.CROSS_ENTROPY.value] == \
           pytest.approx(expected_metrics, abs=1e-5)
    # Run detailed logs file check only on CPU, it will contain slightly different metrics on GPU, but here
    # we want to mostly assert that the files look reasonable
    if machine_has_gpu:
        return
    # Check epoch_metrics.csv
    epoch_metrics_path = config.outputs_folder / ModelExecutionMode.TRAIN.value / EPOCH_METRICS_FILE_NAME
    # Auto-format will break the long header line, hence the strange way of writing it!
    expected_epoch_metrics = \
        "loss,cross_entropy,accuracy_at_threshold_05,seconds_per_epoch,learning_rate," + \
        "area_under_roc_curve,area_under_pr_curve,accuracy_at_optimal_threshold," \
        "false_positive_rate_at_optimal_threshold,false_negative_rate_at_optimal_threshold," \
        "optimal_threshold,subject_count,epoch,cross_validation_split_index\n" + \
        """0.6866141557693481,0.6866141557693481,0.5,0,0.0001,1.0,1.0,0.5,0.0,0.0,0.529514,2.0,0,-1	
        0.6864652633666992,0.6864652633666992,0.5,0,9.999712322065557e-05,1.0,1.0,0.5,0.0,0.0,0.529475,2.0,1,-1	
        0.6863163113594055,0.6863162517547607,0.5,0,9.999306876841536e-05,1.0,1.0,0.5,0.0,0.0,0.529437,2.0,2,-1	
        0.6861673593521118,0.6861673593521118,0.5,0,9.998613801725043e-05,1.0,1.0,0.5,0.0,0.0,0.529399,2.0,3,-1	
        """
    # We cannot compare columns like "seconds_per_epoch" because timing will obviously vary between machines.
    # Column must still be present, though.
    check_log_file(epoch_metrics_path,
                   expected_epoch_metrics,
                   ignore_columns=[LoggingColumns.SecondsPerEpoch.value])
    # Check metrics.csv: This contains the per-subject per-epoch model outputs
    # Randomization comes out slightly different on Windows, hence only execute the test on Linux
    if common_util.is_windows():
        return
    metrics_path = config.outputs_folder / ModelExecutionMode.TRAIN.value / SUBJECT_METRICS_FILE_NAME
    metrics_expected = \
        """prediction_target,epoch,subject,model_output,label,cross_validation_split_index,data_split
Default,0,S2,0.5295137763023376,1.0,-1,Train
Default,0,S4,0.5216594338417053,0.0,-1,Train
Default,1,S4,0.5214819312095642,0.0,-1,Train
Default,1,S2,0.5294750332832336,1.0,-1,Train
Default,2,S2,0.5294366478919983,1.0,-1,Train
Default,2,S4,0.5213046073913574,0.0,-1,Train
Default,3,S2,0.5293986201286316,1.0,-1,Train
Default,3,S4,0.5211275815963745,0.0,-1,Train
"""
    check_log_file(metrics_path, metrics_expected, ignore_columns=[])
    # Check log METRICS_FILE_NAME inside of the folder epoch_004/Train, which is written when we run model_test.
    # Normally, we would run it on the Test and Val splits, but for convenience we test on the train split here.
    inference_metrics_path = config.outputs_folder / get_epoch_results_path(ModelExecutionMode.TRAIN) / \
                             SUBJECT_METRICS_FILE_NAME
    inference_metrics_expected = \
        """prediction_target,subject,model_output,label,cross_validation_split_index,data_split
Default,S2,0.5293986201286316,1.0,-1,Train
Default,S4,0.5211275815963745,0.0,-1,Train
"""
    check_log_file(inference_metrics_path,
                   inference_metrics_expected,
                   ignore_columns=[])
コード例 #21
0
def test_image_encoder(test_output_dirs: OutputFolderForTests,
                       encode_channels_jointly: bool,
                       use_non_imaging_features: bool,
                       kernel_size_per_encoding_block: Optional[Union[TupleInt3, List[TupleInt3]]],
                       stride_size_per_encoding_block: Optional[Union[TupleInt3, List[TupleInt3]]],
                       reduction_factor: float,
                       expected_num_reduced_features: int,
                       aggregation_type: AggregationType) -> None:
    """
    Test if the image encoder networks can be trained without errors (including GradCam computation and data
    augmentation).
    """
    logging_to_stdout()
    set_random_seed(0)
    dataset_folder = Path(test_output_dirs.make_sub_dir("dataset"))
    scan_size = (6, 64, 60)
    scan_files: List[str] = []
    for s in range(4):
        random_scan = np.random.uniform(0, 1, scan_size)
        scan_file_name = f"scan{s + 1}{NumpyFile.NUMPY.value}"
        np.save(str(dataset_folder / scan_file_name), random_scan)
        scan_files.append(scan_file_name)

    dataset_contents = """subject,channel,path,label,numerical1,numerical2,categorical1,categorical2
S1,week0,scan1.npy,,1,10,Male,Val1
S1,week1,scan2.npy,True,2,20,Female,Val2
S2,week0,scan3.npy,,3,30,Female,Val3
S2,week1,scan4.npy,False,4,40,Female,Val1
S3,week0,scan1.npy,,5,50,Male,Val2
S3,week1,scan3.npy,True,6,60,Male,Val2
"""
    (dataset_folder / "dataset.csv").write_text(dataset_contents)
    numerical_columns = ["numerical1", "numerical2"] if use_non_imaging_features else []
    categorical_columns = ["categorical1", "categorical2"] if use_non_imaging_features else []
    non_image_feature_channels = get_non_image_features_dict(default_channels=["week1", "week0"],
                                                             specific_channels={"categorical2": ["week1"]}) \
        if use_non_imaging_features else {}
    config_for_dataset = ScalarModelBase(
        local_dataset=dataset_folder,
        image_channels=["week0", "week1"],
        image_file_column="path",
        label_channels=["week1"],
        label_value_column="label",
        non_image_feature_channels=non_image_feature_channels,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        should_validate=False
    )
    config_for_dataset.read_dataset_into_dataframe_and_pre_process()

    dataset = ScalarDataset(config_for_dataset,
                            sample_transforms=ScalarItemAugmentation(
                                RandAugmentSlice(is_transformation_for_segmentation_maps=False)))
    assert len(dataset) == 3

    config = ImageEncoder(
        encode_channels_jointly=encode_channels_jointly,
        should_validate=False,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        non_image_feature_channels=non_image_feature_channels,
        categorical_feature_encoder=config_for_dataset.categorical_feature_encoder,
        encoder_dimensionality_reduction_factor=reduction_factor,
        aggregation_type=aggregation_type,
        scan_size=(6, 64, 60)
    )

    if kernel_size_per_encoding_block:
        config.kernel_size_per_encoding_block = kernel_size_per_encoding_block
    if stride_size_per_encoding_block:
        config.stride_size_per_encoding_block = stride_size_per_encoding_block

    config.set_output_to(test_output_dirs.root_dir)
    config.max_batch_grad_cam = 1
    model = create_model_with_temperature_scaling(config)
    input_size: List[Tuple] = [(len(config.image_channels), *scan_size)]
    if use_non_imaging_features:
        input_size.append((config.get_total_number_of_non_imaging_features(),))

        # Original number output channels (unreduced) is
        # num initial channel * (num encoder block - 1) = 4 * (3-1) = 8
        if encode_channels_jointly:
            # reduced_num_channels + num_non_img_features
            assert model.final_num_feature_channels == expected_num_reduced_features + \
                   config.get_total_number_of_non_imaging_features()
        else:
            # num_img_channels * reduced_num_channels + num_non_img_features
            assert model.final_num_feature_channels == len(config.image_channels) * expected_num_reduced_features + \
                   config.get_total_number_of_non_imaging_features()

    summarizer = ModelSummary(model)
    summarizer.generate_summary(input_sizes=input_size)
    config.local_dataset = dataset_folder
    config.validate()
    model_train(config, checkpoint_handler=get_default_checkpoint_handler(model_config=config,
                                                                          project_root=Path(test_output_dirs.root_dir)))
コード例 #22
0
def _test_model_train(output_dirs: TestOutputDirectories,
                      image_channels: Any,
                      ground_truth_ids: Any,
                      no_mask_channel: bool = False) -> None:
    def _check_patch_centers(epoch_results: List[MetricsDict],
                             should_equal: bool) -> None:
        diagnostics_per_epoch = [
            m.diagnostics[MetricType.PATCH_CENTER.value] for m in epoch_results
        ]
        patch_centers_epoch1 = diagnostics_per_epoch[0]
        for diagnostic in diagnostics_per_epoch[1:]:
            assert np.array_equal(patch_centers_epoch1,
                                  diagnostic) == should_equal

    train_config = DummyModel()
    train_config.local_dataset = base_path
    train_config.set_output_to(output_dirs.root_dir)
    train_config.image_channels = image_channels
    train_config.ground_truth_ids = ground_truth_ids
    train_config.mask_id = None if no_mask_channel else train_config.mask_id
    train_config.random_seed = 42
    train_config.class_weights = [0.5, 0.25, 0.25]
    train_config.store_dataset_sample = True

    expected_train_losses = [0.455538, 0.455213]
    expected_val_losses = [0.455190, 0.455139]

    expected_stats = "Epoch\tLearningRate\tTrainLoss\tTrainDice\tValLoss\tValDice\n" \
                     "1\t1.00e-03\t0.456\t0.242\t0.455\t0.000\n" \
                     "2\t5.36e-04\t0.455\t0.247\t0.455\t0.000"

    expected_learning_rates = [[train_config.l_rate], [5.3589e-4]]

    loss_absolute_tolerance = 1e-3
    model_training_result = model_training.model_train(train_config)
    assert isinstance(model_training_result, ModelTrainingResults)

    # check to make sure training batches are NOT all the same across epochs
    _check_patch_centers(model_training_result.train_results_per_epoch,
                         should_equal=False)
    # check to make sure validation batches are all the same across epochs
    _check_patch_centers(model_training_result.val_results_per_epoch,
                         should_equal=True)
    assert isinstance(model_training_result.train_results_per_epoch[0],
                      MetricsDict)
    actual_train_losses = [
        m.get_single_metric(MetricType.LOSS)
        for m in model_training_result.train_results_per_epoch
    ]
    actual_val_losses = [
        m.get_single_metric(MetricType.LOSS)
        for m in model_training_result.val_results_per_epoch
    ]
    print("actual_train_losses = {}".format(actual_train_losses))
    print("actual_val_losses = {}".format(actual_val_losses))
    assert np.allclose(actual_train_losses,
                       expected_train_losses,
                       atol=loss_absolute_tolerance)
    assert np.allclose(actual_val_losses,
                       expected_val_losses,
                       atol=loss_absolute_tolerance)
    assert np.allclose(model_training_result.learning_rates_per_epoch,
                       expected_learning_rates,
                       rtol=1e-6)

    # check output files/directories
    assert train_config.outputs_folder.is_dir()
    assert train_config.logs_folder.is_dir()

    # The train and val folder should contain Tensorflow event files
    assert (train_config.logs_folder / "train").is_dir()
    assert (train_config.logs_folder / "val").is_dir()
    assert len([(train_config.logs_folder / "train").glob("*")]) == 1
    assert len([(train_config.logs_folder / "val").glob("*")]) == 1

    # Checkpoint folder
    # With these settings, we should see a checkpoint only at epoch 2:
    # That's the last epoch, and there should always be checkpoint at the last epoch)
    assert train_config.save_start_epoch == 1
    assert train_config.save_step_epochs == 100
    assert train_config.num_epochs == 2
    assert os.path.isdir(train_config.checkpoint_folder)
    assert os.path.isfile(
        os.path.join(train_config.checkpoint_folder,
                     "2" + CHECKPOINT_FILE_SUFFIX))
    assert (train_config.outputs_folder / DATASET_CSV_FILE_NAME).is_file()
    assert (train_config.outputs_folder /
            STORED_CSV_FILE_NAMES[ModelExecutionMode.TRAIN]).is_file()
    assert (train_config.outputs_folder /
            STORED_CSV_FILE_NAMES[ModelExecutionMode.VAL]).is_file()
    assert_file_contents(train_config.outputs_folder / TRAIN_STATS_FILE,
                         expected_stats)

    # Test for saving of example images
    assert os.path.isdir(train_config.example_images_folder)
    example_files = os.listdir(train_config.example_images_folder)
    assert len(example_files) == 3 * 2
コード例 #23
0
def test_recover_testing_from_run_recovery(
        mean_teacher_model: bool,
        test_output_dirs: OutputFolderForTests) -> None:
    """
    Checks that inference results are the same whether from a checkpoint in the same run, from a run recovery or from a
    local_weights_path param.
    """
    # Train for 4 epochs
    config = DummyClassification()
    if mean_teacher_model:
        config.mean_teacher_alpha = 0.999
    config.set_output_to(test_output_dirs.root_dir / "original")
    os.makedirs(str(config.outputs_folder))
    config.save_start_epoch = 2
    config.save_step_epochs = 2

    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=test_output_dirs.root_dir)
    train_results = model_train(config, checkpoint_handler=checkpoint_handler)
    assert len(train_results.learning_rates_per_epoch) == config.num_epochs

    # Run inference on this
    test_results = model_test(config=config,
                              data_split=ModelExecutionMode.TEST,
                              checkpoint_handler=checkpoint_handler)
    assert isinstance(test_results, InferenceMetricsForClassification)
    assert list(test_results.epochs.keys()) == [config.num_epochs]

    # Mimic using a run recovery and see if it is the same
    config_run_recovery = DummyClassification()
    if mean_teacher_model:
        config_run_recovery.mean_teacher_alpha = 0.999
    config_run_recovery.set_output_to(test_output_dirs.root_dir /
                                      "run_recovery")
    os.makedirs(str(config_run_recovery.outputs_folder))

    checkpoint_handler_run_recovery = get_default_checkpoint_handler(
        model_config=config_run_recovery,
        project_root=test_output_dirs.root_dir)
    # make it seem like run recovery objects have been downloaded
    checkpoint_root = config_run_recovery.checkpoint_folder / "recovered"
    shutil.copytree(str(config.checkpoint_folder), str(checkpoint_root))
    checkpoint_handler_run_recovery.run_recovery = RunRecovery(
        [checkpoint_root])
    test_results_run_recovery = model_test(
        config_run_recovery,
        data_split=ModelExecutionMode.TEST,
        checkpoint_handler=checkpoint_handler_run_recovery)
    assert isinstance(test_results_run_recovery,
                      InferenceMetricsForClassification)
    assert list(test_results_run_recovery.epochs.keys()) == [config.num_epochs]
    assert test_results.epochs[config.num_epochs].values()[MetricType.CROSS_ENTROPY.value] == \
           test_results_run_recovery.epochs[config.num_epochs].values()[MetricType.CROSS_ENTROPY.value]

    # Run inference with the local checkpoints
    config_local_weights = DummyClassification()
    if mean_teacher_model:
        config_local_weights.mean_teacher_alpha = 0.999
    config_local_weights.set_output_to(test_output_dirs.root_dir /
                                       "local_weights_path")
    os.makedirs(str(config_local_weights.outputs_folder))

    local_weights_path = test_output_dirs.root_dir / "local_weights_file.pth"
    shutil.copyfile(
        str(
            create_checkpoint_path(config.checkpoint_folder,
                                   epoch=config.num_epochs)),
        local_weights_path)
    config_local_weights.local_weights_path = local_weights_path

    checkpoint_handler_local_weights = get_default_checkpoint_handler(
        model_config=config_local_weights,
        project_root=test_output_dirs.root_dir)
    checkpoint_handler_local_weights.discover_and_download_checkpoints_from_previous_runs(
    )
    test_results_local_weights = model_test(
        config_local_weights,
        data_split=ModelExecutionMode.TEST,
        checkpoint_handler=checkpoint_handler_local_weights)
    assert isinstance(test_results_local_weights,
                      InferenceMetricsForClassification)
    assert list(test_results_local_weights.epochs.keys()) == [0]
    assert test_results.epochs[config.num_epochs].values()[MetricType.CROSS_ENTROPY.value] == \
           test_results_local_weights.epochs[0].values()[MetricType.CROSS_ENTROPY.value]
コード例 #24
0
    def run(self) -> None:
        """
        Driver function to run a ML experiment. If an offline cross validation run is requested, then
        this function is recursively called for each cross validation split.
        """
        if self.is_offline_cross_val_parent_run():
            if self.model_config.is_segmentation_model:
                raise NotImplementedError("Offline cross validation is only supported for classification models.")
            self.spawn_offline_cross_val_classification_child_runs()
            return

        # Get the AzureML context in which the script is running
        if not self.model_config.is_offline_run and PARENT_RUN_CONTEXT is not None:
            logging.info("Setting tags from parent run.")
            self.set_run_tags_from_parent()

        self.save_build_info_for_dotnet_consumers()

        # Set data loader start method
        self.set_multiprocessing_start_method()

        # configure recovery container if provided
        checkpoint_handler = CheckpointHandler(model_config=self.model_config,
                                               azure_config=self.azure_config,
                                               project_root=self.project_root,
                                               run_context=RUN_CONTEXT)
        checkpoint_handler.download_recovery_checkpoints_or_weights()
        # do training and inference, unless the "only register" switch is set (which requires a run_recovery
        # to be valid).
        if not self.azure_config.only_register_model:
            # Set local_dataset to the mounted path specified in azure_runner.py, if any, or download it if that fails
            # and config.local_dataset was not already set.
            self.model_config.local_dataset = self.mount_or_download_dataset()
            # Check for existing dataset.csv file in the correct locations. Skip that if a dataset has already been
            # loaded (typically only during tests)
            if self.model_config.dataset_data_frame is None:
                assert self.model_config.local_dataset is not None
                ml_util.validate_dataset_paths(
                    self.model_config.local_dataset,
                    self.model_config.dataset_csv)

            # train a new model if required
            if self.azure_config.train:
                with logging_section("Model training"):
                    model_train(self.model_config, checkpoint_handler, num_nodes=self.azure_config.num_nodes)
            else:
                self.model_config.write_dataset_files()
                self.create_activation_maps()

            # log the number of epochs used for model training
            RUN_CONTEXT.log(name="Train epochs", value=self.model_config.num_epochs)

        # We specify the ModelProcessing as DEFAULT here even if the run_recovery points to an ensemble run, because
        # the current run is a single one. See the documentation of ModelProcessing for more details.
        self.run_inference_and_register_model(checkpoint_handler, ModelProcessing.DEFAULT)

        if self.model_config.generate_report:
            self.generate_report(ModelProcessing.DEFAULT)

        # If this is an cross validation run, and the present run is child run 0, then wait for the sibling runs,
        # build the ensemble model, and write a report for that.
        if self.model_config.number_of_cross_validation_splits > 0:
            if self.model_config.should_wait_for_other_cross_val_child_runs():
                self.wait_for_runs_to_finish()
                self.create_ensemble_model()