def test_mean_teacher_model() -> None:
    """
    Test training and weight updates of the mean teacher model computation.
    """
    def _get_parameters_of_model(
            model: Union[torch.nn.Module, DataParallelModel]) -> Any:
        """
        Returns the iterator of model parameters
        """
        if isinstance(model, DataParallelModel):
            return model.module.parameters()
        else:
            return model.parameters()

    config = DummyClassification()
    config.num_epochs = 1
    # Set train batch size to be arbitrary big to ensure we have only one training step
    # i.e. one mean teacher update.
    config.train_batch_size = 100
    # Train without mean teacher
    model_train(config)

    # Retrieve the weight after one epoch
    model = create_model_with_temperature_scaling(config)
    print(config.get_path_to_checkpoint(1))
    _ = model_util.load_checkpoint(model, config.get_path_to_checkpoint(1))
    model_weight = next(_get_parameters_of_model(model))

    # Get the starting weight of the mean teacher model
    ml_util.set_random_seed(config.get_effective_random_seed())
    _ = create_model_with_temperature_scaling(config)
    mean_teach_model = create_model_with_temperature_scaling(config)
    initial_weight_mean_teacher_model = next(
        _get_parameters_of_model(mean_teach_model))

    # Now train with mean teacher and check the update of the weight
    alpha = 0.999
    config.mean_teacher_alpha = alpha
    model_train(config)

    # Retrieve weight of mean teacher model saved in the checkpoint
    mean_teacher_model = create_model_with_temperature_scaling(config)
    _ = model_util.load_checkpoint(
        mean_teacher_model,
        config.get_path_to_checkpoint(1, for_mean_teacher_model=True))
    result_weight = next(_get_parameters_of_model(mean_teacher_model))
    # Retrieve the associated student weight
    _ = model_util.load_checkpoint(model, config.get_path_to_checkpoint(1))
    student_model_weight = next(_get_parameters_of_model(model))

    # Assert that the student weight corresponds to the weight of a simple training without mean teacher
    # computation
    assert student_model_weight.allclose(model_weight)

    # Check the update of the parameters
    assert torch.all(alpha * initial_weight_mean_teacher_model +
                     (1 - alpha) * student_model_weight == result_weight)
Exemple #2
0
def test_visualization_for_different_target_weeks(test_output_dirs: TestOutputDirectories) -> None:
    """
    Tests that the visualizations are differentiated depending on the target week
    for which we visualize it.
    """
    config = ToyMultiLabelSequenceModel(should_validate=False)
    config.set_output_to(test_output_dirs.root_dir)
    config.dataset_data_frame = _get_multi_label_sequence_dataframe()
    config.pre_process_dataset_dataframe()
    model = create_model_with_temperature_scaling(config)
    dataloader = SequenceDataset(config,
                                 data_frame=config.dataset_data_frame).as_data_loader(shuffle=False,
                                                                                      batch_size=2)
    batch = next(iter(dataloader))
    model_inputs_and_labels = get_scalar_model_inputs_and_labels(config, model, batch)  # type: ignore

    visualizer = VisualizationMaps(model, config)
    # Pseudo-grad cam explaining the prediction at target sequence 2
    _, _, pseudo_cam_non_img_3, probas_3 = visualizer.generate(model_inputs_and_labels.model_inputs,
                                                               target_position=2,
                                                               target_label_index=2)
    # Pseudo-grad cam explaining the prediction at target sequence 0
    _, _, pseudo_cam_non_img_1, probas_1 = visualizer.generate(model_inputs_and_labels.model_inputs,
                                                               target_position=0,
                                                               target_label_index=0)
    assert pseudo_cam_non_img_1.shape[1] == 1
    assert pseudo_cam_non_img_3.shape[1] == 3
    # Both visualizations should not be equal
    assert np.any(pseudo_cam_non_img_1 != pseudo_cam_non_img_3)
    assert np.any(probas_3 != probas_1)
Exemple #3
0
def test_visualization_with_sequence_model(use_combined_model: bool,
                                           imaging_feature_type: ImagingFeatureType,
                                           test_output_dirs: TestOutputDirectories) -> None:
    config = ToySequenceModel(use_combined_model, imaging_feature_type, should_validate=False)
    config.set_output_to(test_output_dirs.root_dir)
    config.dataset_data_frame = _get_mock_sequence_dataset()
    config.num_epochs = 1

    model = create_model_with_temperature_scaling(config)
    update_model_for_multiple_gpus(ModelAndInfo(model), config)
    dataloader = SequenceDataset(config,
                                 data_frame=config.dataset_data_frame).as_data_loader(shuffle=False,
                                                                                      batch_size=2)
    # Patch the load_images function that will be called once we access a dataset item
    image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, SCAN_SIZE),
                                                      segmentations=np.random.randint(0, 2, SCAN_SIZE))
    with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg):
        batch = next(iter(dataloader))
        model_inputs_and_labels = get_scalar_model_inputs_and_labels(config, model, batch)  # type: ignore
    number_sequences = model_inputs_and_labels.model_inputs[0].shape[1]
    number_subjects = len(model_inputs_and_labels.subject_ids)
    visualizer = VisualizationMaps(model, config)
    guided_grad_cams, grad_cams, pseudo_cam_non_img, probas = visualizer.generate(
        model_inputs_and_labels.model_inputs)
    if use_combined_model:
        if imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
            assert guided_grad_cams.shape[:2] == (number_subjects, number_sequences * 2)
            assert grad_cams.shape[:2] == (number_subjects, number_sequences * 2)
        else:
            assert guided_grad_cams.shape[:2] == (number_subjects, number_sequences)
            assert grad_cams.shape[:2] == (number_subjects, number_sequences)
    else:
        assert guided_grad_cams is None
        assert grad_cams is None
        assert pseudo_cam_non_img.shape[:2] == (number_subjects, number_sequences)
        assert probas.shape[0] == number_subjects
    non_image_features = config.numerical_columns + config.categorical_columns
    non_imaging_plot_labels = visualizer._get_non_imaging_plot_labels(model_inputs_and_labels.data_item,
                                                                      non_image_features,
                                                                      index=0,
                                                                      target_position=3)
    assert non_imaging_plot_labels == ['numerical1_0',
                                       'numerical2_0',
                                       'cat1_0',
                                       'numerical1_1',
                                       'numerical2_1',
                                       'cat1_1',
                                       'numerical1_2',
                                       'numerical2_2',
                                       'cat1_2',
                                       'numerical1_3',
                                       'numerical2_3',
                                       'cat1_3']
Exemple #4
0
def test_load_all_configs(model_name: str) -> None:
    """
    Loads all model configurations that are present in the ML/src/configs folder,
    and carries out basic validations of the configuration.
    """
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    config = ModelConfigLoader().create_model_config_from_name(model_name)
    assert config.model_name == model_name, "Mismatch between definition .py file and model name"
    if config.is_segmentation_model:
        # Reduce the feature channels to a minimum, to make tests run fast on CPU.
        minimal_feature_channels = 1
        config.feature_channels = [minimal_feature_channels] * len(config.feature_channels)
        print("Model architecture after restricting to 2 feature channels only:")
        model = create_model_with_temperature_scaling(config)
        generate_and_print_model_summary(config, model)  # type: ignore
    else:
        # For classification models, we can't always print a model summary: The model could require arbitrary
        # numbers of input tensors, and we'd only know once we load the training data.
        # Hence, only try to create the model, but don't attempt to print the summary.
        create_model_with_temperature_scaling(config)
Exemple #5
0
def extract_activation_maps(args: ModelConfigBase) -> None:
    """
    Extracts and saves activation maps of a specific layer of a trained network
    :param args:
    :return:
    """
    model = create_model_with_temperature_scaling(args)
    if args.use_gpu:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(
                                          range(torch.cuda.device_count())))
        model = model.cuda()

    checkpoint_path = args.get_path_to_checkpoint()
    if checkpoint_path.is_file():
        checkpoint = torch.load(checkpoint_path)  # type: ignore
        model.load_state_dict(checkpoint['state_dict'])
    else:
        raise FileNotFoundError("Could not find checkpoint")

    model.eval()

    val_loader = args.create_data_loaders()[ModelExecutionMode.VAL]

    feature_extractor = model_hooks.HookBasedFeatureExtractor(
        model, layer_name=args.activation_map_layers)

    for batch, sample in enumerate(val_loader):

        sample = CroppedSample.from_dict(sample=sample)

        input_image = sample.image.cuda().float()

        feature_extractor(input_image)

        # access first image of batch of feature maps
        activation_map = feature_extractor.outputs[0][0].cpu().numpy()

        if len(activation_map.shape) == 4:
            visualize_3d_activation_map(activation_map, args)

        elif len(activation_map.shape) == 3:
            visualize_2d_activation_map(activation_map, args)

        else:
            raise NotImplementedError(
                'cannot visualize activation map of shape',
                activation_map.shape)

        # Only visualize the first validation example
        break
def run_inference_on_unet(size: TupleInt3) -> None:
    """
    Runs a model forward pass on a freshly created model, with an input image of the given size.
    Asserts that the model prediction has the same size as the input image.
    """
    fg_classes = ["tumour_mass", "subtract"]
    number_of_classes = len(fg_classes) + 1
    config = SegmentationModelBase(
        architecture="UNet3D",
        local_dataset=Path("dummy"),
        feature_channels=[1],
        kernel_size=3,
        largest_connected_component_foreground_classes=fg_classes,
        posterior_smoothing_mm=(2, 2, 2),
        crop_size=(64, 64, 64),
        # test_crop_size must be larger than 'size for the bug to trigger
        test_crop_size=(80, 80, 80),
        image_channels=["mr"],
        ground_truth_ids=fg_classes,
        ground_truth_ids_display_names=fg_classes,
        colours=[(255, 0, 0)] * len(fg_classes),
        fill_holes=[False] * len(fg_classes),
        mask_id=None,
        class_weights=[1.0 / number_of_classes] * number_of_classes,
        train_batch_size=8,
        inference_batch_size=1,
        inference_stride_size=(40, 40, 40),
        use_mixed_precision=True)
    model = create_model_with_temperature_scaling(config)
    pipeline = InferencePipeline(model=model, model_config=config, epoch=1)
    image = np.random.uniform(-1, 1, (1, ) + size)
    result = pipeline.predict_and_post_process_whole_image(
        image, mask=np.ones(size), voxel_spacing_mm=(1, 1, 1))
    # All posteriors and segmentations must have the size of the input image
    for p in [*result.posteriors, result.segmentation]:
        assert p.shape == size
        # Check that all results are not NaN. In particular, if stride size is not adjusted
        # correctly, the results would be partially NaN.
        image_util.check_array_range(p)
def test_model_summary_on_classification1() -> None:
    model = create_model_with_temperature_scaling(GlaucomaPublic())
    ModelSummary(model).generate_summary(input_sizes=[(1, 6, 64, 60)])
Exemple #8
0
def test_visualization_with_scalar_model(use_non_imaging_features: bool,
                                         imaging_feature_type: ImagingFeatureType,
                                         encode_channels_jointly: bool,
                                         test_output_dirs: OutputFolderForTests) -> None:
    dataset_contents = """subject,channel,path,label,numerical1,numerical2,categorical1,categorical2
    S1,week0,scan1.npy,,1,10,Male,Val1
    S1,week1,scan2.npy,True,2,20,Female,Val2
    S2,week0,scan3.npy,,3,30,Female,Val3
    S2,week1,scan4.npy,False,4,40,Female,Val1
    S3,week0,scan1.npy,,5,50,Male,Val2
    S3,week1,scan3.npy,True,6,60,Male,Val2
    """
    dataset_dataframe = pd.read_csv(StringIO(dataset_contents), dtype=str)
    numerical_columns = ["numerical1", "numerical2"] if use_non_imaging_features else []
    categorical_columns = ["categorical1", "categorical2"] if use_non_imaging_features else []
    non_image_feature_channels = get_non_image_features_dict(default_channels=["week1", "week0"],
                                                             specific_channels={"categorical2": ["week1"]}) \
        if use_non_imaging_features else {}

    config = ImageEncoder(
        local_dataset=Path(),
        encode_channels_jointly=encode_channels_jointly,
        should_validate=False,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        imaging_feature_type=imaging_feature_type,
        non_image_feature_channels=non_image_feature_channels,
        categorical_feature_encoder=CategoricalToOneHotEncoder.create_from_dataframe(
            dataframe=dataset_dataframe, columns=categorical_columns)
    )

    dataloader = ScalarDataset(config, data_frame=dataset_dataframe) \
        .as_data_loader(shuffle=False, batch_size=2)

    config.set_output_to(test_output_dirs.root_dir)
    config.num_epochs = 1
    model = create_model_with_temperature_scaling(config)
    visualizer = VisualizationMaps(model, config)
    # Patch the load_images function that will be called once we access a dataset item
    image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, (6, 64, 60)),
                                                      segmentations=np.random.randint(0, 2, (6, 64, 60)))
    with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg):
        batch = next(iter(dataloader))
        if config.use_gpu:
            device = visualizer.grad_cam.device
            batch = transfer_batch_to_device(batch, device)
            visualizer.grad_cam.model = visualizer.grad_cam.model.to(device)
        model_inputs_and_labels = get_scalar_model_inputs_and_labels(model,
                                                                     target_indices=[],
                                                                     sample=batch)
    number_channels = len(config.image_channels)
    number_subjects = len(model_inputs_and_labels.subject_ids)
    guided_grad_cams, grad_cams, pseudo_cam_non_img, probas = visualizer.generate(
        model_inputs_and_labels.model_inputs)

    if imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
        assert guided_grad_cams.shape[:2] == (number_subjects, number_channels * 2)
    else:
        assert guided_grad_cams.shape[:2] == (number_subjects, number_channels)

    assert grad_cams.shape[:2] == (number_subjects, 1) if encode_channels_jointly \
        else (number_subjects, number_channels)

    if use_non_imaging_features:
        non_image_features = config.numerical_columns + config.categorical_columns
        non_imaging_plot_labels = visualizer._get_non_imaging_plot_labels(model_inputs_and_labels.data_item,
                                                                          non_image_features,
                                                                          index=0)
        assert non_imaging_plot_labels == ['numerical1_week1',
                                           'numerical1_week0',
                                           'numerical2_week1',
                                           'numerical2_week0',
                                           'categorical1_week1',
                                           'categorical1_week0',
                                           'categorical2_week1']
        assert pseudo_cam_non_img.shape == (number_subjects, 1, len(non_imaging_plot_labels))
Exemple #9
0
def test_image_encoder(test_output_dirs: OutputFolderForTests,
                       encode_channels_jointly: bool,
                       use_non_imaging_features: bool,
                       kernel_size_per_encoding_block: Optional[Union[TupleInt3, List[TupleInt3]]],
                       stride_size_per_encoding_block: Optional[Union[TupleInt3, List[TupleInt3]]],
                       reduction_factor: float,
                       expected_num_reduced_features: int,
                       aggregation_type: AggregationType) -> None:
    """
    Test if the image encoder networks can be trained without errors (including GradCam computation and data
    augmentation).
    """
    logging_to_stdout()
    set_random_seed(0)
    dataset_folder = Path(test_output_dirs.make_sub_dir("dataset"))
    scan_size = (6, 64, 60)
    scan_files: List[str] = []
    for s in range(4):
        random_scan = np.random.uniform(0, 1, scan_size)
        scan_file_name = f"scan{s + 1}{NumpyFile.NUMPY.value}"
        np.save(str(dataset_folder / scan_file_name), random_scan)
        scan_files.append(scan_file_name)

    dataset_contents = """subject,channel,path,label,numerical1,numerical2,categorical1,categorical2
S1,week0,scan1.npy,,1,10,Male,Val1
S1,week1,scan2.npy,True,2,20,Female,Val2
S2,week0,scan3.npy,,3,30,Female,Val3
S2,week1,scan4.npy,False,4,40,Female,Val1
S3,week0,scan1.npy,,5,50,Male,Val2
S3,week1,scan3.npy,True,6,60,Male,Val2
"""
    (dataset_folder / "dataset.csv").write_text(dataset_contents)
    numerical_columns = ["numerical1", "numerical2"] if use_non_imaging_features else []
    categorical_columns = ["categorical1", "categorical2"] if use_non_imaging_features else []
    non_image_feature_channels = get_non_image_features_dict(default_channels=["week1", "week0"],
                                                             specific_channels={"categorical2": ["week1"]}) \
        if use_non_imaging_features else {}
    config_for_dataset = ScalarModelBase(
        local_dataset=dataset_folder,
        image_channels=["week0", "week1"],
        image_file_column="path",
        label_channels=["week1"],
        label_value_column="label",
        non_image_feature_channels=non_image_feature_channels,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        should_validate=False
    )
    config_for_dataset.read_dataset_into_dataframe_and_pre_process()

    dataset = ScalarDataset(config_for_dataset,
                            sample_transforms=ScalarItemAugmentation(
                                RandAugmentSlice(is_transformation_for_segmentation_maps=False)))
    assert len(dataset) == 3

    config = ImageEncoder(
        encode_channels_jointly=encode_channels_jointly,
        should_validate=False,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        non_image_feature_channels=non_image_feature_channels,
        categorical_feature_encoder=config_for_dataset.categorical_feature_encoder,
        encoder_dimensionality_reduction_factor=reduction_factor,
        aggregation_type=aggregation_type,
        scan_size=(6, 64, 60)
    )

    if kernel_size_per_encoding_block:
        config.kernel_size_per_encoding_block = kernel_size_per_encoding_block
    if stride_size_per_encoding_block:
        config.stride_size_per_encoding_block = stride_size_per_encoding_block

    config.set_output_to(test_output_dirs.root_dir)
    config.max_batch_grad_cam = 1
    model = create_model_with_temperature_scaling(config)
    input_size: List[Tuple] = [(len(config.image_channels), *scan_size)]
    if use_non_imaging_features:
        input_size.append((config.get_total_number_of_non_imaging_features(),))

        # Original number output channels (unreduced) is
        # num initial channel * (num encoder block - 1) = 4 * (3-1) = 8
        if encode_channels_jointly:
            # reduced_num_channels + num_non_img_features
            assert model.final_num_feature_channels == expected_num_reduced_features + \
                   config.get_total_number_of_non_imaging_features()
        else:
            # num_img_channels * reduced_num_channels + num_non_img_features
            assert model.final_num_feature_channels == len(config.image_channels) * expected_num_reduced_features + \
                   config.get_total_number_of_non_imaging_features()

    summarizer = ModelSummary(model)
    summarizer.generate_summary(input_sizes=input_size)
    config.local_dataset = dataset_folder
    config.validate()
    model_train(config, checkpoint_handler=get_default_checkpoint_handler(model_config=config,
                                                                          project_root=Path(test_output_dirs.root_dir)))