Example #1
0
def test_visualization_for_different_target_weeks(test_output_dirs: TestOutputDirectories) -> None:
    """
    Tests that the visualizations are differentiated depending on the target week
    for which we visualize it.
    """
    config = ToyMultiLabelSequenceModel(should_validate=False)
    config.set_output_to(test_output_dirs.root_dir)
    config.dataset_data_frame = _get_multi_label_sequence_dataframe()
    config.pre_process_dataset_dataframe()
    model = create_model_with_temperature_scaling(config)
    dataloader = SequenceDataset(config,
                                 data_frame=config.dataset_data_frame).as_data_loader(shuffle=False,
                                                                                      batch_size=2)
    batch = next(iter(dataloader))
    model_inputs_and_labels = get_scalar_model_inputs_and_labels(config, model, batch)  # type: ignore

    visualizer = VisualizationMaps(model, config)
    # Pseudo-grad cam explaining the prediction at target sequence 2
    _, _, pseudo_cam_non_img_3, probas_3 = visualizer.generate(model_inputs_and_labels.model_inputs,
                                                               target_position=2,
                                                               target_label_index=2)
    # Pseudo-grad cam explaining the prediction at target sequence 0
    _, _, pseudo_cam_non_img_1, probas_1 = visualizer.generate(model_inputs_and_labels.model_inputs,
                                                               target_position=0,
                                                               target_label_index=0)
    assert pseudo_cam_non_img_1.shape[1] == 1
    assert pseudo_cam_non_img_3.shape[1] == 3
    # Both visualizations should not be equal
    assert np.any(pseudo_cam_non_img_1 != pseudo_cam_non_img_3)
    assert np.any(probas_3 != probas_1)
    def predict(self, sample: Dict[str,
                                   Any]) -> ScalarInferencePipelineBase.Result:
        """
        Runs the forward pass on a single batch.
        :param sample: Single batch of input data.
                        In the form of a dict containing at least the fields:
                        metadata, label, images, numerical_non_image_features,
                        categorical_non_image_features and segmentations.
        :return: Returns ScalarInferencePipelineBase.Result with  the subject ids, ground truth labels and predictions.
        """
        assert isinstance(self.model_config, ScalarModelBase)
        model_inputs_and_labels = get_scalar_model_inputs_and_labels(
            self.model_config, self.model, sample)
        subject_ids = model_inputs_and_labels.subject_ids
        labels = self.model_config.get_gpu_tensor_if_possible(
            model_inputs_and_labels.labels)
        model_output: torch.Tensor = self.model.forward(
            *model_inputs_and_labels.model_inputs)
        if isinstance(model_output, list):
            # Model output is a list if we are using data parallel. Here, this will be a degenerate list with
            # only 1 element
            model_output = torch.nn.parallel.gather(model_output,
                                                    target_device=0)

        # Apply any post loss normalization to logits
        model_output = self.model_config.get_post_loss_logits_normalization_function(
        )(model_output)
        # Cast labels and model outputs back to float32, if the model had been run in mixed precision
        return ScalarInferencePipelineBase.Result(subject_ids, labels.float(),
                                                  model_output.float())
def generate_and_print_model_summary(config: ModelConfigBase, model: DeviceAwareModule) -> None:
    """
    Writes a human readable summary of the present model to logging.info, and logs the number of trainable
    parameters to AzureML.

    :param config: The configuration for the model.
    :param model: The instantiated Pytorch model.
    """
    random_state = RandomStateSnapshot.snapshot_random_state()
    # There appears to be a bug in apex, where previous use (in training for example) causes problems
    # when another model is later built on the CPU (for example, before loading from a checkpoint)
    # https://github.com/NVIDIA/apex/issues/694
    # Hence, move the model to the GPU before doing model summary.
    if config.use_gpu:
        model = model.cuda()
    if isinstance(config, ScalarModelBase):
        # To generate the model summary, read the first item of the dataset. Then use the model's own
        # get_model_input function to convert the dataset item to input tensors, and feed them through the model.
        train_dataset = config.get_torch_dataset_for_inference(ModelExecutionMode.TRAIN)
        train_item_0 = next(iter(train_dataset.as_data_loader(shuffle=False, batch_size=1, num_dataload_workers=0)))
        model_inputs = get_scalar_model_inputs_and_labels(config, model, train_item_0).model_inputs
        # The model inputs may already be converted to float16, assuming that we would do mixed precision.
        # However, the model is not yet converted to float16 when this function is called, hence convert back to float32
        summary = ModelSummary(model)
        summary.generate_summary(input_tensors=model_inputs, log_summaries_to_files=config.log_summaries_to_files)
    elif config.is_segmentation_model:
        summary_for_segmentation_models(config, model)
        assert model.summarizer
        summary = model.summarizer  # type: ignore
    else:
        raise ValueError("Don't know how to generate a summary for this type of model?")
    RUN_CONTEXT.log(LoggingColumns.NumTrainableParameters, summary.n_trainable_params)
    random_state.restore_random_state()
def test_amp_and_parallel_for_scalar_models(
        test_output_dirs: TestOutputDirectories,
        execution_mode: ModelExecutionMode, use_mixed_precision: bool) -> None:
    """
    Tests the mix precision flag and data parallel for scalar models.
    """
    assert machine_has_gpu, "This test must be executed on a GPU machine."
    assert torch.cuda.device_count(
    ) > 1, "This test must be executed on a multi-GPU machine"
    config = ClassificationModelForTesting()
    config.use_mixed_precision = use_mixed_precision
    model = DummyScalarModel(
        expected_image_size_zyx=config.expected_image_size_zyx,
        activation=Identity())
    model.use_mixed_precision = use_mixed_precision
    model_and_info = ModelAndInfo(model=model,
                                  model_execution_mode=execution_mode)
    # This is the same logic spelt out in update_model_for_multiple_gpu
    # execution_mode == ModelExecutionMode.TRAIN or (not use_model_parallel), which is always True in our case
    use_data_parallel = True
    model_and_info = model_util.update_model_for_multiple_gpus(
        model_and_info, config)
    if use_data_parallel:
        assert isinstance(model_and_info.model, DataParallelModel)
    data_loaders = config.create_data_loaders()
    gradient_scaler = GradScaler() if use_mixed_precision else None
    train_val_parameters: TrainValidateParameters = TrainValidateParameters(
        model=model_and_info.model,
        data_loader=data_loaders[execution_mode],
        in_training_mode=execution_mode == ModelExecutionMode.TRAIN,
        gradient_scaler=gradient_scaler,
        dataframe_loggers=MetricsDataframeLoggers(
            Path(test_output_dirs.root_dir)),
        summary_writers=SummaryWriters(train=None, val=None)  # type: ignore
    )
    training_steps = ModelTrainingStepsForScalarModel(config,
                                                      train_val_parameters)
    sample = list(data_loaders[execution_mode])[0]
    model_input = get_scalar_model_inputs_and_labels(config, model, sample)
    logits, posteriors, loss = training_steps._compute_model_output_and_loss(
        model_input)
    # When using DataParallel, we expect to get a list of tensors back, one per GPU.
    if use_data_parallel:
        assert isinstance(logits, list)
        first_logit = logits[0]
    else:
        first_logit = logits
    if use_mixed_precision:
        assert first_logit.dtype == torch.float16
        assert posteriors.dtype == torch.float16
        # BCEWithLogitsLoss outputs float32, even with float16 args
        assert loss.dtype == torch.float32
    else:
        assert first_logit.dtype == torch.float32
        assert posteriors.dtype == torch.float32
        assert loss.dtype == torch.float32
    # Verify that forward pass does not throw. It would for example if it fails to gather tensors or not convert
    # float16 to float32
    _, _, _ = training_steps._compute_model_output_and_loss(model_input)
def test_visualization_with_sequence_model(use_combined_model: bool,
                                           imaging_feature_type: ImagingFeatureType,
                                           test_output_dirs: TestOutputDirectories) -> None:
    config = ToySequenceModel(use_combined_model, imaging_feature_type, should_validate=False)
    config.set_output_to(test_output_dirs.root_dir)
    config.dataset_data_frame = _get_mock_sequence_dataset()
    config.num_epochs = 1

    model_and_info = ModelAndInfo(config=config, model_execution_mode=ModelExecutionMode.TEST,
                                  is_mean_teacher=False, checkpoint_path=None)
    model_loaded = model_and_info.try_create_model_load_from_checkpoint_and_adjust()
    assert model_loaded

    model = model_and_info.model

    dataloader = SequenceDataset(config,
                                 data_frame=config.dataset_data_frame).as_data_loader(shuffle=False,
                                                                                      batch_size=2)
    # Patch the load_images function that will be called once we access a dataset item
    image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, SCAN_SIZE),
                                                      segmentations=np.random.randint(0, 2, SCAN_SIZE))
    with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg):
        batch = next(iter(dataloader))
        model_inputs_and_labels = get_scalar_model_inputs_and_labels(config, model, batch)  # type: ignore
    number_sequences = model_inputs_and_labels.model_inputs[0].shape[1]
    number_subjects = len(model_inputs_and_labels.subject_ids)
    visualizer = VisualizationMaps(model, config)
    guided_grad_cams, grad_cams, pseudo_cam_non_img, probas = visualizer.generate(
        model_inputs_and_labels.model_inputs)
    if use_combined_model:
        if imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
            assert guided_grad_cams.shape[:2] == (number_subjects, number_sequences * 2)
            assert grad_cams.shape[:2] == (number_subjects, number_sequences * 2)
        else:
            assert guided_grad_cams.shape[:2] == (number_subjects, number_sequences)
            assert grad_cams.shape[:2] == (number_subjects, number_sequences)
    else:
        assert guided_grad_cams is None
        assert grad_cams is None
        assert pseudo_cam_non_img.shape[:2] == (number_subjects, number_sequences)
        assert probas.shape[0] == number_subjects
    non_image_features = config.numerical_columns + config.categorical_columns
    non_imaging_plot_labels = visualizer._get_non_imaging_plot_labels(model_inputs_and_labels.data_item,
                                                                      non_image_features,
                                                                      index=0,
                                                                      target_position=3)
    assert non_imaging_plot_labels == ['numerical1_0',
                                       'numerical2_0',
                                       'cat1_0',
                                       'numerical1_1',
                                       'numerical2_1',
                                       'cat1_1',
                                       'numerical1_2',
                                       'numerical2_2',
                                       'cat1_2',
                                       'numerical1_3',
                                       'numerical2_3',
                                       'cat1_3']
def test_visualization_with_scalar_model(use_non_imaging_features: bool,
                                         imaging_feature_type: ImagingFeatureType,
                                         encode_channels_jointly: bool,
                                         test_output_dirs: TestOutputDirectories) -> None:
    dataset_contents = """subject,channel,path,label,numerical1,numerical2,categorical1,categorical2
    S1,week0,scan1.npy,,1,10,Male,Val1
    S1,week1,scan2.npy,True,2,20,Female,Val2
    S2,week0,scan3.npy,,3,30,Female,Val3
    S2,week1,scan4.npy,False,4,40,Female,Val1
    S3,week0,scan1.npy,,5,50,Male,Val2
    S3,week1,scan3.npy,True,6,60,Male,Val2
    """
    dataset_dataframe = pd.read_csv(StringIO(dataset_contents), dtype=str)
    numerical_columns = ["numerical1", "numerical2"] if use_non_imaging_features else []
    categorical_columns = ["categorical1", "categorical2"] if use_non_imaging_features else []
    non_image_feature_channels = get_non_image_features_dict(default_channels=["week1", "week0"],
                                                             specific_channels={"categorical2": ["week1"]}) \
        if use_non_imaging_features else {}

    config = ImageEncoder(
        local_dataset=Path(),
        encode_channels_jointly=encode_channels_jointly,
        should_validate=False,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        imaging_feature_type=imaging_feature_type,
        non_image_feature_channels=non_image_feature_channels,
        categorical_feature_encoder=CategoricalToOneHotEncoder.create_from_dataframe(
            dataframe=dataset_dataframe, columns=categorical_columns)
    )

    dataloader = ScalarDataset(config, data_frame=dataset_dataframe) \
        .as_data_loader(shuffle=False, batch_size=2)

    config.set_output_to(test_output_dirs.root_dir)
    config.num_epochs = 1
    model = create_model_with_temperature_scaling(config)
    # Patch the load_images function that will be called once we access a dataset item
    image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, (6, 64, 60)),
                                                      segmentations=np.random.randint(0, 2, (6, 64, 60)))
    with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg):
        batch = next(iter(dataloader))
        model_inputs_and_labels = get_scalar_model_inputs_and_labels(config, model, batch)

    number_channels = len(config.image_channels)
    number_subjects = len(model_inputs_and_labels.subject_ids)
    visualizer = VisualizationMaps(model, config)
    guided_grad_cams, grad_cams, pseudo_cam_non_img, probas = visualizer.generate(
        model_inputs_and_labels.model_inputs)

    if imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
        assert guided_grad_cams.shape[:2] == (number_subjects, number_channels * 2)
    else:
        assert guided_grad_cams.shape[:2] == (number_subjects, number_channels)

    assert grad_cams.shape[:2] == (number_subjects, 1) if encode_channels_jointly \
        else (number_subjects, number_channels)

    if use_non_imaging_features:
        non_image_features = config.numerical_columns + config.categorical_columns
        non_imaging_plot_labels = visualizer._get_non_imaging_plot_labels(model_inputs_and_labels.data_item,
                                                                          non_image_features,
                                                                          index=0)
        assert non_imaging_plot_labels == ['numerical1_week1',
                                           'numerical1_week0',
                                           'numerical2_week1',
                                           'numerical2_week0',
                                           'categorical1_week1',
                                           'categorical1_week0',
                                           'categorical2_week1']
        assert pseudo_cam_non_img.shape == (number_subjects, 1, len(non_imaging_plot_labels))