예제 #1
0
def test_get_class_counts_multilabel(test_output_dirs: OutputFolderForTests) -> None:
    """
    Test the get_class_counts method for multilabel scalar datasets.
    """
    dataset_folder = Path(test_output_dirs.make_sub_dir("dataset"))
    dataset_contents = """subject,channel,path,label,CAT1
   S1,week0,scan1.npy,,A
   S1,week1,scan2.npy,0|1|2,A
   S2,week0,scan3.npy,,A
   S2,week1,scan4.npy,1|2,A
   S3,week0,scan1.npy,,A
   S3,week1,scan3.npy,1,A
   """
    config = ScalarModelBase(
        local_dataset=dataset_folder,
        class_names=["class0", "class1", "class2", "class3"],
        label_channels=["week1"],
        label_value_column="label",
        non_image_feature_channels=["week0", "week1"],
        should_validate=False
    )
    config.set_output_to(test_output_dirs.root_dir)
    train_dataset = ScalarDataset(config, pd.read_csv(StringIO(dataset_contents), dtype=str))
    class_counts = train_dataset.get_class_counts()
    assert class_counts == {0: 1, 1: 3, 2: 2, 3: 0}
예제 #2
0
def test_get_class_weights_dataset(test_output_dirs: OutputFolderForTests) -> None:
    """
    Test training and testing of sequence models that predicts at multiple time points,
    when it is started via run_ml.
    """
    dataset_folder = Path(test_output_dirs.make_sub_dir("dataset"))
    dataset_contents = """subject,channel,path,label,numerical1,numerical2,CAT1
   S1,week0,scan1.npy,,1,10,A
   S1,week1,scan2.npy,True,2,20,A
   S2,week0,scan3.npy,,3,30,A
   S2,week1,scan4.npy,False,4,40,A
   S3,week0,scan1.npy,,5,50,A
   S3,week1,scan3.npy,True,6,60,A
   """
    config = ScalarModelBase(
        local_dataset=dataset_folder,
        label_channels=["week1"],
        label_value_column="label",
        non_image_feature_channels=["week0", "week1"],
        numerical_columns=["numerical1", "numerical2"],
        should_validate=False
    )
    config.set_output_to(test_output_dirs.root_dir)
    train_dataset = ScalarDataset(config, pd.read_csv(StringIO(dataset_contents), dtype=str))
    class_counts = train_dataset.get_class_counts()
    assert class_counts == {0.0: 1, 1.0: 2}
예제 #3
0
def test_get_labels_for_imbalanced_sampler_multilabel(test_output_dirs: OutputFolderForTests) -> None:
    """
    Test that the get_labels_for_imbalanced_sampler method raises an error for multilabel scalar datasets.
    """
    dataset_folder = Path(test_output_dirs.make_sub_dir("dataset"))
    dataset_contents = """subject,channel,path,label,CAT1
    S1,week0,scan1.npy,,A
    S1,week1,scan2.npy,0|1|2,A
    S2,week0,scan3.npy,,A
    S2,week1,scan4.npy,1|2,A
    S3,week0,scan1.npy,,A
    S3,week1,scan3.npy,1,A
    """
    config = ScalarModelBase(
        local_dataset=dataset_folder,
        class_names=["class0", "class1", "class2", "class3"],
        label_channels=["week1"],
        label_value_column="label",
        non_image_feature_channels=["week0", "week1"],
        should_validate=False
    )
    config.set_output_to(test_output_dirs.root_dir)
    train_dataset = ScalarDataset(config, pd.read_csv(StringIO(dataset_contents), dtype=str))
    with pytest.raises(NotImplementedError) as ex:
        train_dataset.get_labels_for_imbalanced_sampler()
    assert "ImbalancedSampler is not supported for multilabel tasks." in str(ex)
예제 #4
0
def test_get_labels_for_imbalanced_sampler_binary(test_output_dirs: OutputFolderForTests) -> None:
    """
    Test the get_labels_for_imbalanced_sampler method for binary scalar datasets.
    """
    dataset_folder = Path(test_output_dirs.make_sub_dir("dataset"))
    dataset_contents = """subject,channel,path,label,numerical1,numerical2,CAT1
    S1,week0,scan1.npy,,1,10,A
    S1,week1,scan2.npy,True,2,20,A
    S2,week0,scan3.npy,,3,30,A
    S2,week1,scan4.npy,False,4,40,A
    S3,week0,scan1.npy,,5,50,A
    S3,week1,scan3.npy,True,6,60,A
    """
    config = ScalarModelBase(
        local_dataset=dataset_folder,
        label_channels=["week1"],
        label_value_column="label",
        non_image_feature_channels=["week0", "week1"],
        numerical_columns=["numerical1", "numerical2"],
        should_validate=False
    )
    config.set_output_to(test_output_dirs.root_dir)
    train_dataset = ScalarDataset(config, pd.read_csv(StringIO(dataset_contents), dtype=str))
    labels = train_dataset.get_labels_for_imbalanced_sampler()
    assert labels == [1.0, 0.0, 1.0]
    def create_torch_datasets(
            self,
            dataset_splits: DatasetSplits) -> Dict[ModelExecutionMode, Any]:
        from InnerEye.ML.dataset.scalar_dataset import ScalarDataset
        sample_transform = self.get_scalar_item_transform()
        assert sample_transform.train is not None  # for mypy
        assert sample_transform.val is not None  # for mypy
        assert sample_transform.test is not None  # for mypy
        train = ScalarDataset(args=self,
                              data_frame=dataset_splits.train,
                              name="training",
                              sample_transform=sample_transform.train)
        val = ScalarDataset(args=self,
                            data_frame=dataset_splits.val,
                            feature_statistics=train.feature_statistics,
                            name="validation",
                            sample_transform=sample_transform.val)
        test = ScalarDataset(args=self,
                             data_frame=dataset_splits.test,
                             feature_statistics=train.feature_statistics,
                             name="test",
                             sample_transform=sample_transform.test)

        return {
            ModelExecutionMode.TRAIN: train,
            ModelExecutionMode.VAL: val,
            ModelExecutionMode.TEST: test
        }
예제 #6
0
    def create_torch_datasets(
            self,
            dataset_splits: DatasetSplits) -> Dict[ModelExecutionMode, Any]:
        from InnerEye.ML.dataset.scalar_dataset import ScalarDataset
        image_transforms = self.get_image_sample_transforms()
        train = ScalarDataset(
            args=self,
            data_frame=dataset_splits.train,
            name="training",
            sample_transforms=image_transforms.train)  # type: ignore
        val = ScalarDataset(
            args=self,
            data_frame=dataset_splits.val,
            feature_statistics=train.feature_statistics,
            name="validation",
            sample_transforms=image_transforms.val)  # type: ignore
        test = ScalarDataset(
            args=self,
            data_frame=dataset_splits.test,
            feature_statistics=train.feature_statistics,
            name="test",
            sample_transforms=image_transforms.test)  # type: ignore

        return {
            ModelExecutionMode.TRAIN: train,
            ModelExecutionMode.VAL: val,
            ModelExecutionMode.TEST: test
        }
예제 #7
0
def test_imbalanced_sampler() -> None:
    # Simulate a highly imbalanced dataset with only one data point
    # with a negative label.
    csv_string = StringIO("""subject,channel,value,scalar1
    S1,label,True,1.0
    S2,label,True,1.0
    S3,label,True,1.0
    S4,label,True,1.0
    S5,label,True,1.0
    S6,label,False,1.0
    """)
    torch.manual_seed(0)
    df = pd.read_csv(csv_string, sep=",", dtype=str)
    args = ScalarModelBase(label_value_column="value",
                           numerical_columns=["scalar1"],
                           local_dataset=Path("fakepath"))
    dataset = ScalarDataset(args, data_frame=df)
    drawn_subjects = []
    for _ in range(10):
        data_loader = dataset.as_data_loader(use_imbalanced_sampler=True,
                                             shuffle=True, batch_size=6,
                                             num_dataload_workers=0)
        for batch in data_loader:
            drawn_subjects.extend([i.id.strip() for i in batch["metadata"]])
    counts_per_subjects = Counter(drawn_subjects)
    count_negative_subjects = counts_per_subjects["S6"]
    assert count_negative_subjects / float(len(drawn_subjects)) > 0.3
def test_dataset_normalize_image(
        test_output_dirs: OutputFolderForTests) -> None:
    """
    Test dataset loading with window normalization image processing.
    """
    source_folder = str(full_ml_test_data_path() / "classification_data")
    target_folder = str(Path(test_output_dirs.make_sub_dir("foo")) / "bar")
    shutil.copytree(source_folder, target_folder)
    csv_string = StringIO("""subject,channel,path,value,scalar1
S1,image,4be9beed-5861-fdd2-72c2-8dd89aadc1ef
S1,label,,True,1.0
S2,image,6ceacaf8-abd2-ffec-2ade-d52afd6dd1be
S2,label,,True,2.0
S3,image,61bc9d73-9fbb-bd7d-c06b-eeffbafabcc4
S3,label,,False,3.0
S4,image,61bc9d73-9fbb-bd7d-c06b-eeffbafabcc4
S4,label,,False,3.0
""")
    df = pd.read_csv(csv_string, sep=",", dtype=str)
    args = ScalarModelBase(image_channels=["image"],
                           image_file_column="path",
                           label_channels=["label"],
                           label_value_column="value",
                           non_image_feature_channels={},
                           numerical_columns=[],
                           traverse_dirs_when_loading=True,
                           local_dataset=test_output_dirs.root_dir)
    raw_dataset = ScalarDataset(args, data_frame=df)
    normalized = ScalarDataset(
        args,
        data_frame=df,
        sample_transforms=WindowNormalizationForScalarItem())
    assert len(raw_dataset) == 4
    for i in range(4):
        raw_item = raw_dataset[i]
        normalized_item = normalized[i]
        normalized_images = normalized_item["images"]
        assert isinstance(raw_item, dict)
        expected_normalized_images = torch.tensor(
            mri_window(raw_item["images"].numpy(),
                       mask=None,
                       output_range=(0, 1))[0])
        assert normalized_images is not None
        assert torch.is_tensor(normalized_images)
        assert expected_normalized_images.shape == normalized_images.shape
        expected_image_size = (4, 5, 7)
        assert normalized_images.shape == (1, ) + expected_image_size
        assert torch.all(expected_normalized_images == normalized_images)
def test_image_labels_from_subject_id_single(
        test_output_dirs: OutputFolderForTests) -> None:
    config = ScalarModelBase(label_value_column="label",
                             subject_column="subject")

    config.local_dataset = test_output_dirs.root_dir / "dataset"
    config.local_dataset.mkdir()
    dataset_csv = config.local_dataset / "dataset.csv"
    dataset_csv.write_text("subject,channel,label\n"
                           "0,label,0\n"
                           "1,label,1\n")

    df = config.read_dataset_if_needed()
    dataset = ScalarDataset(args=config, data_frame=df)

    labels = get_image_labels_from_subject_id(subject_id="0",
                                              dataset=dataset,
                                              config=config)
    assert not labels

    labels = get_image_labels_from_subject_id(subject_id="1",
                                              dataset=dataset,
                                              config=config)
    assert labels
    assert len(labels) == 1
    assert labels[0] == MetricsDict.DEFAULT_HUE_KEY
def test_get_image_filepath_from_subject_id_with_image_channels(
        test_output_dirs: OutputFolderForTests) -> None:
    config = ScalarModelBase(label_channels=["label"],
                             image_file_column="filePath",
                             label_value_column="label",
                             image_channels=["image"],
                             subject_column="subject")

    config.local_dataset = test_output_dirs.root_dir / "dataset"
    config.local_dataset.mkdir()
    dataset_csv = config.local_dataset / "dataset.csv"
    image_file_name = "image.npy"
    dataset_csv.write_text(f"subject,channel,filePath,label\n"
                           f"0,label,,0\n"
                           f"0,image,0_{image_file_name},\n"
                           f"1,label,,1\n"
                           f"1,image,1_{image_file_name},\n")

    df = config.read_dataset_if_needed()
    dataset = ScalarDataset(args=config, data_frame=df)

    Path(config.local_dataset / f"0_{image_file_name}").touch()
    Path(config.local_dataset / f"1_{image_file_name}").touch()

    filepath = get_image_filepath_from_subject_id(subject_id="1",
                                                  dataset=dataset,
                                                  config=config)
    expected_path = Path(config.local_dataset / f"1_{image_file_name}")

    assert filepath
    assert len(filepath) == 1
    assert filepath[0].samefile(expected_path)
def test_get_image_filepath_from_subject_id_single(
        test_output_dirs: OutputFolderForTests) -> None:
    config = ScalarModelBase(image_file_column="filePath",
                             label_value_column="label",
                             subject_column="subject")

    config.local_dataset = test_output_dirs.root_dir / "dataset"
    config.local_dataset.mkdir()
    dataset_csv = config.local_dataset / "dataset.csv"
    image_file_name = "image.npy"
    dataset_csv.write_text(f"subject,filePath,label\n"
                           f"0,0_{image_file_name},0\n"
                           f"1,1_{image_file_name},1\n")

    df = config.read_dataset_if_needed()
    dataset = ScalarDataset(args=config, data_frame=df)

    Path(config.local_dataset / f"0_{image_file_name}").touch()
    Path(config.local_dataset / f"1_{image_file_name}").touch()

    filepath = get_image_filepath_from_subject_id(subject_id="1",
                                                  dataset=dataset,
                                                  config=config)
    expected_path = Path(config.local_dataset / f"1_{image_file_name}")

    assert filepath
    assert len(filepath) == 1
    assert expected_path.samefile(filepath[0])

    # Check error is raised if the subject does not exist
    with pytest.raises(ValueError) as ex:
        get_image_filepath_from_subject_id(subject_id="100",
                                           dataset=dataset,
                                           config=config)
    assert "Could not find subject" in str(ex)
예제 #12
0
def get_unique_prediction_target_combinations(
        config: ScalarModelBase) -> Set[FrozenSet[str]]:
    """
    Get a list of all the combinations of labels that exist in the dataset.

    For multilabel classification tasks, this function will return all unique combinations of labels that
    occur in the dataset csv.
    For example, if there are 6 samples in the dataset with the following ground truth labels
    Sample1: class1, class2
    Sample2: class0
    Sample3: class1
    Sample4: class2, class3
    Sample5: (all label classes are negative in Sample 5)
    Sample6: class1, class2
    This function will return {{"class1", "class2"}, {"class0"}, {"class1"},  {"class2", "class3"}, {}}

    For binary classification tasks (assume class_names has not been changed from ["Default"]):
    This function will return a set with two members - {{"Default"}, {}} if there is at least one positive example
    in the dataset. If there are no positive examples, it returns {{}}.
    """
    df = config.read_dataset_if_needed()
    dataset = ScalarDataset(args=config, data_frame=df)

    all_labels = [
        torch.flatten(torch.nonzero(item.label)).tolist()
        for item in dataset.items
    ]
    label_set = set(
        frozenset([config.class_names[i] for i in labels if not math.isnan(i)])
        for labels in all_labels)

    return label_set
예제 #13
0
def _create_test_dataset(csv_path: Path, scalar_loss: ScalarLoss = ScalarLoss.BinaryCrossEntropyWithLogits,
                         categorical_columns: Optional[List[str]] = None) -> ScalarDataset:
    # Load items indirectly via a ScalarDataset object, to see if the wiring up of all column names works
    args = ScalarModelBase(image_channels=["image"],
                           image_file_column="path",
                           label_channels=["label"],
                           label_value_column="value",
                           non_image_feature_channels=["label"],
                           numerical_columns=["scalar1", "scalar2"],
                           categorical_columns=categorical_columns or list(),
                           subject_column="USUBJID",
                           channel_column="week",
                           local_dataset=csv_path,
                           should_validate=False,
                           loss_type=scalar_loss,
                           num_dataload_workers=0)
    args.read_dataset_into_dataframe_and_pre_process()
    return ScalarDataset(args)
def test_dataset_traverse_dirs(test_output_dirs: OutputFolderForTests,
                               center_crop_size: Optional[TupleInt3]) -> None:
    """
    Test dataset loading when the dataset file only contains file name stems, not full paths.
    """
    # Copy the existing test dataset to a new folder, two levels deep. Later will initialize the
    # dataset with only the root folder given, to check if the files are still found.
    source_folder = str(full_ml_test_data_path() / "classification_data")
    target_folder = str(Path(test_output_dirs.make_sub_dir("foo")) / "bar")
    shutil.copytree(source_folder, target_folder)
    # The dataset should only contain the file name stem, without extension.
    csv_string = StringIO("""subject,channel,path,value,scalar1
S1,image,4be9beed-5861-fdd2-72c2-8dd89aadc1ef
S1,label,,True,1.0
S2,image,6ceacaf8-abd2-ffec-2ade-d52afd6dd1be
S2,label,,True,2.0
S3,image,61bc9d73-9fbb-bd7d-c06b-eeffbafabcc4
S3,label,,False,3.0
S4,image,61bc9d73-9fbb-bd7d-c06b-eeffbafabcc4
S4,label,,False,3.0
""")
    df = pd.read_csv(csv_string, sep=",", dtype=str)
    args = ScalarModelBase(image_channels=["image"],
                           image_file_column="path",
                           label_channels=["label"],
                           label_value_column="value",
                           non_image_feature_channels={},
                           numerical_columns=[],
                           traverse_dirs_when_loading=True,
                           center_crop_size=center_crop_size,
                           local_dataset=test_output_dirs.root_dir)
    dataset = ScalarDataset(args, data_frame=df)
    assert len(dataset) == 4
    for i in range(4):
        item = dataset[i]
        assert isinstance(item, dict)
        images = item["images"]
        assert images is not None
        assert torch.is_tensor(images)
        expected_image_size = center_crop_size or (4, 5, 7)
        assert images.shape == (1, ) + expected_image_size
예제 #15
0
def test_image_labels_from_subject_id_multiple(test_output_dirs: OutputFolderForTests) -> None:
    config = ScalarModelBase(label_channels=["label"],
                             label_value_column="label",
                             subject_column="subject",
                             class_names=["class1", "class2", "class3"])
    config.local_dataset = test_output_dirs.root_dir / "dataset"
    config.local_dataset.mkdir()
    dataset_csv = config.local_dataset / "dataset.csv"
    dataset_csv.write_text("subject,channel,label\n"
                           "0,label,0\n"
                           "0,image,\n"
                           "1,label,1|2\n"
                           "1,image,\n")

    df = config.read_dataset_if_needed()
    dataset = ScalarDataset(args=config, data_frame=df)

    labels = get_image_labels_from_subject_id(subject_id="1",
                                                  dataset=dataset,
                                                  config=config)
    assert labels
    assert len(labels) == 2
    assert set(labels) == {config.class_names[1], config.class_names[2]}
def test_dataloader_speed(test_output_dirs: OutputFolderForTests,
                          num_dataload_workers: int, shuffle: bool) -> None:
    """
    Test how dataloaders work when using multiple processes.
    """
    ml_util.set_random_seed(0)
    # The dataset should only contain the file name stem, without extension.
    csv_string = StringIO("""subject,channel,path,value,scalar1
S1,image,4be9beed-5861-fdd2-72c2-8dd89aadc1ef
S1,label,,True,1.0
S2,image,6ceacaf8-abd2-ffec-2ade-d52afd6dd1be
S2,label,,True,2.0
S3,image,61bc9d73-9fbb-bd7d-c06b-eeffbafabcc4
S3,label,,False,3.0
S4,image,61bc9d73-9fbb-bd7d-c06b-eeffbafabcc4
S4,label,,False,3.0
""")
    args = ScalarModelBase(image_channels=[],
                           label_channels=["label"],
                           label_value_column="value",
                           non_image_feature_channels=["label"],
                           numerical_columns=["scalar1"],
                           num_dataload_workers=num_dataload_workers,
                           num_dataset_reader_workers=num_dataload_workers,
                           avoid_process_spawn_in_data_loaders=True,
                           should_validate=False)
    dataset = ScalarDataset(args,
                            data_frame=pd.read_csv(csv_string, dtype=str))
    assert len(dataset) == 4
    num_epochs = 2
    total_start_time = time.time()
    loader = dataset.as_data_loader(shuffle=shuffle, batch_size=1)
    # The order in which items are expected in each epoch, when applying shuffling, and using 1 dataloader worker
    # This was determined before making any changes to the dataloader logic
    # (that is, when the as_data_loader method returns an instance of DataLoader, rather than RepeatDataLoader)
    expected_item_order = [
        ["S2", "S1", "S4", "S3"],
        ["S4", "S3", "S1", "S2"],
    ]
    for epoch in range(num_epochs):
        actual_item_order = []
        print(f"Starting epoch {epoch}")
        epoch_start_time = time.time()
        item_start_time = time.time()
        for i, item_dict in enumerate(loader):
            item_load_time = time.time() - item_start_time
            item = ScalarItem.from_dict(item_dict)
            # noinspection PyTypeChecker
            sample_id = item.metadata[0].id  # type: ignore
            print(
                f"Loading item {i} with ID = {sample_id} in {item_load_time:0.8f} sec"
            )
            if shuffle:
                actual_item_order.append(sample_id)
            else:
                assert sample_id == f"S{i + 1}"
            if not (epoch == 0 and i == 0):
                assert item_load_time < 0.1, f"We should only see significant item load times in the first batch " \
                                             f"of the first epoch, but got loading time of {item_load_time:0.2f} sec" \
                                             f" in epoch {epoch} batch {i}"
            # Sleep a bit so that the worker process can fill in items
            if num_dataload_workers > 0:
                time.sleep(0.05)
            item_start_time = time.time()
        if shuffle and num_dataload_workers == 1:
            assert actual_item_order == expected_item_order[
                epoch], f"Item in wrong order for epoch {epoch}"
        total_epoch_time = time.time() - epoch_start_time
        print(f"Total time for epoch {epoch}: {total_epoch_time} sec")
    total_time = time.time() - total_start_time
    print(f"Total time for all epochs: {total_time} sec")
예제 #17
0
def test_image_encoder(test_output_dirs: OutputFolderForTests,
                       encode_channels_jointly: bool,
                       use_non_imaging_features: bool,
                       kernel_size_per_encoding_block: Optional[Union[TupleInt3, List[TupleInt3]]],
                       stride_size_per_encoding_block: Optional[Union[TupleInt3, List[TupleInt3]]],
                       reduction_factor: float,
                       expected_num_reduced_features: int,
                       aggregation_type: AggregationType) -> None:
    """
    Test if the image encoder networks can be trained without errors (including GradCam computation and data
    augmentation).
    """
    logging_to_stdout()
    set_random_seed(0)
    dataset_folder = Path(test_output_dirs.make_sub_dir("dataset"))
    scan_size = (6, 64, 60)
    scan_files: List[str] = []
    for s in range(4):
        random_scan = np.random.uniform(0, 1, scan_size)
        scan_file_name = f"scan{s + 1}{NumpyFile.NUMPY.value}"
        np.save(str(dataset_folder / scan_file_name), random_scan)
        scan_files.append(scan_file_name)

    dataset_contents = """subject,channel,path,label,numerical1,numerical2,categorical1,categorical2
S1,week0,scan1.npy,,1,10,Male,Val1
S1,week1,scan2.npy,True,2,20,Female,Val2
S2,week0,scan3.npy,,3,30,Female,Val3
S2,week1,scan4.npy,False,4,40,Female,Val1
S3,week0,scan1.npy,,5,50,Male,Val2
S3,week1,scan3.npy,True,6,60,Male,Val2
"""
    (dataset_folder / "dataset.csv").write_text(dataset_contents)
    numerical_columns = ["numerical1", "numerical2"] if use_non_imaging_features else []
    categorical_columns = ["categorical1", "categorical2"] if use_non_imaging_features else []
    non_image_feature_channels = get_non_image_features_dict(default_channels=["week1", "week0"],
                                                             specific_channels={"categorical2": ["week1"]}) \
        if use_non_imaging_features else {}
    config_for_dataset = ScalarModelBase(
        local_dataset=dataset_folder,
        image_channels=["week0", "week1"],
        image_file_column="path",
        label_channels=["week1"],
        label_value_column="label",
        non_image_feature_channels=non_image_feature_channels,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        should_validate=False
    )
    config_for_dataset.read_dataset_into_dataframe_and_pre_process()

    dataset = ScalarDataset(config_for_dataset,
                            sample_transforms=ScalarItemAugmentation(
                                RandAugmentSlice(is_transformation_for_segmentation_maps=False)))
    assert len(dataset) == 3

    config = ImageEncoder(
        encode_channels_jointly=encode_channels_jointly,
        should_validate=False,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        non_image_feature_channels=non_image_feature_channels,
        categorical_feature_encoder=config_for_dataset.categorical_feature_encoder,
        encoder_dimensionality_reduction_factor=reduction_factor,
        aggregation_type=aggregation_type,
        scan_size=(6, 64, 60)
    )

    if kernel_size_per_encoding_block:
        config.kernel_size_per_encoding_block = kernel_size_per_encoding_block
    if stride_size_per_encoding_block:
        config.stride_size_per_encoding_block = stride_size_per_encoding_block

    config.set_output_to(test_output_dirs.root_dir)
    config.max_batch_grad_cam = 1
    model = create_model_with_temperature_scaling(config)
    input_size: List[Tuple] = [(len(config.image_channels), *scan_size)]
    if use_non_imaging_features:
        input_size.append((config.get_total_number_of_non_imaging_features(),))

        # Original number output channels (unreduced) is
        # num initial channel * (num encoder block - 1) = 4 * (3-1) = 8
        if encode_channels_jointly:
            # reduced_num_channels + num_non_img_features
            assert model.final_num_feature_channels == expected_num_reduced_features + \
                   config.get_total_number_of_non_imaging_features()
        else:
            # num_img_channels * reduced_num_channels + num_non_img_features
            assert model.final_num_feature_channels == len(config.image_channels) * expected_num_reduced_features + \
                   config.get_total_number_of_non_imaging_features()

    summarizer = ModelSummary(model)
    summarizer.generate_summary(input_sizes=input_size)
    config.local_dataset = dataset_folder
    config.validate()
    model_train(config, checkpoint_handler=get_default_checkpoint_handler(model_config=config,
                                                                          project_root=Path(test_output_dirs.root_dir)))
예제 #18
0
def plot_k_best_and_worst_performing(val_metrics_csv: Path,
                                     test_metrics_csv: Path, k: int,
                                     prediction_target: str,
                                     config: ScalarModelBase) -> None:
    """
    Plot images for the top "k" best predictions (i.e. correct classifications where the model was the most certain)
    and the top "k" worst predictions (i.e. misclassifications where the model was the most confident).
    :param val_metrics_csv: Path to one of the metrics csvs written during inference. This set of metrics will be
                            used to determine the thresholds for predicting labels on the test set. The best and worst
                            performing subjects will not be printed out for this csv.
    :param test_metrics_csv: Path to one of the metrics csvs written during inference. This is the csv for which
                            best and worst performing subjects will be printed out.
    :param k: Number of subjects of each category to print out.
    :param prediction_target: The class label to filter on
    :param config: scalar model config object
    """
    results = get_k_best_and_worst_performing(
        val_metrics_csv=val_metrics_csv,
        test_metrics_csv=test_metrics_csv,
        k=k,
        prediction_target=prediction_target)
    if results is None:
        print_header("Empty validation or test set", level=4)
        return

    test_metrics = pd.read_csv(test_metrics_csv, dtype=str)

    df = config.read_dataset_if_needed()
    dataset = ScalarDataset(args=config, data_frame=df)

    im_width = 800

    print_header("", level=2)
    print_header(f"Top {k} false positives", level=2)
    for index, (subject, model_output) in enumerate(
            zip(results.false_positives[LoggingColumns.Patient.value],
                results.false_positives[LoggingColumns.ModelOutput.value])):
        plot_image_for_subject(subject_id=str(subject),
                               dataset=dataset,
                               im_width=im_width,
                               model_output=model_output,
                               header="False Positive",
                               config=config,
                               metrics_df=test_metrics)

    print_header(f"Top {k} false negatives", level=2)
    for index, (subject, model_output) in enumerate(
            zip(results.false_negatives[LoggingColumns.Patient.value],
                results.false_negatives[LoggingColumns.ModelOutput.value])):
        plot_image_for_subject(subject_id=str(subject),
                               dataset=dataset,
                               im_width=im_width,
                               model_output=model_output,
                               header="False Negative",
                               config=config,
                               metrics_df=test_metrics)

    print_header(f"Top {k} true positives", level=2)
    for index, (subject, model_output) in enumerate(
            zip(results.true_positives[LoggingColumns.Patient.value],
                results.true_positives[LoggingColumns.ModelOutput.value])):
        plot_image_for_subject(subject_id=str(subject),
                               dataset=dataset,
                               im_width=im_width,
                               model_output=model_output,
                               header="True Positive",
                               config=config,
                               metrics_df=test_metrics)

    print_header(f"Top {k} true negatives", level=2)
    for index, (subject, model_output) in enumerate(
            zip(results.true_negatives[LoggingColumns.Patient.value],
                results.true_negatives[LoggingColumns.ModelOutput.value])):
        plot_image_for_subject(subject_id=str(subject),
                               dataset=dataset,
                               im_width=im_width,
                               model_output=model_output,
                               header="True Negative",
                               config=config,
                               metrics_df=test_metrics)
예제 #19
0
def test_visualization_with_scalar_model(use_non_imaging_features: bool,
                                         imaging_feature_type: ImagingFeatureType,
                                         encode_channels_jointly: bool,
                                         test_output_dirs: OutputFolderForTests) -> None:
    dataset_contents = """subject,channel,path,label,numerical1,numerical2,categorical1,categorical2
    S1,week0,scan1.npy,,1,10,Male,Val1
    S1,week1,scan2.npy,True,2,20,Female,Val2
    S2,week0,scan3.npy,,3,30,Female,Val3
    S2,week1,scan4.npy,False,4,40,Female,Val1
    S3,week0,scan1.npy,,5,50,Male,Val2
    S3,week1,scan3.npy,True,6,60,Male,Val2
    """
    dataset_dataframe = pd.read_csv(StringIO(dataset_contents), dtype=str)
    numerical_columns = ["numerical1", "numerical2"] if use_non_imaging_features else []
    categorical_columns = ["categorical1", "categorical2"] if use_non_imaging_features else []
    non_image_feature_channels = get_non_image_features_dict(default_channels=["week1", "week0"],
                                                             specific_channels={"categorical2": ["week1"]}) \
        if use_non_imaging_features else {}

    config = ImageEncoder(
        local_dataset=Path(),
        encode_channels_jointly=encode_channels_jointly,
        should_validate=False,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        imaging_feature_type=imaging_feature_type,
        non_image_feature_channels=non_image_feature_channels,
        categorical_feature_encoder=CategoricalToOneHotEncoder.create_from_dataframe(
            dataframe=dataset_dataframe, columns=categorical_columns)
    )

    dataloader = ScalarDataset(config, data_frame=dataset_dataframe) \
        .as_data_loader(shuffle=False, batch_size=2)

    config.set_output_to(test_output_dirs.root_dir)
    config.num_epochs = 1
    model = create_model_with_temperature_scaling(config)
    visualizer = VisualizationMaps(model, config)
    # Patch the load_images function that will be called once we access a dataset item
    image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, (6, 64, 60)),
                                                      segmentations=np.random.randint(0, 2, (6, 64, 60)))
    with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg):
        batch = next(iter(dataloader))
        if config.use_gpu:
            device = visualizer.grad_cam.device
            batch = transfer_batch_to_device(batch, device)
            visualizer.grad_cam.model = visualizer.grad_cam.model.to(device)
        model_inputs_and_labels = get_scalar_model_inputs_and_labels(model,
                                                                     target_indices=[],
                                                                     sample=batch)
    number_channels = len(config.image_channels)
    number_subjects = len(model_inputs_and_labels.subject_ids)
    guided_grad_cams, grad_cams, pseudo_cam_non_img, probas = visualizer.generate(
        model_inputs_and_labels.model_inputs)

    if imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
        assert guided_grad_cams.shape[:2] == (number_subjects, number_channels * 2)
    else:
        assert guided_grad_cams.shape[:2] == (number_subjects, number_channels)

    assert grad_cams.shape[:2] == (number_subjects, 1) if encode_channels_jointly \
        else (number_subjects, number_channels)

    if use_non_imaging_features:
        non_image_features = config.numerical_columns + config.categorical_columns
        non_imaging_plot_labels = visualizer._get_non_imaging_plot_labels(model_inputs_and_labels.data_item,
                                                                          non_image_features,
                                                                          index=0)
        assert non_imaging_plot_labels == ['numerical1_week1',
                                           'numerical1_week0',
                                           'numerical2_week1',
                                           'numerical2_week0',
                                           'categorical1_week1',
                                           'categorical1_week0',
                                           'categorical2_week1']
        assert pseudo_cam_non_img.shape == (number_subjects, 1, len(non_imaging_plot_labels))