Ejemplo n.º 1
0
def test_predict_non_ensemble(batch_size: int, empty_labels: bool) -> None:
    config = ConstantScalarConfig(1.)
    model_and_info = ModelAndInfo(config=config,
                                  model_execution_mode=ModelExecutionMode.TEST,
                                  checkpoint_path=None)
    model_loaded = model_and_info.try_create_model_load_from_checkpoint_and_adjust(
    )
    assert model_loaded

    model = model_and_info.model

    pipeline = ScalarInferencePipeline(model, config, 0, 0)
    actual_labels = torch.zeros(
        (batch_size, 1)) * np.nan if empty_labels else torch.zeros(
            (batch_size, 1))
    data = {
        "metadata": [GeneralSampleMetadata(id='2')] * batch_size,
        "label": actual_labels,
        "images": torch.zeros(
            ((batch_size, 1) + config.expected_image_size_zyx)),
        "numerical_non_image_features": torch.tensor([]),
        "categorical_non_image_features": torch.tensor([]),
        "segmentations": torch.tensor([])
    }

    results = pipeline.predict(data)
    ids, labels, predicted = results.subject_ids, results.labels, results.model_outputs
    assert ids == ['2'] * batch_size
    assert torch.allclose(labels, actual_labels, equal_nan=True)
    # The model always returns 1, so predicted should be sigmoid(1)
    assert torch.allclose(predicted, torch.full((batch_size, 1), 0.731058578))
def test_visualization_with_sequence_model(use_combined_model: bool,
                                           imaging_feature_type: ImagingFeatureType,
                                           test_output_dirs: TestOutputDirectories) -> None:
    config = ToySequenceModel(use_combined_model, imaging_feature_type, should_validate=False)
    config.set_output_to(test_output_dirs.root_dir)
    config.dataset_data_frame = _get_mock_sequence_dataset()
    config.num_epochs = 1

    model_and_info = ModelAndInfo(config=config, model_execution_mode=ModelExecutionMode.TEST,
                                  is_mean_teacher=False, checkpoint_path=None)
    model_loaded = model_and_info.try_create_model_load_from_checkpoint_and_adjust()
    assert model_loaded

    model = model_and_info.model

    dataloader = SequenceDataset(config,
                                 data_frame=config.dataset_data_frame).as_data_loader(shuffle=False,
                                                                                      batch_size=2)
    # Patch the load_images function that will be called once we access a dataset item
    image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, SCAN_SIZE),
                                                      segmentations=np.random.randint(0, 2, SCAN_SIZE))
    with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg):
        batch = next(iter(dataloader))
        model_inputs_and_labels = get_scalar_model_inputs_and_labels(config, model, batch)  # type: ignore
    number_sequences = model_inputs_and_labels.model_inputs[0].shape[1]
    number_subjects = len(model_inputs_and_labels.subject_ids)
    visualizer = VisualizationMaps(model, config)
    guided_grad_cams, grad_cams, pseudo_cam_non_img, probas = visualizer.generate(
        model_inputs_and_labels.model_inputs)
    if use_combined_model:
        if imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
            assert guided_grad_cams.shape[:2] == (number_subjects, number_sequences * 2)
            assert grad_cams.shape[:2] == (number_subjects, number_sequences * 2)
        else:
            assert guided_grad_cams.shape[:2] == (number_subjects, number_sequences)
            assert grad_cams.shape[:2] == (number_subjects, number_sequences)
    else:
        assert guided_grad_cams is None
        assert grad_cams is None
        assert pseudo_cam_non_img.shape[:2] == (number_subjects, number_sequences)
        assert probas.shape[0] == number_subjects
    non_image_features = config.numerical_columns + config.categorical_columns
    non_imaging_plot_labels = visualizer._get_non_imaging_plot_labels(model_inputs_and_labels.data_item,
                                                                      non_image_features,
                                                                      index=0,
                                                                      target_position=3)
    assert non_imaging_plot_labels == ['numerical1_0',
                                       'numerical2_0',
                                       'cat1_0',
                                       'numerical1_1',
                                       'numerical2_1',
                                       'cat1_1',
                                       'numerical1_2',
                                       'numerical2_2',
                                       'cat1_2',
                                       'numerical1_3',
                                       'numerical2_3',
                                       'cat1_3']
Ejemplo n.º 3
0
def test_try_create_model_load_from_checkpoint_and_adjust(config: ModelConfigBase, checkpoint_path: str,
                                                          model_execution_mode: ModelExecutionMode) -> None:
    config.use_gpu = True

    # no checkpoint path provided
    model_and_info = ModelAndInfo(config,
                                  model_execution_mode=model_execution_mode,
                                  checkpoint_path=None)

    with pytest.raises(ValueError):
        model_and_info.model

    model_loaded = model_and_info.try_create_model_load_from_checkpoint_and_adjust()
    assert model_loaded
    assert isinstance(model_and_info.model, DataParallelModel)

    # Invalid checkpoint path provided
    model_and_info = ModelAndInfo(config,
                                  model_execution_mode=model_execution_mode,
                                  checkpoint_path=full_ml_test_data_path("non_exist.pth.tar"))
    model_loaded = model_and_info.try_create_model_load_from_checkpoint_and_adjust()
    assert not model_loaded
    # Current code assumes that even if this function returns False, the model itself was created, only the checkpoint
    # loading failed.
    assert isinstance(model_and_info.model, DataParallelModel)

    # Valid checkpoint path provided
    model_and_info = ModelAndInfo(config,
                                  model_execution_mode=model_execution_mode,
                                  checkpoint_path=full_ml_test_data_path(checkpoint_path))
    model_loaded = model_and_info.try_create_model_load_from_checkpoint_and_adjust()
    assert model_loaded
    assert isinstance(model_and_info.model, DataParallelModel)
    assert model_and_info.checkpoint_epoch == 1
def test_predict_non_ensemble(batch_size: int, empty_labels: bool) -> None:
    config = ClassificationModelForTesting()
    model: Any = ScalarOnesModel(config.expected_image_size_zyx, 1.)
    update_model_for_multiple_gpus(ModelAndInfo(model),
                                   args=config,
                                   execution_mode=ModelExecutionMode.TEST)
    pipeline = ScalarInferencePipeline(model, config, 0, 0)
    actual_labels = torch.zeros(
        (batch_size, 1)) * np.nan if empty_labels else torch.zeros(
            (batch_size, 1))
    data = {
        "metadata": [GeneralSampleMetadata(id='2')] * batch_size,
        "label": actual_labels,
        "images": torch.zeros(
            ((batch_size, 1) + config.expected_image_size_zyx)),
        "numerical_non_image_features": torch.tensor([]),
        "categorical_non_image_features": torch.tensor([]),
        "segmentations": torch.tensor([])
    }

    results = pipeline.predict(data)
    ids, labels, predicted = results.subject_ids, results.labels, results.model_outputs
    assert ids == ['2'] * batch_size
    assert torch.allclose(labels, actual_labels, equal_nan=True)
    # The model always returns 1, so predicted should be sigmoid(1)
    assert torch.allclose(predicted, torch.full((batch_size, 1), 0.731058578))
def test_amp_and_parallel_for_scalar_models(
        test_output_dirs: TestOutputDirectories,
        execution_mode: ModelExecutionMode, use_mixed_precision: bool) -> None:
    """
    Tests the mix precision flag and data parallel for scalar models.
    """
    assert machine_has_gpu, "This test must be executed on a GPU machine."
    assert torch.cuda.device_count(
    ) > 1, "This test must be executed on a multi-GPU machine"
    config = ClassificationModelForTesting()
    config.use_mixed_precision = use_mixed_precision
    model = DummyScalarModel(
        expected_image_size_zyx=config.expected_image_size_zyx,
        activation=Identity())
    model.use_mixed_precision = use_mixed_precision
    model_and_info = ModelAndInfo(model=model,
                                  model_execution_mode=execution_mode)
    # This is the same logic spelt out in update_model_for_multiple_gpu
    # execution_mode == ModelExecutionMode.TRAIN or (not use_model_parallel), which is always True in our case
    use_data_parallel = True
    model_and_info = model_util.update_model_for_multiple_gpus(
        model_and_info, config)
    if use_data_parallel:
        assert isinstance(model_and_info.model, DataParallelModel)
    data_loaders = config.create_data_loaders()
    gradient_scaler = GradScaler() if use_mixed_precision else None
    train_val_parameters: TrainValidateParameters = TrainValidateParameters(
        model=model_and_info.model,
        data_loader=data_loaders[execution_mode],
        in_training_mode=execution_mode == ModelExecutionMode.TRAIN,
        gradient_scaler=gradient_scaler,
        dataframe_loggers=MetricsDataframeLoggers(
            Path(test_output_dirs.root_dir)),
        summary_writers=SummaryWriters(train=None, val=None)  # type: ignore
    )
    training_steps = ModelTrainingStepsForScalarModel(config,
                                                      train_val_parameters)
    sample = list(data_loaders[execution_mode])[0]
    model_input = get_scalar_model_inputs_and_labels(config, model, sample)
    logits, posteriors, loss = training_steps._compute_model_output_and_loss(
        model_input)
    # When using DataParallel, we expect to get a list of tensors back, one per GPU.
    if use_data_parallel:
        assert isinstance(logits, list)
        first_logit = logits[0]
    else:
        first_logit = logits
    if use_mixed_precision:
        assert first_logit.dtype == torch.float16
        assert posteriors.dtype == torch.float16
        # BCEWithLogitsLoss outputs float32, even with float16 args
        assert loss.dtype == torch.float32
    else:
        assert first_logit.dtype == torch.float32
        assert posteriors.dtype == torch.float32
        assert loss.dtype == torch.float32
    # Verify that forward pass does not throw. It would for example if it fails to gather tensors or not convert
    # float16 to float32
    _, _, _ = training_steps._compute_model_output_and_loss(model_input)
def test_try_create_model_and_load_from_checkpoint(
        config: ModelConfigBase, checkpoint_path: str) -> None:
    # no checkpoint path provided
    model_and_info = ModelAndInfo(config,
                                  model_execution_mode=ModelExecutionMode.TEST,
                                  is_mean_teacher=False,
                                  checkpoint_path=None)

    with pytest.raises(ValueError):
        model_and_info.model

    model_loaded = model_and_info.try_create_model_and_load_from_checkpoint()
    assert model_loaded
    if isinstance(config, SegmentationModelBase):
        assert isinstance(model_and_info.model, BaseModel)
    else:
        assert isinstance(model_and_info.model, DeviceAwareModule)

    # Invalid checkpoint path provided
    model_and_info = ModelAndInfo(
        config,
        model_execution_mode=ModelExecutionMode.TEST,
        is_mean_teacher=False,
        checkpoint_path=full_ml_test_data_path("non_exist.pth.tar"))
    model_loaded = model_and_info.try_create_model_and_load_from_checkpoint()
    assert not model_loaded
    # Current code assumes that even if this function returns False, the model itself was created, only the checkpoint
    # loading failed.
    if isinstance(config, SegmentationModelBase):
        assert isinstance(model_and_info.model, BaseModel)
    else:
        assert isinstance(model_and_info.model, DeviceAwareModule)

    # Valid checkpoint path provided
    model_and_info = ModelAndInfo(
        config,
        model_execution_mode=ModelExecutionMode.TEST,
        is_mean_teacher=False,
        checkpoint_path=full_ml_test_data_path(checkpoint_path))
    model_loaded = model_and_info.try_create_model_and_load_from_checkpoint()
    assert model_loaded
    if isinstance(config, SegmentationModelBase):
        assert isinstance(model_and_info.model, BaseModel)
    else:
        assert isinstance(model_and_info.model, DeviceAwareModule)
    assert model_and_info.checkpoint_epoch == 1
def test_invalid_stride_size() -> None:
    config = SegmentationModelBase(
        architecture="UNet3D",
        feature_channels=[1],
        crop_size=(64, 64, 64),
        test_crop_size=(80, 80, 80),
        image_channels=["mr"],
        ground_truth_ids=["tumour_mass", "subtract"],
        train_batch_size=8,
        inference_batch_size=1,
        inference_stride_size=(120, 120, 120),
        should_validate=False
    )
    with pytest.raises(ValueError) as ex:
        model_and_info = ModelAndInfo(config=config, model_execution_mode=ModelExecutionMode.TEST,
                                      checkpoint_path=None)
        model_and_info.try_create_model_load_from_checkpoint_and_adjust()

    assert "inference stride size must be smaller" in ex.value.args[0]
    assert str(config.inference_stride_size) in ex.value.args[0]
    assert str(config.test_crop_size) in ex.value.args[0]
def test_predict_ensemble(batch_size: int) -> None:
    config_returns_0 = ConstantScalarConfig(0.)
    model_and_info_returns_0 = ModelAndInfo(config=config_returns_0, model_execution_mode=ModelExecutionMode.TEST,
                                            is_mean_teacher=False, checkpoint_path=None)
    model_loaded = model_and_info_returns_0.try_create_model_load_from_checkpoint_and_adjust()
    assert model_loaded
    model_returns_0 = model_and_info_returns_0.model

    config_returns_1 = ConstantScalarConfig(1.)
    model_and_info_returns_1 = ModelAndInfo(config=config_returns_1, model_execution_mode=ModelExecutionMode.TEST,
                                            is_mean_teacher=False, checkpoint_path=None)
    model_loaded = model_and_info_returns_1.try_create_model_load_from_checkpoint_and_adjust()
    assert model_loaded
    model_returns_1 = model_and_info_returns_1.model

    pipeline_0 = ScalarInferencePipeline(model_returns_0, config_returns_0, 0, 0)
    pipeline_1 = ScalarInferencePipeline(model_returns_0, config_returns_0, 0, 1)
    pipeline_2 = ScalarInferencePipeline(model_returns_0, config_returns_0, 0, 2)
    pipeline_3 = ScalarInferencePipeline(model_returns_1, config_returns_1, 0, 3)
    pipeline_4 = ScalarInferencePipeline(model_returns_1, config_returns_1, 0, 4)
    ensemble_pipeline = ScalarEnsemblePipeline([pipeline_0, pipeline_1, pipeline_2, pipeline_3, pipeline_4],
                                               config_returns_0, EnsembleAggregationType.Average)
    data = {"metadata": [GeneralSampleMetadata(id='2')] * batch_size,
            "label": torch.zeros((batch_size, 1)),
            "images": torch.zeros(((batch_size, 1) + config_returns_0.expected_image_size_zyx)),
            "numerical_non_image_features": torch.tensor([]),
            "categorical_non_image_features": torch.tensor([]),
            "segmentations": torch.tensor([])}

    results = ensemble_pipeline.predict(data)
    ids, labels, predicted = results.subject_ids, results.labels, results.model_outputs
    assert ids == ['2'] * batch_size
    assert torch.equal(labels, torch.zeros((batch_size, 1)))
    # 3 models return 0, 2 return 1, so predicted should be ((sigmoid(0)*3)+(sigmoid(1)*2))/5
    assert torch.allclose(predicted, torch.full((batch_size, 1), 0.592423431))
def test_predict_ensemble(batch_size: int) -> None:
    config = ClassificationModelForTesting()
    model_returns_0: Any = ScalarOnesModel(config.expected_image_size_zyx, 0.)
    model_returns_1: Any = ScalarOnesModel(config.expected_image_size_zyx, 1.)
    model_and_opt_0 = update_model_for_multiple_gpus(
        ModelAndInfo(model_returns_0),
        args=config,
        execution_mode=ModelExecutionMode.TEST)
    model_returns_0 = model_and_opt_0.model
    model_and_opt_1 = update_model_for_multiple_gpus(
        ModelAndInfo(model_returns_1),
        args=config,
        execution_mode=ModelExecutionMode.TEST)
    model_returns_1 = model_and_opt_1.model
    pipeline_0 = ScalarInferencePipeline(model_returns_0, config, 0, 0)
    pipeline_1 = ScalarInferencePipeline(model_returns_0, config, 0, 1)
    pipeline_2 = ScalarInferencePipeline(model_returns_0, config, 0, 2)
    pipeline_3 = ScalarInferencePipeline(model_returns_1, config, 0, 3)
    pipeline_4 = ScalarInferencePipeline(model_returns_1, config, 0, 4)
    ensemble_pipeline = ScalarEnsemblePipeline(
        [pipeline_0, pipeline_1, pipeline_2, pipeline_3, pipeline_4], config,
        EnsembleAggregationType.Average)
    data = {
        "metadata": [GeneralSampleMetadata(id='2')] * batch_size,
        "label": torch.zeros((batch_size, 1)),
        "images": torch.zeros(
            ((batch_size, 1) + config.expected_image_size_zyx)),
        "numerical_non_image_features": torch.tensor([]),
        "categorical_non_image_features": torch.tensor([]),
        "segmentations": torch.tensor([])
    }

    results = ensemble_pipeline.predict(data)
    ids, labels, predicted = results.subject_ids, results.labels, results.model_outputs
    assert ids == ['2'] * batch_size
    assert torch.equal(labels, torch.zeros((batch_size, 1)))
    # 3 models return 0, 2 return 1, so predicted should be ((sigmoid(0)*3)+(sigmoid(1)*2))/5
    assert torch.allclose(predicted, torch.full((batch_size, 1), 0.592423431))
def model_train(config: ModelConfigBase,
                checkpoint_handler: CheckpointHandler) -> ModelTrainingResults:
    """
    The main training loop. It creates the model, dataset, optimizer_type, and criterion, then proceeds
    to train the model. If a checkpoint was specified, then it loads the checkpoint before resuming training.

    :param config: The arguments which specify all required information.
    :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
    :raises TypeError: If the arguments are of the wrong type.
    :raises ValueError: When there are issues loading a previous checkpoint.
    """
    # Save the dataset files for later use in cross validation analysis
    config.write_dataset_files()

    # set the random seed for all libraries
    ml_util.set_random_seed(config.get_effective_random_seed(),
                            "Patch visualization")
    # Visualize how patches are sampled for segmentation models. This changes the random generator, but we don't
    # want training to depend on how many patients we visualized, and hence set the random seed again right after.
    with logging_section(
            "Visualizing the effect of sampling random crops for training"):
        visualize_random_crops_for_dataset(config)
    ml_util.set_random_seed(config.get_effective_random_seed(),
                            "Model training")

    logging.debug("Creating the PyTorch model.")

    # Create the train loader and validation loader to load images from the dataset
    data_loaders = config.create_data_loaders()

    # Get the path to the checkpoint to recover from
    checkpoint_path = checkpoint_handler.get_recovery_path_train()

    models_and_optimizer = ModelAndInfo(
        config=config,
        model_execution_mode=ModelExecutionMode.TRAIN,
        checkpoint_path=checkpoint_path)

    # Create the main model
    # If continuing from a previous run at a specific epoch, then load the previous model.
    model_loaded = models_and_optimizer.try_create_model_and_load_from_checkpoint(
    )
    if not model_loaded:
        raise ValueError(
            "There was no checkpoint file available for the model for given start_epoch {}"
            .format(config.start_epoch))

    # Print out a detailed breakdown of layers, memory consumption and time.
    generate_and_print_model_summary(config, models_and_optimizer.model)

    # Move model to GPU and adjust for multiple GPUs
    models_and_optimizer.adjust_model_for_gpus()

    # Create the mean teacher model and move to GPU
    if config.compute_mean_teacher_model:
        mean_teacher_model_loaded = models_and_optimizer.try_create_mean_teacher_model_load_from_checkpoint_and_adjust(
        )
        if not mean_teacher_model_loaded:
            raise ValueError(
                "There was no checkpoint file available for the mean teacher model "
                f"for given start_epoch {config.start_epoch}")

    # Create optimizer
    models_and_optimizer.create_optimizer()
    if checkpoint_handler.should_load_optimizer_checkpoint():
        optimizer_loaded = models_and_optimizer.try_load_checkpoint_for_optimizer(
        )
        if not optimizer_loaded:
            raise ValueError(
                f"There was no checkpoint file available for the optimizer for given start_epoch "
                f"{config.start_epoch}")

    # Create checkpoint directory for this run if it doesn't already exist
    logging.info(f"Models are saved at {config.checkpoint_folder}")
    if not config.checkpoint_folder.is_dir():
        config.checkpoint_folder.mkdir()

    # Create the SummaryWriters for Tensorboard
    writers = create_summary_writers(config)
    config.create_dataframe_loggers()

    # Create LR scheduler
    l_rate_scheduler = SchedulerWithWarmUp(config,
                                           models_and_optimizer.optimizer)

    # Training loop
    logging.info("Starting training")
    train_results_per_epoch, val_results_per_epoch, learning_rates_per_epoch = [], [], []

    resource_monitor = None
    if config.monitoring_interval_seconds > 0:
        # initialize and start GPU monitoring
        diagnostics_events = config.logs_folder / "diagnostics"
        logging.info(
            f"Starting resource monitor, outputting to {diagnostics_events}")
        resource_monitor = ResourceMonitor(
            interval_seconds=config.monitoring_interval_seconds,
            tensorboard_folder=diagnostics_events)
        resource_monitor.start()

    gradient_scaler = GradScaler(
    ) if config.use_gpu and config.use_mixed_precision else None
    optimal_temperature_scale_values = []
    for epoch in config.get_train_epochs():
        logging.info("Starting epoch {}".format(epoch))
        save_epoch = config.should_save_epoch(
            epoch) and models_and_optimizer.optimizer is not None

        # store the learning rates used for each epoch
        epoch_lrs = l_rate_scheduler.get_last_lr()
        learning_rates_per_epoch.append(epoch_lrs)

        train_val_params: TrainValidateParameters = \
            TrainValidateParameters(data_loader=data_loaders[ModelExecutionMode.TRAIN],
                                    model=models_and_optimizer.model,
                                    mean_teacher_model=models_and_optimizer.mean_teacher_model,
                                    epoch=epoch,
                                    optimizer=models_and_optimizer.optimizer,
                                    gradient_scaler=gradient_scaler,
                                    epoch_learning_rate=epoch_lrs,
                                    summary_writers=writers,
                                    dataframe_loggers=config.metrics_data_frame_loggers,
                                    in_training_mode=True)
        training_steps = create_model_training_steps(config, train_val_params)
        train_epoch_results = train_or_validate_epoch(training_steps)
        train_results_per_epoch.append(train_epoch_results.metrics)

        metrics.validate_and_store_model_parameters(writers.train, epoch,
                                                    models_and_optimizer.model)
        # Run without adjusting weights on the validation set
        train_val_params.in_training_mode = False
        train_val_params.data_loader = data_loaders[ModelExecutionMode.VAL]
        # if temperature scaling is enabled then do not save validation metrics for the checkpoint epochs
        # as these will be re-computed after performing temperature scaling on the validation set.
        if isinstance(config, SequenceModelBase):
            train_val_params.save_metrics = not (
                save_epoch and config.temperature_scaling_config)

        training_steps = create_model_training_steps(config, train_val_params)
        val_epoch_results = train_or_validate_epoch(training_steps)
        val_results_per_epoch.append(val_epoch_results.metrics)

        if config.is_segmentation_model:
            metrics.store_epoch_stats_for_segmentation(
                config.outputs_folder, epoch, epoch_lrs,
                train_epoch_results.metrics, val_epoch_results.metrics)

        if save_epoch:
            # perform temperature scaling if required
            if isinstance(
                    config,
                    SequenceModelBase) and config.temperature_scaling_config:
                optimal_temperature, scaled_val_results = \
                    temperature_scaling_steps(config, train_val_params, val_epoch_results)
                optimal_temperature_scale_values.append(optimal_temperature)
                # overwrite the metrics for the epoch with the metrics from the temperature scaled model
                val_results_per_epoch[-1] = scaled_val_results.metrics

            models_and_optimizer.save_checkpoint(epoch)

        # Updating the learning rate should happen at the end of the training loop, so that the
        # initial learning rate will be used for the very first epoch.
        l_rate_scheduler.step()

    model_training_results = ModelTrainingResults(
        train_results_per_epoch=train_results_per_epoch,
        val_results_per_epoch=val_results_per_epoch,
        learning_rates_per_epoch=learning_rates_per_epoch,
        optimal_temperature_scale_values_per_checkpoint_epoch=
        optimal_temperature_scale_values)

    logging.info("Finished training")

    # Since we have trained the model further, let the checkpoint_handler object know so it can handle
    # checkpoints correctly.
    checkpoint_handler.additional_training_done()

    # Upload visualization directory to AML run context to be able to see it
    # in the Azure UI.
    if config.max_batch_grad_cam > 0 and config.visualization_folder.exists():
        RUN_CONTEXT.upload_folder(name=VISUALIZATION_FOLDER,
                                  path=str(config.visualization_folder))

    writers.close_all()
    config.metrics_data_frame_loggers.close_all()
    if resource_monitor:
        # stop the resource monitoring process
        logging.info(
            "Shutting down the resource monitor process. Aggregate resource utilization:"
        )
        for name, value in resource_monitor.read_aggregate_metrics():
            logging.info(f"{name}: {value}")
            if not is_offline_run_context(RUN_CONTEXT):
                RUN_CONTEXT.log(name, value)
        resource_monitor.kill()

    return model_training_results
def model_train(config: ModelConfigBase, run_recovery: Optional[RunRecovery] = None) -> ModelTrainingResults:
    """
    The main training loop. It creates the model, dataset, optimizer_type, and criterion, then proceeds
    to train the model. If a checkpoint was specified, then it loads the checkpoint before resuming training.

    :param config: The arguments which specify all required information.
    :param run_recovery: Recovery information to restart training from an existing run.
    :raises TypeError: If the arguments are of the wrong type.
    :raises ValueError: When there are issues loading a previous checkpoint.
    """
    # Save the dataset files for later use in cross validation analysis
    config.write_dataset_files()

    # set the random seed for all libraries
    ml_util.set_random_seed(config.get_effective_random_seed(), "Model Training")

    logging.debug("Creating the PyTorch model.")

    # Create the train loader and validation loader to load images from the dataset
    data_loaders = config.create_data_loaders()

    # Create models, optimizers, and whether is_mean_teacher
    checkpoint_path = get_recovery_path_train(run_recovery=run_recovery,
                                              is_mean_teacher=False,
                                              epoch=config.start_epoch)
    models_and_optimizers = [ModelAndInfo(config=config,
                                          model_execution_mode=ModelExecutionMode.TRAIN,
                                          is_mean_teacher=False,
                                          checkpoint_path=checkpoint_path if config.should_load_checkpoint_for_training() else None)]

    if config.compute_mean_teacher_model:
        checkpoint_path = get_recovery_path_train(run_recovery=run_recovery,
                                                  is_mean_teacher=True,
                                                  epoch=config.start_epoch)
        models_and_optimizers.append(ModelAndInfo(config=config,
                                                  model_execution_mode=ModelExecutionMode.TRAIN,
                                                  is_mean_teacher=True,
                                                  checkpoint_path=checkpoint_path if config.should_load_checkpoint_for_training() else None))

    # Create the models.
    # If continuing from a previous run at a specific epoch, then load the previous model.
    for model_and_info in models_and_optimizers:
        model_loaded = model_and_info.try_create_model_and_load_from_checkpoint()
        if not model_loaded:
            raise ValueError("There was no checkpoint file available for the model for given start_epoch {}"
                             .format(config.start_epoch))

    # Print out a detailed breakdown of layers, memory consumption and time.
    generate_and_print_model_summary(config, models_and_optimizers[0].model)

    # Move model to GPU and adjust for multiple GPUs
    models_and_optimizers[0].adjust_model_for_gpus()
    if len(models_and_optimizers) > 1:
        models_and_optimizers[1].create_summary_and_adjust_model_for_gpus()

    # Create optimizer
    optimizer_loaded = models_and_optimizers[0].try_create_optimizer_and_load_from_checkpoint()
    if not optimizer_loaded:
        raise ValueError("There was no checkpoint file available for the optimizer for given start_epoch {}"
                         .format(config.start_epoch))

    # Create checkpoint directory for this run if it doesn't already exist
    logging.info("Models are saved at {}".format(config.checkpoint_folder))
    if not os.path.isdir(config.checkpoint_folder):
        os.makedirs(config.checkpoint_folder)

    # Create the SummaryWriters for Tensorboard
    writers = create_summary_writers(config)
    config.create_dataframe_loggers()

    model = models_and_optimizers[0].model
    optimizer = models_and_optimizers[0].optimizer
    mean_teacher_model = models_and_optimizers[1].model if len(models_and_optimizers) > 1 else None

    # Create LR scheduler
    l_rate_scheduler = SchedulerWithWarmUp(config, optimizer)

    # Training loop
    logging.info("Starting training")
    train_results_per_epoch, val_results_per_epoch, learning_rates_per_epoch = [], [], []

    resource_monitor = None
    if config.monitoring_interval_seconds > 0:
        # initialize and start GPU monitoring
        resource_monitor = ResourceMonitor(interval_seconds=config.monitoring_interval_seconds,
                                           tb_log_file_path=str(config.logs_folder / "diagnostics"))
        resource_monitor.start()

    gradient_scaler = GradScaler() if config.use_gpu and config.use_mixed_precision else None
    optimal_temperature_scale_values = []
    for epoch in config.get_train_epochs():
        logging.info("Starting epoch {}".format(epoch))
        save_epoch = config.should_save_epoch(epoch) and optimizer is not None

        # store the learning rates used for each epoch
        epoch_lrs = l_rate_scheduler.get_last_lr()
        learning_rates_per_epoch.append(epoch_lrs)

        train_val_params: TrainValidateParameters = \
            TrainValidateParameters(data_loader=data_loaders[ModelExecutionMode.TRAIN],
                                    model=model,
                                    mean_teacher_model=mean_teacher_model,
                                    epoch=epoch,
                                    optimizer=optimizer,
                                    gradient_scaler=gradient_scaler,
                                    epoch_learning_rate=epoch_lrs,
                                    summary_writers=writers,
                                    dataframe_loggers=config.metrics_data_frame_loggers,
                                    in_training_mode=True)
        training_steps = create_model_training_steps(config, train_val_params)
        train_epoch_results = train_or_validate_epoch(training_steps)
        train_results_per_epoch.append(train_epoch_results.metrics)

        metrics.validate_and_store_model_parameters(writers.train, epoch, model)
        # Run without adjusting weights on the validation set
        train_val_params.in_training_mode = False
        train_val_params.data_loader = data_loaders[ModelExecutionMode.VAL]
        # if temperature scaling is enabled then do not save validation metrics for the checkpoint epochs
        # as these will be re-computed after performing temperature scaling on the validation set.
        if isinstance(config, SequenceModelBase):
            train_val_params.save_metrics = not (save_epoch and config.temperature_scaling_config)

        training_steps = create_model_training_steps(config, train_val_params)
        val_epoch_results = train_or_validate_epoch(training_steps)
        val_results_per_epoch.append(val_epoch_results.metrics)

        if config.is_segmentation_model:
            metrics.store_epoch_stats_for_segmentation(config.outputs_folder, epoch, epoch_lrs,
                                                       train_epoch_results.metrics,
                                                       val_epoch_results.metrics)

        if save_epoch:
            # perform temperature scaling if required
            if isinstance(config, SequenceModelBase) and config.temperature_scaling_config:
                optimal_temperature, scaled_val_results = \
                    temperature_scaling_steps(config, train_val_params, val_epoch_results)
                optimal_temperature_scale_values.append(optimal_temperature)
                # overwrite the metrics for the epoch with the metrics from the temperature scaled model
                val_results_per_epoch[-1] = scaled_val_results.metrics

            assert optimizer is not None
            save_checkpoint(model, optimizer, epoch, config)
            if config.compute_mean_teacher_model:
                assert mean_teacher_model is not None
                save_checkpoint(mean_teacher_model, optimizer, epoch, config, mean_teacher_model=True)

        # Updating the learning rate should happen at the end of the training loop, so that the
        # initial learning rate will be used for the very first epoch.
        l_rate_scheduler.step()

    model_training_results = ModelTrainingResults(
        train_results_per_epoch=train_results_per_epoch,
        val_results_per_epoch=val_results_per_epoch,
        learning_rates_per_epoch=learning_rates_per_epoch,
        optimal_temperature_scale_values_per_checkpoint_epoch=optimal_temperature_scale_values
    )

    logging.info("Finished training")

    # Upload visualization directory to AML run context to be able to see it
    # in the Azure UI.
    if config.max_batch_grad_cam > 0 and config.visualization_folder.exists():
        RUN_CONTEXT.upload_folder(name=VISUALIZATION_FOLDER, path=str(config.visualization_folder))

    writers.close_all()
    config.metrics_data_frame_loggers.close_all()
    if resource_monitor:
        # stop the resource monitoring process
        resource_monitor.kill()

    return model_training_results
def test_amp_activated(use_model_parallel: bool,
                       execution_mode: ModelExecutionMode,
                       use_mixed_precision: bool) -> None:
    """
    Tests the mix precision flag and the model parallel flag.
    """
    assert machine_has_gpu, "This test must be executed on a GPU machine."
    assert torch.cuda.device_count(
    ) > 1, "This test must be executed on a multi-GPU machine"
    # image, labels, and mask to run forward and backward passes
    image = torch.from_numpy(
        np.random.uniform(size=[1, 1, 4, 4, 4]).astype(
            ImageDataType.IMAGE.value))
    labels = torch.from_numpy(
        np.random.uniform(size=[1, 2, 4, 4, 4]).astype(
            ImageDataType.SEGMENTATION.value))
    mask = torch.from_numpy((np.round(np.random.uniform(
        size=[1, 4, 4, 4])).astype(dtype=ImageDataType.MASK.value)))

    crop_size = (4, 4, 4)

    model = SimpleModel(1, [1], 2, 2)
    model_config = SegmentationModelBase(
        crop_size=crop_size,
        image_channels=["ct"],
        ground_truth_ids=["Lung"],
        use_mixed_precision=use_mixed_precision,
        use_model_parallel=use_model_parallel,
        should_validate=False)
    assert model_config.use_gpu
    # Move the model to the GPU. This is mostly to avoid issues with AMP, which has trouble
    # with first using a GPU model and later using a CPU-based one.
    model = model.cuda()
    optimizer = model_util.create_optimizer(model_config, model)
    model_and_info = ModelAndInfo(model, optimizer)
    try:
        model_and_info_amp = model_util.update_model_for_multiple_gpus(
            model_and_info, model_config, execution_mode)
    except NotImplementedError as ex:
        if use_model_parallel:
            # The SimpleModel does not implement model partitioning, and should hence fail at this step.
            assert "Model partitioning is not implemented" in str(ex)
            return
        else:
            raise ValueError(f"Expected this call to succeed, but got: {ex}")

    # This is the same logic spelt out in update_model_for_multiple_gpu
    use_data_parallel = (execution_mode == ModelExecutionMode.TRAIN) or (
        not use_model_parallel)
    if use_data_parallel:
        assert isinstance(model_and_info.model, DataParallelModel)
    gradient_scaler = GradScaler() if use_mixed_precision else None
    criterion = lambda x, y: torch.tensor([0.0], requires_grad=True).cuda()
    pipeline = SegmentationForwardPass(model_and_info_amp.model,
                                       model_config,
                                       batch_size=1,
                                       optimizer=optimizer,
                                       gradient_scaler=gradient_scaler,
                                       criterion=criterion)
    logits, _ = pipeline._compute_loss(image, labels)
    # When using DataParallel, we expect to get a list of tensors back, one per GPU.
    if use_data_parallel:
        assert isinstance(logits, list)
        first_logit = logits[0]
    else:
        first_logit = logits
    if use_mixed_precision:
        assert first_logit.dtype == torch.float16
    else:
        assert first_logit.dtype == torch.float32
    # Verify that forward and backward passes do not throw an exception
    pipeline._forward_pass(patches=image, mask=mask, labels=labels)
def test_register_and_score_model(
        is_ensemble: bool, dataset_expected_spacing_xyz: Any,
        model_outside_package: bool,
        test_output_dirs: OutputFolderForTests) -> None:
    """
    End-to-end test which ensures the scoring pipeline is functioning as expected by performing the following:
    1) Registering a pre-trained model to AML
    2) Checking that a model zip from the registered model can be created successfully
    3) Calling the scoring pipeline to check inference can be run from the published model successfully
    """
    # We are creating checkpoints on the fly in this test, writing a randomly initialized model.
    set_random_seed(0)
    # Get an existing config as template
    loader = get_model_loader(
        "Tests.ML.configs" if model_outside_package else None)
    config: SegmentationModelBase = loader.create_model_config_from_name(
        model_name="BasicModel2EpochsOutsidePackage"
        if model_outside_package else "BasicModel2Epochs")
    config.dataset_expected_spacing_xyz = dataset_expected_spacing_xyz
    config.set_output_to(test_output_dirs.root_dir)
    checkpoints_absolute = []
    model_and_info = ModelAndInfo(
        config=config, model_execution_mode=ModelExecutionMode.TRAIN)
    model_and_info.create_model()
    model_and_info.create_optimizer()
    checkpoints_absolute.append(model_and_info.save_checkpoint(epoch=10))
    if is_ensemble:
        checkpoints_absolute.append(model_and_info.save_checkpoint(epoch=20))
    checkpoints_relative = [
        f.relative_to(config.checkpoint_folder) for f in checkpoints_absolute
    ]
    azureml_model = None
    # Simulate a project root: We can't derive that from the repository root because that might point
    # into Python's package folder
    project_root = Path(__file__).parent.parent
    # Double-check that we are at the right place, by testing for a file that would quite certainly not be found
    # somewhere else
    assert (project_root / fixed_paths.SCORE_SCRIPT).is_file()
    try:
        azure_config = get_default_azure_config()
        if model_outside_package:
            azure_config.extra_code_directory = "Tests"  # contains BasicModel2EpochsOutsidePackage
        deployment_hook = lambda cfg, azure_cfg, mdl, is_ens: (Path(
            cfg.model_name), azure_cfg.docker_shm_size)
        ml_runner = MLRunner(config,
                             azure_config,
                             project_root=project_root,
                             model_deployment_hook=deployment_hook)
        registration_result = ml_runner.register_segmentation_model(
            model_description="",
            checkpoint_paths=checkpoints_absolute,
            model_proc=ModelProcessing.DEFAULT)
        assert registration_result is not None
        azureml_model, deployment_result = registration_result
        assert azureml_model is not None
        assert deployment_result == (Path(config.model_name),
                                     azure_config.docker_shm_size)

        # download the registered model and test that we can run the score pipeline on it
        model_root = Path(
            azureml_model.download(str(test_output_dirs.root_dir)))
        # The model needs to contain score.py at the root, the (merged) environment definition,
        # and the inference config.
        expected_files = [
            *fixed_paths.SCRIPTS_AT_ROOT,
            fixed_paths.ENVIRONMENT_YAML_FILE_NAME,
            fixed_paths.MODEL_INFERENCE_JSON_FILE_NAME,
            "InnerEye/ML/runner.py",
        ]
        # All checkpoints go into their own folder
        expected_files.extend(
            str(Path(CHECKPOINT_FOLDER) / c) for c in checkpoints_relative)
        for expected_file in expected_files:
            assert (model_root /
                    expected_file).is_file(), f"File {expected_file} missing"

        # create a dummy datastore to store the image data
        test_datastore = test_output_dirs.root_dir / "test_datastore"
        # move test data into the data folder to simulate an actual run
        train_and_test_data_dir = full_ml_test_data_path("train_and_test_data")
        img_files = ["id1_channel1.nii.gz", "id1_channel2.nii.gz"]
        data_root = test_datastore / fixed_paths.DEFAULT_DATA_FOLDER
        data_root.mkdir(parents=True)
        for f in img_files:
            shutil.copy(str(train_and_test_data_dir / f), str(data_root))

        # run score pipeline as a separate process
        python_executable = sys.executable
        [return_code1,
         stdout1] = SubprocessConfig(process=python_executable,
                                     args=["--version"
                                           ]).spawn_and_monitor_subprocess()
        assert return_code1 == 0
        print(f"Executing Python version {stdout1[0]}")
        return_code, stdout2 = SubprocessConfig(
            process=python_executable,
            args=[
                str(model_root / fixed_paths.SCORE_SCRIPT),
                f"--data_folder={str(data_root)}",
                f"--image_files={img_files[0]},{img_files[1]}",
                "--use_gpu=False"
            ]).spawn_and_monitor_subprocess()

        # check that the process completed as expected
        assert return_code == 0, f"Subprocess failed with return code {return_code}. Stdout: {os.linesep.join(stdout2)}"
        expected_segmentation_path = Path(
            model_root) / DEFAULT_RESULT_IMAGE_NAME
        assert expected_segmentation_path.exists(
        ), f"Result file not found: {expected_segmentation_path}"

        # sanity check the resulting segmentation
        expected_shape = get_nifti_shape(train_and_test_data_dir /
                                         img_files[0])
        image_header = get_unit_image_header()
        assert_nifti_content(str(expected_segmentation_path), expected_shape,
                             image_header, [3], np.ubyte)

    finally:
        # delete the registered model
        if azureml_model:
            azureml_model.delete()
Ejemplo n.º 14
0
def test_anomaly_detection(value_to_insert: float,
                           in_training_mode: bool) -> None:
    """
    Test anomaly detection for the segmentation forward pass.
    :param value_to_insert: The value to insert in the image image (nan, inf, or a valid float)
    :param in_training_mode: If true, run the segmentation forward pass in training mode, otherwise use the
    settings for running on the validation set.
    :return:
    """
    image_size = [1, 1, 4, 4, 4]
    labels_size = [1, 2, 4, 4, 4]
    mask_size = [1, 4, 4, 4]
    crop_size = (4, 4, 4)
    inference_stride_size = (2, 2, 2)
    ground_truth_ids = ["Lung"]

    # image to run inference on
    image = torch.from_numpy(
        np.random.uniform(size=image_size).astype(ImageDataType.IMAGE.value))
    # labels for criterion
    labels = torch.from_numpy(
        np.random.uniform(size=labels_size).astype(
            ImageDataType.SEGMENTATION.value))
    # create a random mask if required
    mask = torch.from_numpy((np.round(np.random.uniform(
        size=mask_size)).astype(dtype=ImageDataType.MASK.value)))

    config = SegmentationModelBase(crop_size=crop_size,
                                   inference_stride_size=inference_stride_size,
                                   image_channels=["ct"],
                                   ground_truth_ids=ground_truth_ids,
                                   should_validate=False,
                                   detect_anomaly=True)

    model_and_info = ModelAndInfo(
        config=config,
        model_execution_mode=ModelExecutionMode.TRAIN,
        checkpoint_path=None)
    model_and_info._model: BaseModel = SimpleModel(1, [1], 2,
                                                   2)  # type: ignore
    model_and_info.create_summary_and_adjust_model_for_gpus()
    model_and_info.try_create_optimizer_and_load_from_checkpoint()
    config.use_gpu = False

    model = model_and_info.model
    optimizer = model_and_info.optimizer

    # Create the loss criterion
    criterion = lambda x, y: torch.tensor(value_to_insert, requires_grad=True)
    pipeline = SegmentationForwardPass(model,
                                       config,
                                       batch_size=1,
                                       optimizer=optimizer,
                                       in_training_mode=in_training_mode,
                                       criterion=criterion)
    image[0, 0, 0, 0, 0] = value_to_insert
    if np.isnan(value_to_insert) or np.isinf(value_to_insert):
        with pytest.raises(RuntimeError) as ex:
            pipeline.forward_pass_patches(patches=image,
                                          mask=mask,
                                          labels=labels)
        assert f"loss computation returned {value_to_insert}" in str(ex)
    else:
        pipeline.forward_pass_patches(patches=image, mask=mask, labels=labels)
Ejemplo n.º 15
0
def test_mean_teacher_model(test_output_dirs: OutputFolderForTests) -> None:
    """
    Test training and weight updates of the mean teacher model computation.
    """
    def _get_parameters_of_model(model: DeviceAwareModule) -> Any:
        """
        Returns the iterator of model parameters
        """
        if isinstance(model, DataParallelModel):
            return model.module.parameters()
        else:
            return model.parameters()

    config = DummyClassification()
    config.set_output_to(test_output_dirs.root_dir)
    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=test_output_dirs.root_dir)

    config.num_epochs = 1
    # Set train batch size to be arbitrary big to ensure we have only one training step
    # i.e. one mean teacher update.
    config.train_batch_size = 100
    # Train without mean teacher
    model_train(config, checkpoint_handler=checkpoint_handler)

    # Retrieve the weight after one epoch
    model_and_info = ModelAndInfo(
        config=config,
        model_execution_mode=ModelExecutionMode.TEST,
        checkpoint_path=config.get_path_to_checkpoint(epoch=1))
    model_and_info.try_create_model_and_load_from_checkpoint()
    model = model_and_info.model
    model_weight = next(_get_parameters_of_model(model))

    # Get the starting weight of the mean teacher model
    ml_util.set_random_seed(config.get_effective_random_seed())

    model_and_info_mean_teacher = ModelAndInfo(
        config=config,
        model_execution_mode=ModelExecutionMode.TEST,
        checkpoint_path=None)
    model_and_info_mean_teacher.try_create_model_and_load_from_checkpoint()

    model_and_info_mean_teacher.try_create_mean_teacher_model_and_load_from_checkpoint(
    )
    mean_teach_model = model_and_info_mean_teacher.mean_teacher_model
    assert mean_teach_model is not None  # for mypy
    initial_weight_mean_teacher_model = next(
        _get_parameters_of_model(mean_teach_model))

    # Now train with mean teacher and check the update of the weight
    alpha = 0.999
    config.mean_teacher_alpha = alpha
    model_train(config, checkpoint_handler=checkpoint_handler)

    # Retrieve weight of mean teacher model saved in the checkpoint
    model_and_info_mean_teacher = ModelAndInfo(
        config=config,
        model_execution_mode=ModelExecutionMode.TEST,
        checkpoint_path=config.get_path_to_checkpoint(1))
    model_and_info_mean_teacher.try_create_mean_teacher_model_and_load_from_checkpoint(
    )
    mean_teacher_model = model_and_info_mean_teacher.mean_teacher_model
    assert mean_teacher_model is not None  # for mypy
    result_weight = next(_get_parameters_of_model(mean_teacher_model))
    # Retrieve the associated student weight
    model_and_info_mean_teacher.try_create_model_and_load_from_checkpoint()
    student_model = model_and_info_mean_teacher.model
    student_model_weight = next(_get_parameters_of_model(student_model))

    # Assert that the student weight corresponds to the weight of a simple training without mean teacher
    # computation
    assert student_model_weight.allclose(model_weight)

    # Check the update of the parameters
    assert torch.all(alpha * initial_weight_mean_teacher_model +
                     (1 - alpha) * student_model_weight == result_weight)
Ejemplo n.º 16
0
def test_try_create_optimizer_and_load_from_checkpoint(config: ModelConfigBase, checkpoint_path: str) -> None:
    # no checkpoint path provided
    model_and_info = ModelAndInfo(config,
                                  model_execution_mode=ModelExecutionMode.TEST,
                                  checkpoint_path=None)

    with pytest.raises(ValueError):
        model_and_info.optimizer

    model_loaded = model_and_info.try_create_model_and_load_from_checkpoint()
    assert model_loaded
    optimizer_loaded = model_and_info.try_create_optimizer_and_load_from_checkpoint()
    assert optimizer_loaded
    assert isinstance(model_and_info.optimizer, Optimizer)

    # Invalid checkpoint path provided
    model_and_info = ModelAndInfo(config,
                                  model_execution_mode=ModelExecutionMode.TEST,
                                  checkpoint_path=full_ml_test_data_path("non_exist.pth.tar"))
    model_loaded = model_and_info.try_create_model_and_load_from_checkpoint()
    assert not model_loaded
    # Current code assumes that even if this function returns False, the model itself was created, only the checkpoint
    # loading failed.
    optimizer_loaded = model_and_info.try_create_optimizer_and_load_from_checkpoint()
    assert not optimizer_loaded
    # Current code assumes that even if this function returns False,
    # the optimizer itself was created, only the checkpoint loading failed.
    assert isinstance(model_and_info.optimizer, Optimizer)

    # Valid checkpoint path provided
    model_and_info = ModelAndInfo(config,
                                  model_execution_mode=ModelExecutionMode.TEST,
                                  checkpoint_path=full_ml_test_data_path(checkpoint_path))
    model_loaded = model_and_info.try_create_model_and_load_from_checkpoint()
    assert model_loaded
    assert model_and_info.checkpoint_epoch == 1
    optimizer_loaded = model_and_info.try_create_optimizer_and_load_from_checkpoint()
    assert optimizer_loaded
    assert isinstance(model_and_info.optimizer, Optimizer)
    assert model_and_info.checkpoint_epoch == 1
Ejemplo n.º 17
0
def test_save_checkpoint(config: ModelConfigBase) -> None:
    """
    Test that checkpoints are saved correctly
    """

    config.mean_teacher_alpha = 0.999

    model_and_info = ModelAndInfo(config,
                                  model_execution_mode=ModelExecutionMode.TEST,
                                  checkpoint_path=None)
    model_and_info.try_create_model_and_load_from_checkpoint()
    model_and_info.try_create_mean_teacher_model_and_load_from_checkpoint()
    model_and_info.try_create_optimizer_and_load_from_checkpoint()

    def get_constant_init_function(constant: float) -> Callable:
        def init(layer: nn.Module) -> None:
            if type(layer) == nn.Conv3d:
                layer.weight.data.fill_(constant)  # type: ignore
        return init

    assert model_and_info.mean_teacher_model is not None  # for mypy

    model_and_info.model.apply(get_constant_init_function(1.0))
    model_and_info.mean_teacher_model.apply(get_constant_init_function(2.0))

    epoch = 3

    checkpoint_path = config.get_path_to_checkpoint(epoch=epoch)
    checkpoint_dir = checkpoint_path.parent
    if not os.path.isdir(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    model_and_info.save_checkpoint(epoch=epoch)

    model_and_info_restored = ModelAndInfo(config,
                                           model_execution_mode=ModelExecutionMode.TEST,
                                           checkpoint_path=config.get_path_to_checkpoint(epoch=epoch))
    model_and_info_restored.try_create_model_load_from_checkpoint_and_adjust()
    model_and_info_restored.try_create_mean_teacher_model_load_from_checkpoint_and_adjust()

    assert model_and_info_restored.mean_teacher_model is not None  # for mypy

    for module in model_and_info_restored.model.modules():
        if type(module) == nn.Conv3d:
            assert torch.equal(module.weight.detach(), torch.full_like(module.weight.detach(), 1.0))  # type: ignore

    for module in model_and_info_restored.mean_teacher_model.modules():
        if type(module) == nn.Conv3d:
            assert torch.equal(module.weight.detach(), torch.full_like(module.weight.detach(), 2.0))  # type: ignore