def test_predict_non_ensemble(batch_size: int, empty_labels: bool) -> None:
    config = ConstantScalarConfig(1.)
    model = create_lightning_model(config, set_optimizer_and_scheduler=False)
    assert isinstance(model, ScalarLightning)

    pipeline = ScalarInferencePipeline(model, config, 0)
    actual_labels = torch.zeros(
        (batch_size, 1)) * np.nan if empty_labels else torch.zeros(
            (batch_size, 1))
    data = {
        "metadata": [GeneralSampleMetadata(id='2')] * batch_size,
        "label": actual_labels,
        "images": torch.zeros(
            ((batch_size, 1) + config.expected_image_size_zyx)),
        "numerical_non_image_features": torch.tensor([]),
        "categorical_non_image_features": torch.tensor([]),
        "segmentations": torch.tensor([])
    }

    results = pipeline.predict(data)
    ids, labels, predicted = results.subject_ids, results.labels, results.posteriors
    assert ids == ['2'] * batch_size
    assert torch.allclose(labels, actual_labels, equal_nan=True)
    # The model always returns 1, so predicted should be sigmoid(1)
    assert torch.allclose(predicted, torch.full((batch_size, 1), 0.731058578))
def test_predict_ensemble(batch_size: int) -> None:
    config_returns_0 = ConstantScalarConfig(0.)
    model_returns_0 = create_lightning_model(config_returns_0,
                                             set_optimizer_and_scheduler=False)
    assert isinstance(model_returns_0, ScalarLightning)

    config_returns_1 = ConstantScalarConfig(1.)
    model_returns_1 = create_lightning_model(config_returns_1,
                                             set_optimizer_and_scheduler=False)
    assert isinstance(model_returns_1, ScalarLightning)

    pipeline_0 = ScalarInferencePipeline(model_returns_0, config_returns_0, 0)
    pipeline_1 = ScalarInferencePipeline(model_returns_0, config_returns_0, 1)
    pipeline_2 = ScalarInferencePipeline(model_returns_0, config_returns_0, 2)
    pipeline_3 = ScalarInferencePipeline(model_returns_1, config_returns_1, 3)
    pipeline_4 = ScalarInferencePipeline(model_returns_1, config_returns_1, 4)
    ensemble_pipeline = ScalarEnsemblePipeline(
        [pipeline_0, pipeline_1, pipeline_2, pipeline_3, pipeline_4],
        config_returns_0, EnsembleAggregationType.Average)
    data = {
        "metadata": [GeneralSampleMetadata(id='2')] * batch_size,
        "label":
        torch.zeros((batch_size, 1)),
        "images":
        torch.zeros(
            ((batch_size, 1) + config_returns_0.expected_image_size_zyx)),
        "numerical_non_image_features":
        torch.tensor([]),
        "categorical_non_image_features":
        torch.tensor([]),
        "segmentations":
        torch.tensor([])
    }

    results = ensemble_pipeline.predict(data)
    ids, labels, predicted = results.subject_ids, results.labels, results.posteriors
    assert ids == ['2'] * batch_size
    assert torch.equal(labels, torch.zeros((batch_size, 1)))
    # 3 models return 0, 2 return 1, so predicted should be ((sigmoid(0)*3)+(sigmoid(1)*2))/5
    assert torch.allclose(predicted, torch.full((batch_size, 1), 0.592423431))
예제 #3
0
def create_model_and_store_checkpoint(config: ModelConfigBase,
                                      checkpoint_path: Path) -> None:
    """
    Creates a Lightning model for the given model configuration, and stores it as a checkpoint file.
    If a GPU is available, the model is moved to the GPU before storing.
    The trainer properties `current_epoch` and `global_step` are set to fixed non-default values.
    :param config: The model configuration.
    :param checkpoint_path: The path and filename of the checkpoint file.
    """
    trainer, _ = create_lightning_trainer(config)
    model = create_lightning_model(config)
    if machine_has_gpu:
        model = model.cuda()  # type: ignore
    trainer.model = model
    # Before saving, the values for epoch and step are incremented. Save them here in such a way that we can assert
    # easily later.
    trainer.current_epoch = FIXED_EPOCH - 1
    trainer.global_step = FIXED_GLOBAL_STEP - 1
    # In PL, it is the Trainer's responsibility to save the model. Checkpoint handling refers back to the trainer
    # to get a save_func. Mimicking that here.
    trainer.save_checkpoint(checkpoint_path, weights_only=True)
def run_inference_on_unet(size: TupleInt3) -> None:
    """
    Runs a model forward pass on a freshly created model, with an input image of the given size.
    Asserts that the model prediction has the same size as the input image.
    """
    fg_classes = ["tumour_mass", "subtract"]
    number_of_classes = len(fg_classes) + 1
    config = SegmentationModelBase(
        architecture="UNet3D",
        local_dataset=Path("dummy"),
        feature_channels=[1],
        kernel_size=3,
        largest_connected_component_foreground_classes=fg_classes,
        posterior_smoothing_mm=(2, 2, 2),
        crop_size=(64, 64, 64),
        # test_crop_size must be larger than 'size for the bug to trigger
        test_crop_size=(80, 80, 80),
        image_channels=["mr"],
        ground_truth_ids=fg_classes,
        ground_truth_ids_display_names=fg_classes,
        colours=[(255, 0, 0)] * len(fg_classes),
        fill_holes=[False] * len(fg_classes),
        mask_id=None,
        class_weights=[1.0 / number_of_classes] * number_of_classes,
        train_batch_size=8,
        inference_batch_size=1,
        inference_stride_size=(40, 40, 40),
        use_mixed_precision=True
    )
    lightning_model = create_lightning_model(config)
    assert isinstance(lightning_model, SegmentationLightning)
    pipeline = InferencePipeline(model=lightning_model, model_config=config)
    image = np.random.uniform(-1, 1, (1,) + size)
    result = pipeline.predict_and_post_process_whole_image(image, mask=np.ones(size), voxel_spacing_mm=(1, 1, 1))
    # All posteriors and segmentations must have the size of the input image
    for p in [*result.posteriors, result.segmentation]:
        assert p.shape == size
        # Check that all results are not NaN. In particular, if stride size is not adjusted
        # correctly, the results would be partially NaN.
        image_util.check_array_range(p)
예제 #5
0
def model_train(config: ModelConfigBase,
                checkpoint_handler: CheckpointHandler,
                num_nodes: int = 1) -> ModelTrainingResults:
    """
    The main training loop. It creates the Pytorch model based on the configuration options passed in,
    creates a Pytorch Lightning trainer, and trains the model.
    If a checkpoint was specified, then it loads the checkpoint before resuming training.
    :param config: The arguments which specify all required information.
    :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
    :param num_nodes: The number of nodes to use in distributed training.
    """
    # Get the path to the checkpoint to recover from
    checkpoint_path = checkpoint_handler.get_recovery_path_train()
    # This reads the dataset file, and possibly sets required pre-processing objects, like one-hot encoder
    # for categorical features, that need to be available before creating the model.
    config.read_dataset_if_needed()

    # Create the trainer object. Backup the environment variables before doing that, in case we need to run a second
    # training in the unit tests.d
    old_environ = dict(os.environ)
    seed_everything(config.get_effective_random_seed())
    trainer, storing_logger = create_lightning_trainer(config, checkpoint_path, num_nodes=num_nodes)

    logging.info(f"GLOBAL_RANK: {os.getenv('GLOBAL_RANK')}, LOCAL_RANK {os.getenv('LOCAL_RANK')}. "
                 f"trainer.global_rank: {trainer.global_rank}")
    logging.debug("Creating the PyTorch model.")
    lightning_model = create_lightning_model(config)
    lightning_model.storing_logger = storing_logger

    resource_monitor = None
    # Execute some bookkeeping tasks only once if running distributed:
    if is_rank_zero():
        config.write_args_file()
        logging.info(str(config))
        # Save the dataset files for later use in cross validation analysis
        config.write_dataset_files()
        logging.info(f"Model checkpoints are saved at {config.checkpoint_folder}")

        # set the random seed for all libraries
        ml_util.set_random_seed(config.get_effective_random_seed(), "Patch visualization")
        # Visualize how patches are sampled for segmentation models. This changes the random generator, but we don't
        # want training to depend on how many patients we visualized, and hence set the random seed again right after.
        with logging_section("Visualizing the effect of sampling random crops for training"):
            visualize_random_crops_for_dataset(config)

        # Print out a detailed breakdown of layers, memory consumption and time.
        generate_and_print_model_summary(config, lightning_model.model)

        if config.monitoring_interval_seconds > 0:
            # initialize and start GPU monitoring
            diagnostics_events = config.logs_folder / "diagnostics"
            logging.info(f"Starting resource monitor, outputting to {diagnostics_events}")
            resource_monitor = ResourceMonitor(interval_seconds=config.monitoring_interval_seconds,
                                               tensorboard_folder=diagnostics_events)
            resource_monitor.start()

    # Training loop
    logging.info("Starting training")

    lightning_data = TrainingAndValidationDataLightning(config)  # type: ignore
    # When trying to store the config object in the constructor, it does not appear to get stored at all, later
    # reference of the object simply fail. Hence, have to set explicitly here.
    lightning_data.config = config
    trainer.fit(lightning_model, datamodule=lightning_data)
    trainer.logger.close()  # type: ignore
    lightning_model.close_all_loggers()
    world_size = getattr(trainer, "world_size", 0)
    is_azureml_run = not config.is_offline_run
    # Per-subject model outputs for regression models are written per rank, and need to be aggregated here.
    # Each thread per rank will come here, and upload its files to the run outputs. Rank 0 will later download them.
    if is_azureml_run and world_size > 1 and isinstance(lightning_model, ScalarLightning):
        upload_output_file_as_temp(lightning_model.train_subject_outputs_logger.csv_path, config.outputs_folder)
        upload_output_file_as_temp(lightning_model.val_subject_outputs_logger.csv_path, config.outputs_folder)
    # DDP will start multiple instances of the runner, one for each GPU. Those should terminate here after training.
    # We can now use the global_rank of the Lightining model, rather than environment variables, because DDP has set
    # all necessary properties.
    if lightning_model.global_rank != 0:
        logging.info(f"Terminating training thread with rank {lightning_model.global_rank}.")
        sys.exit()

    logging.info("Choosing the best checkpoint and removing redundant files.")
    cleanup_checkpoint_folder(config.checkpoint_folder)
    # Lightning modifies a ton of environment variables. If we first run training and then the test suite,
    # those environment variables will mislead the training runs in the test suite, and make them crash.
    # Hence, restore the original environment after training.
    os.environ.clear()
    os.environ.update(old_environ)

    if world_size and isinstance(lightning_model, ScalarLightning):
        if is_azureml_run and world_size > 1:
            # In a DDP run on the local box, all ranks will write to local disk, hence no download needed.
            # In a multi-node DDP, each rank would upload to AzureML, and rank 0 will now download all results and
            # concatenate
            for rank in range(world_size):
                for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]:
                    file = mode.value + "/" + get_subject_output_file_per_rank(rank)
                    RUN_CONTEXT.download_file(name=TEMP_PREFIX + file, output_file_path=config.outputs_folder / file)
        # Concatenate all temporary file per execution mode
        for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]:
            temp_files = (config.outputs_folder / mode.value).rglob(SUBJECT_OUTPUT_PER_RANK_PREFIX + "*")
            result_file = config.outputs_folder / mode.value / SUBJECT_METRICS_FILE_NAME
            for i, file in enumerate(temp_files):
                temp_file_contents = file.read_text()
                if i == 0:
                    # Copy the first file as-is, including the first line with the column headers
                    result_file.write_text(temp_file_contents)
                else:
                    # For all files but the first one, cut off the header line.
                    result_file.write_text(os.linesep.join(temp_file_contents.splitlines()[1:]))

    model_training_results = ModelTrainingResults(
        train_results_per_epoch=list(storing_logger.to_metrics_dicts(prefix_filter=TRAIN_PREFIX).values()),
        val_results_per_epoch=list(storing_logger.to_metrics_dicts(prefix_filter=VALIDATION_PREFIX).values()),
        train_diagnostics=lightning_model.train_diagnostics,
        val_diagnostics=lightning_model.val_diagnostics,
        optimal_temperature_scale_values_per_checkpoint_epoch=[]
    )

    logging.info("Finished training")

    # Since we have trained the model further, let the checkpoint_handler object know so it can handle
    # checkpoints correctly.
    checkpoint_handler.additional_training_done()

    # Upload visualization directory to AML run context to be able to see it
    # in the Azure UI.
    if config.max_batch_grad_cam > 0 and config.visualization_folder.exists():
        RUN_CONTEXT.upload_folder(name=VISUALIZATION_FOLDER, path=str(config.visualization_folder))

    if resource_monitor:
        # stop the resource monitoring process
        logging.info("Shutting down the resource monitor process. Aggregate resource utilization:")
        for name, value in resource_monitor.read_aggregate_metrics():
            logging.info(f"{name}: {value}")
            if not config.is_offline_run:
                RUN_CONTEXT.log(name, value)
        resource_monitor.kill()

    return model_training_results