def test_visualization_with_sequence_model(
        use_combined_model: bool, imaging_feature_type: ImagingFeatureType,
        test_output_dirs: OutputFolderForTests) -> None:
    config = ToySequenceModel(use_combined_model,
                              imaging_feature_type,
                              should_validate=False)
    config.set_output_to(test_output_dirs.root_dir)
    config.dataset_data_frame = _get_mock_sequence_dataset()
    config.num_epochs = 1
    model = config.create_model()
    if config.use_gpu:
        model = model.cuda()
    dataloader = SequenceDataset(
        config,
        data_frame=config.dataset_data_frame).as_data_loader(shuffle=False,
                                                             batch_size=2)
    # Patch the load_images function that will be called once we access a dataset item
    image_and_seg = ImageAndSegmentations[np.ndarray](
        images=np.random.uniform(0, 1, SCAN_SIZE),
        segmentations=np.random.randint(0, 2, SCAN_SIZE))
    with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats',
                    return_value=image_and_seg):
        batch = next(iter(dataloader))
        if config.use_gpu:
            batch = transfer_batch_to_device(batch, torch.device(0))
        model_inputs_and_labels = get_scalar_model_inputs_and_labels(
            model, target_indices=config.get_target_indices(),
            sample=batch)  # type: ignore
    number_sequences = model_inputs_and_labels.model_inputs[0].shape[1]
    number_subjects = len(model_inputs_and_labels.subject_ids)
    visualizer = VisualizationMaps(model, config)
    guided_grad_cams, grad_cams, pseudo_cam_non_img, probas = visualizer.generate(
        model_inputs_and_labels.model_inputs)
    if use_combined_model:
        if imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
            assert guided_grad_cams.shape[:2] == (number_subjects,
                                                  number_sequences * 2)
            assert grad_cams.shape[:2] == (number_subjects,
                                           number_sequences * 2)
        else:
            assert guided_grad_cams.shape[:2] == (number_subjects,
                                                  number_sequences)
            assert grad_cams.shape[:2] == (number_subjects, number_sequences)
    else:
        assert guided_grad_cams is None
        assert grad_cams is None
        assert pseudo_cam_non_img.shape[:2] == (number_subjects,
                                                number_sequences)
        assert probas.shape[0] == number_subjects
    non_image_features = config.numerical_columns + config.categorical_columns
    non_imaging_plot_labels = visualizer._get_non_imaging_plot_labels(
        model_inputs_and_labels.data_item,
        non_image_features,
        index=0,
        target_position=3)
    assert non_imaging_plot_labels == [
        'numerical1_0', 'numerical2_0', 'cat1_0', 'numerical1_1',
        'numerical2_1', 'cat1_1', 'numerical1_2', 'numerical2_2', 'cat1_2',
        'numerical1_3', 'numerical2_3', 'cat1_3'
    ]
 def training_or_validation_step(self,
                                 sample: Dict[str, Any],
                                 batch_index: int,
                                 is_training: bool) -> torch.Tensor:
     """
     Runs training for a single minibatch of training or validation data, and computes all metrics.
     :param is_training: If true, the method is called from `training_step`, otherwise it is called from
     `validation_step`.
     :param sample: The batched sample on which the model should be trained.
     :param batch_index: The index of the present batch (supplied only for diagnostics).
     Runs a minibatch of training or validation data through the model.
     """
     model_inputs_and_labels = get_scalar_model_inputs_and_labels(self.model, self.target_indices, sample)
     labels = model_inputs_and_labels.labels
     if is_training:
         logits = self.model(*model_inputs_and_labels.model_inputs)
     else:
         with torch.no_grad():
             logits = self.model(*model_inputs_and_labels.model_inputs)
     subject_ids = model_inputs_and_labels.subject_ids
     loss = self.loss_fn(logits, labels)
     self.write_loss(is_training, loss)
     self.compute_and_log_metrics(logits, labels, subject_ids, is_training)
     self.log_on_epoch(name=MetricType.SUBJECT_COUNT,
                       value=len(model_inputs_and_labels.subject_ids),
                       is_training=is_training,
                       reduce_fx=sum)
     return loss
 def predict(self, sample: Dict[str, Any]) -> ScalarInferencePipelineBase.Result:
     """
     Runs the forward pass on a single batch.
     :param sample: Single batch of input data.
                     In the form of a dict containing at least the fields:
                     metadata, label, images, numerical_non_image_features,
                     categorical_non_image_features and segmentations.
     :return: Returns ScalarInferencePipelineBase.Result with  the subject ids, ground truth labels and predictions.
     """
     assert isinstance(self.model_config, ScalarModelBase)
     model_inputs_and_labels = get_scalar_model_inputs_and_labels(self.model.model,
                                                                  sample=sample)
     model_inputs_and_labels.move_to_device(self.model.device)
     with torch.no_grad():
         # This already gives the model outputs converted to posteriors
         posteriors: torch.Tensor = self.model.forward(*model_inputs_and_labels.model_inputs)
     return ScalarInferencePipelineBase.Result(subject_ids=model_inputs_and_labels.subject_ids,
                                               labels=model_inputs_and_labels.labels,
                                               posteriors=posteriors)
def test_visualization_for_different_target_weeks(
        test_output_dirs: OutputFolderForTests) -> None:
    """
    Tests that the visualizations are differentiated depending on the target week
    for which we visualize it.
    """
    config = ToyMultiLabelSequenceModel(should_validate=False)
    config.set_output_to(test_output_dirs.root_dir)
    config.dataset_data_frame = _get_multi_label_sequence_dataframe()
    config.pre_process_dataset_dataframe()
    model = create_model_with_temperature_scaling(config)
    dataloader = SequenceDataset(
        config,
        data_frame=config.dataset_data_frame).as_data_loader(shuffle=False,
                                                             batch_size=2)
    batch = next(iter(dataloader))
    model_inputs_and_labels = get_scalar_model_inputs_and_labels(
        model, target_indices=config.get_target_indices(), sample=batch)

    visualizer = VisualizationMaps(model, config)
    if config.use_gpu:
        device = visualizer.grad_cam.device
        batch = transfer_batch_to_device(batch, device)
        model = model.to(device)
    # Pseudo-grad cam explaining the prediction at target sequence 2
    _, _, pseudo_cam_non_img_3, probas_3 = visualizer.generate(
        model_inputs_and_labels.model_inputs,
        target_position=2,
        target_label_index=2)
    # Pseudo-grad cam explaining the prediction at target sequence 0
    _, _, pseudo_cam_non_img_1, probas_1 = visualizer.generate(
        model_inputs_and_labels.model_inputs,
        target_position=0,
        target_label_index=0)
    assert pseudo_cam_non_img_1.shape[1] == 1
    assert pseudo_cam_non_img_3.shape[1] == 3
    # Both visualizations should not be equal
    assert np.any(pseudo_cam_non_img_1 != pseudo_cam_non_img_3)
    assert np.any(probas_3 != probas_1)
Exemplo n.º 5
0
def test_visualization_with_scalar_model(use_non_imaging_features: bool,
                                         imaging_feature_type: ImagingFeatureType,
                                         encode_channels_jointly: bool,
                                         test_output_dirs: OutputFolderForTests) -> None:
    dataset_contents = """subject,channel,path,label,numerical1,numerical2,categorical1,categorical2
    S1,week0,scan1.npy,,1,10,Male,Val1
    S1,week1,scan2.npy,True,2,20,Female,Val2
    S2,week0,scan3.npy,,3,30,Female,Val3
    S2,week1,scan4.npy,False,4,40,Female,Val1
    S3,week0,scan1.npy,,5,50,Male,Val2
    S3,week1,scan3.npy,True,6,60,Male,Val2
    """
    dataset_dataframe = pd.read_csv(StringIO(dataset_contents), dtype=str)
    numerical_columns = ["numerical1", "numerical2"] if use_non_imaging_features else []
    categorical_columns = ["categorical1", "categorical2"] if use_non_imaging_features else []
    non_image_feature_channels = get_non_image_features_dict(default_channels=["week1", "week0"],
                                                             specific_channels={"categorical2": ["week1"]}) \
        if use_non_imaging_features else {}

    config = ImageEncoder(
        local_dataset=Path(),
        encode_channels_jointly=encode_channels_jointly,
        should_validate=False,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        imaging_feature_type=imaging_feature_type,
        non_image_feature_channels=non_image_feature_channels,
        categorical_feature_encoder=CategoricalToOneHotEncoder.create_from_dataframe(
            dataframe=dataset_dataframe, columns=categorical_columns)
    )

    dataloader = ScalarDataset(config, data_frame=dataset_dataframe) \
        .as_data_loader(shuffle=False, batch_size=2)

    config.set_output_to(test_output_dirs.root_dir)
    config.num_epochs = 1
    model = create_model_with_temperature_scaling(config)
    visualizer = VisualizationMaps(model, config)
    # Patch the load_images function that will be called once we access a dataset item
    image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, (6, 64, 60)),
                                                      segmentations=np.random.randint(0, 2, (6, 64, 60)))
    with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg):
        batch = next(iter(dataloader))
        if config.use_gpu:
            device = visualizer.grad_cam.device
            batch = transfer_batch_to_device(batch, device)
            visualizer.grad_cam.model = visualizer.grad_cam.model.to(device)
        model_inputs_and_labels = get_scalar_model_inputs_and_labels(model,
                                                                     target_indices=[],
                                                                     sample=batch)
    number_channels = len(config.image_channels)
    number_subjects = len(model_inputs_and_labels.subject_ids)
    guided_grad_cams, grad_cams, pseudo_cam_non_img, probas = visualizer.generate(
        model_inputs_and_labels.model_inputs)

    if imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
        assert guided_grad_cams.shape[:2] == (number_subjects, number_channels * 2)
    else:
        assert guided_grad_cams.shape[:2] == (number_subjects, number_channels)

    assert grad_cams.shape[:2] == (number_subjects, 1) if encode_channels_jointly \
        else (number_subjects, number_channels)

    if use_non_imaging_features:
        non_image_features = config.numerical_columns + config.categorical_columns
        non_imaging_plot_labels = visualizer._get_non_imaging_plot_labels(model_inputs_and_labels.data_item,
                                                                          non_image_features,
                                                                          index=0)
        assert non_imaging_plot_labels == ['numerical1_week1',
                                           'numerical1_week0',
                                           'numerical2_week1',
                                           'numerical2_week0',
                                           'categorical1_week1',
                                           'categorical1_week0',
                                           'categorical2_week1']
        assert pseudo_cam_non_img.shape == (number_subjects, 1, len(non_imaging_plot_labels))