def test_get_class_counts_multilabel(test_output_dirs: OutputFolderForTests) -> None: """ Test the get_class_counts method for multilabel scalar datasets. """ dataset_folder = Path(test_output_dirs.make_sub_dir("dataset")) dataset_contents = """subject,channel,path,label,CAT1 S1,week0,scan1.npy,,A S1,week1,scan2.npy,0|1|2,A S2,week0,scan3.npy,,A S2,week1,scan4.npy,1|2,A S3,week0,scan1.npy,,A S3,week1,scan3.npy,1,A """ config = ScalarModelBase( local_dataset=dataset_folder, class_names=["class0", "class1", "class2", "class3"], label_channels=["week1"], label_value_column="label", non_image_feature_channels=["week0", "week1"], should_validate=False ) config.set_output_to(test_output_dirs.root_dir) train_dataset = ScalarDataset(config, pd.read_csv(StringIO(dataset_contents), dtype=str)) class_counts = train_dataset.get_class_counts() assert class_counts == {0: 1, 1: 3, 2: 2, 3: 0}
def test_get_class_weights_dataset(test_output_dirs: OutputFolderForTests) -> None: """ Test training and testing of sequence models that predicts at multiple time points, when it is started via run_ml. """ dataset_folder = Path(test_output_dirs.make_sub_dir("dataset")) dataset_contents = """subject,channel,path,label,numerical1,numerical2,CAT1 S1,week0,scan1.npy,,1,10,A S1,week1,scan2.npy,True,2,20,A S2,week0,scan3.npy,,3,30,A S2,week1,scan4.npy,False,4,40,A S3,week0,scan1.npy,,5,50,A S3,week1,scan3.npy,True,6,60,A """ config = ScalarModelBase( local_dataset=dataset_folder, label_channels=["week1"], label_value_column="label", non_image_feature_channels=["week0", "week1"], numerical_columns=["numerical1", "numerical2"], should_validate=False ) config.set_output_to(test_output_dirs.root_dir) train_dataset = ScalarDataset(config, pd.read_csv(StringIO(dataset_contents), dtype=str)) class_counts = train_dataset.get_class_counts() assert class_counts == {0.0: 1, 1.0: 2}
def test_get_labels_for_imbalanced_sampler_multilabel(test_output_dirs: OutputFolderForTests) -> None: """ Test that the get_labels_for_imbalanced_sampler method raises an error for multilabel scalar datasets. """ dataset_folder = Path(test_output_dirs.make_sub_dir("dataset")) dataset_contents = """subject,channel,path,label,CAT1 S1,week0,scan1.npy,,A S1,week1,scan2.npy,0|1|2,A S2,week0,scan3.npy,,A S2,week1,scan4.npy,1|2,A S3,week0,scan1.npy,,A S3,week1,scan3.npy,1,A """ config = ScalarModelBase( local_dataset=dataset_folder, class_names=["class0", "class1", "class2", "class3"], label_channels=["week1"], label_value_column="label", non_image_feature_channels=["week0", "week1"], should_validate=False ) config.set_output_to(test_output_dirs.root_dir) train_dataset = ScalarDataset(config, pd.read_csv(StringIO(dataset_contents), dtype=str)) with pytest.raises(NotImplementedError) as ex: train_dataset.get_labels_for_imbalanced_sampler() assert "ImbalancedSampler is not supported for multilabel tasks." in str(ex)
def test_get_labels_for_imbalanced_sampler_binary(test_output_dirs: OutputFolderForTests) -> None: """ Test the get_labels_for_imbalanced_sampler method for binary scalar datasets. """ dataset_folder = Path(test_output_dirs.make_sub_dir("dataset")) dataset_contents = """subject,channel,path,label,numerical1,numerical2,CAT1 S1,week0,scan1.npy,,1,10,A S1,week1,scan2.npy,True,2,20,A S2,week0,scan3.npy,,3,30,A S2,week1,scan4.npy,False,4,40,A S3,week0,scan1.npy,,5,50,A S3,week1,scan3.npy,True,6,60,A """ config = ScalarModelBase( local_dataset=dataset_folder, label_channels=["week1"], label_value_column="label", non_image_feature_channels=["week0", "week1"], numerical_columns=["numerical1", "numerical2"], should_validate=False ) config.set_output_to(test_output_dirs.root_dir) train_dataset = ScalarDataset(config, pd.read_csv(StringIO(dataset_contents), dtype=str)) labels = train_dataset.get_labels_for_imbalanced_sampler() assert labels == [1.0, 0.0, 1.0]
def create_torch_datasets( self, dataset_splits: DatasetSplits) -> Dict[ModelExecutionMode, Any]: from InnerEye.ML.dataset.scalar_dataset import ScalarDataset sample_transform = self.get_scalar_item_transform() assert sample_transform.train is not None # for mypy assert sample_transform.val is not None # for mypy assert sample_transform.test is not None # for mypy train = ScalarDataset(args=self, data_frame=dataset_splits.train, name="training", sample_transform=sample_transform.train) val = ScalarDataset(args=self, data_frame=dataset_splits.val, feature_statistics=train.feature_statistics, name="validation", sample_transform=sample_transform.val) test = ScalarDataset(args=self, data_frame=dataset_splits.test, feature_statistics=train.feature_statistics, name="test", sample_transform=sample_transform.test) return { ModelExecutionMode.TRAIN: train, ModelExecutionMode.VAL: val, ModelExecutionMode.TEST: test }
def create_torch_datasets( self, dataset_splits: DatasetSplits) -> Dict[ModelExecutionMode, Any]: from InnerEye.ML.dataset.scalar_dataset import ScalarDataset image_transforms = self.get_image_sample_transforms() train = ScalarDataset( args=self, data_frame=dataset_splits.train, name="training", sample_transforms=image_transforms.train) # type: ignore val = ScalarDataset( args=self, data_frame=dataset_splits.val, feature_statistics=train.feature_statistics, name="validation", sample_transforms=image_transforms.val) # type: ignore test = ScalarDataset( args=self, data_frame=dataset_splits.test, feature_statistics=train.feature_statistics, name="test", sample_transforms=image_transforms.test) # type: ignore return { ModelExecutionMode.TRAIN: train, ModelExecutionMode.VAL: val, ModelExecutionMode.TEST: test }
def test_imbalanced_sampler() -> None: # Simulate a highly imbalanced dataset with only one data point # with a negative label. csv_string = StringIO("""subject,channel,value,scalar1 S1,label,True,1.0 S2,label,True,1.0 S3,label,True,1.0 S4,label,True,1.0 S5,label,True,1.0 S6,label,False,1.0 """) torch.manual_seed(0) df = pd.read_csv(csv_string, sep=",", dtype=str) args = ScalarModelBase(label_value_column="value", numerical_columns=["scalar1"], local_dataset=Path("fakepath")) dataset = ScalarDataset(args, data_frame=df) drawn_subjects = [] for _ in range(10): data_loader = dataset.as_data_loader(use_imbalanced_sampler=True, shuffle=True, batch_size=6, num_dataload_workers=0) for batch in data_loader: drawn_subjects.extend([i.id.strip() for i in batch["metadata"]]) counts_per_subjects = Counter(drawn_subjects) count_negative_subjects = counts_per_subjects["S6"] assert count_negative_subjects / float(len(drawn_subjects)) > 0.3
def test_dataset_normalize_image( test_output_dirs: OutputFolderForTests) -> None: """ Test dataset loading with window normalization image processing. """ source_folder = str(full_ml_test_data_path() / "classification_data") target_folder = str(Path(test_output_dirs.make_sub_dir("foo")) / "bar") shutil.copytree(source_folder, target_folder) csv_string = StringIO("""subject,channel,path,value,scalar1 S1,image,4be9beed-5861-fdd2-72c2-8dd89aadc1ef S1,label,,True,1.0 S2,image,6ceacaf8-abd2-ffec-2ade-d52afd6dd1be S2,label,,True,2.0 S3,image,61bc9d73-9fbb-bd7d-c06b-eeffbafabcc4 S3,label,,False,3.0 S4,image,61bc9d73-9fbb-bd7d-c06b-eeffbafabcc4 S4,label,,False,3.0 """) df = pd.read_csv(csv_string, sep=",", dtype=str) args = ScalarModelBase(image_channels=["image"], image_file_column="path", label_channels=["label"], label_value_column="value", non_image_feature_channels={}, numerical_columns=[], traverse_dirs_when_loading=True, local_dataset=test_output_dirs.root_dir) raw_dataset = ScalarDataset(args, data_frame=df) normalized = ScalarDataset( args, data_frame=df, sample_transforms=WindowNormalizationForScalarItem()) assert len(raw_dataset) == 4 for i in range(4): raw_item = raw_dataset[i] normalized_item = normalized[i] normalized_images = normalized_item["images"] assert isinstance(raw_item, dict) expected_normalized_images = torch.tensor( mri_window(raw_item["images"].numpy(), mask=None, output_range=(0, 1))[0]) assert normalized_images is not None assert torch.is_tensor(normalized_images) assert expected_normalized_images.shape == normalized_images.shape expected_image_size = (4, 5, 7) assert normalized_images.shape == (1, ) + expected_image_size assert torch.all(expected_normalized_images == normalized_images)
def test_image_labels_from_subject_id_single( test_output_dirs: OutputFolderForTests) -> None: config = ScalarModelBase(label_value_column="label", subject_column="subject") config.local_dataset = test_output_dirs.root_dir / "dataset" config.local_dataset.mkdir() dataset_csv = config.local_dataset / "dataset.csv" dataset_csv.write_text("subject,channel,label\n" "0,label,0\n" "1,label,1\n") df = config.read_dataset_if_needed() dataset = ScalarDataset(args=config, data_frame=df) labels = get_image_labels_from_subject_id(subject_id="0", dataset=dataset, config=config) assert not labels labels = get_image_labels_from_subject_id(subject_id="1", dataset=dataset, config=config) assert labels assert len(labels) == 1 assert labels[0] == MetricsDict.DEFAULT_HUE_KEY
def test_get_image_filepath_from_subject_id_with_image_channels( test_output_dirs: OutputFolderForTests) -> None: config = ScalarModelBase(label_channels=["label"], image_file_column="filePath", label_value_column="label", image_channels=["image"], subject_column="subject") config.local_dataset = test_output_dirs.root_dir / "dataset" config.local_dataset.mkdir() dataset_csv = config.local_dataset / "dataset.csv" image_file_name = "image.npy" dataset_csv.write_text(f"subject,channel,filePath,label\n" f"0,label,,0\n" f"0,image,0_{image_file_name},\n" f"1,label,,1\n" f"1,image,1_{image_file_name},\n") df = config.read_dataset_if_needed() dataset = ScalarDataset(args=config, data_frame=df) Path(config.local_dataset / f"0_{image_file_name}").touch() Path(config.local_dataset / f"1_{image_file_name}").touch() filepath = get_image_filepath_from_subject_id(subject_id="1", dataset=dataset, config=config) expected_path = Path(config.local_dataset / f"1_{image_file_name}") assert filepath assert len(filepath) == 1 assert filepath[0].samefile(expected_path)
def test_get_image_filepath_from_subject_id_single( test_output_dirs: OutputFolderForTests) -> None: config = ScalarModelBase(image_file_column="filePath", label_value_column="label", subject_column="subject") config.local_dataset = test_output_dirs.root_dir / "dataset" config.local_dataset.mkdir() dataset_csv = config.local_dataset / "dataset.csv" image_file_name = "image.npy" dataset_csv.write_text(f"subject,filePath,label\n" f"0,0_{image_file_name},0\n" f"1,1_{image_file_name},1\n") df = config.read_dataset_if_needed() dataset = ScalarDataset(args=config, data_frame=df) Path(config.local_dataset / f"0_{image_file_name}").touch() Path(config.local_dataset / f"1_{image_file_name}").touch() filepath = get_image_filepath_from_subject_id(subject_id="1", dataset=dataset, config=config) expected_path = Path(config.local_dataset / f"1_{image_file_name}") assert filepath assert len(filepath) == 1 assert expected_path.samefile(filepath[0]) # Check error is raised if the subject does not exist with pytest.raises(ValueError) as ex: get_image_filepath_from_subject_id(subject_id="100", dataset=dataset, config=config) assert "Could not find subject" in str(ex)
def get_unique_prediction_target_combinations( config: ScalarModelBase) -> Set[FrozenSet[str]]: """ Get a list of all the combinations of labels that exist in the dataset. For multilabel classification tasks, this function will return all unique combinations of labels that occur in the dataset csv. For example, if there are 6 samples in the dataset with the following ground truth labels Sample1: class1, class2 Sample2: class0 Sample3: class1 Sample4: class2, class3 Sample5: (all label classes are negative in Sample 5) Sample6: class1, class2 This function will return {{"class1", "class2"}, {"class0"}, {"class1"}, {"class2", "class3"}, {}} For binary classification tasks (assume class_names has not been changed from ["Default"]): This function will return a set with two members - {{"Default"}, {}} if there is at least one positive example in the dataset. If there are no positive examples, it returns {{}}. """ df = config.read_dataset_if_needed() dataset = ScalarDataset(args=config, data_frame=df) all_labels = [ torch.flatten(torch.nonzero(item.label)).tolist() for item in dataset.items ] label_set = set( frozenset([config.class_names[i] for i in labels if not math.isnan(i)]) for labels in all_labels) return label_set
def _create_test_dataset(csv_path: Path, scalar_loss: ScalarLoss = ScalarLoss.BinaryCrossEntropyWithLogits, categorical_columns: Optional[List[str]] = None) -> ScalarDataset: # Load items indirectly via a ScalarDataset object, to see if the wiring up of all column names works args = ScalarModelBase(image_channels=["image"], image_file_column="path", label_channels=["label"], label_value_column="value", non_image_feature_channels=["label"], numerical_columns=["scalar1", "scalar2"], categorical_columns=categorical_columns or list(), subject_column="USUBJID", channel_column="week", local_dataset=csv_path, should_validate=False, loss_type=scalar_loss, num_dataload_workers=0) args.read_dataset_into_dataframe_and_pre_process() return ScalarDataset(args)
def test_dataset_traverse_dirs(test_output_dirs: OutputFolderForTests, center_crop_size: Optional[TupleInt3]) -> None: """ Test dataset loading when the dataset file only contains file name stems, not full paths. """ # Copy the existing test dataset to a new folder, two levels deep. Later will initialize the # dataset with only the root folder given, to check if the files are still found. source_folder = str(full_ml_test_data_path() / "classification_data") target_folder = str(Path(test_output_dirs.make_sub_dir("foo")) / "bar") shutil.copytree(source_folder, target_folder) # The dataset should only contain the file name stem, without extension. csv_string = StringIO("""subject,channel,path,value,scalar1 S1,image,4be9beed-5861-fdd2-72c2-8dd89aadc1ef S1,label,,True,1.0 S2,image,6ceacaf8-abd2-ffec-2ade-d52afd6dd1be S2,label,,True,2.0 S3,image,61bc9d73-9fbb-bd7d-c06b-eeffbafabcc4 S3,label,,False,3.0 S4,image,61bc9d73-9fbb-bd7d-c06b-eeffbafabcc4 S4,label,,False,3.0 """) df = pd.read_csv(csv_string, sep=",", dtype=str) args = ScalarModelBase(image_channels=["image"], image_file_column="path", label_channels=["label"], label_value_column="value", non_image_feature_channels={}, numerical_columns=[], traverse_dirs_when_loading=True, center_crop_size=center_crop_size, local_dataset=test_output_dirs.root_dir) dataset = ScalarDataset(args, data_frame=df) assert len(dataset) == 4 for i in range(4): item = dataset[i] assert isinstance(item, dict) images = item["images"] assert images is not None assert torch.is_tensor(images) expected_image_size = center_crop_size or (4, 5, 7) assert images.shape == (1, ) + expected_image_size
def test_image_labels_from_subject_id_multiple(test_output_dirs: OutputFolderForTests) -> None: config = ScalarModelBase(label_channels=["label"], label_value_column="label", subject_column="subject", class_names=["class1", "class2", "class3"]) config.local_dataset = test_output_dirs.root_dir / "dataset" config.local_dataset.mkdir() dataset_csv = config.local_dataset / "dataset.csv" dataset_csv.write_text("subject,channel,label\n" "0,label,0\n" "0,image,\n" "1,label,1|2\n" "1,image,\n") df = config.read_dataset_if_needed() dataset = ScalarDataset(args=config, data_frame=df) labels = get_image_labels_from_subject_id(subject_id="1", dataset=dataset, config=config) assert labels assert len(labels) == 2 assert set(labels) == {config.class_names[1], config.class_names[2]}
def test_dataloader_speed(test_output_dirs: OutputFolderForTests, num_dataload_workers: int, shuffle: bool) -> None: """ Test how dataloaders work when using multiple processes. """ ml_util.set_random_seed(0) # The dataset should only contain the file name stem, without extension. csv_string = StringIO("""subject,channel,path,value,scalar1 S1,image,4be9beed-5861-fdd2-72c2-8dd89aadc1ef S1,label,,True,1.0 S2,image,6ceacaf8-abd2-ffec-2ade-d52afd6dd1be S2,label,,True,2.0 S3,image,61bc9d73-9fbb-bd7d-c06b-eeffbafabcc4 S3,label,,False,3.0 S4,image,61bc9d73-9fbb-bd7d-c06b-eeffbafabcc4 S4,label,,False,3.0 """) args = ScalarModelBase(image_channels=[], label_channels=["label"], label_value_column="value", non_image_feature_channels=["label"], numerical_columns=["scalar1"], num_dataload_workers=num_dataload_workers, num_dataset_reader_workers=num_dataload_workers, avoid_process_spawn_in_data_loaders=True, should_validate=False) dataset = ScalarDataset(args, data_frame=pd.read_csv(csv_string, dtype=str)) assert len(dataset) == 4 num_epochs = 2 total_start_time = time.time() loader = dataset.as_data_loader(shuffle=shuffle, batch_size=1) # The order in which items are expected in each epoch, when applying shuffling, and using 1 dataloader worker # This was determined before making any changes to the dataloader logic # (that is, when the as_data_loader method returns an instance of DataLoader, rather than RepeatDataLoader) expected_item_order = [ ["S2", "S1", "S4", "S3"], ["S4", "S3", "S1", "S2"], ] for epoch in range(num_epochs): actual_item_order = [] print(f"Starting epoch {epoch}") epoch_start_time = time.time() item_start_time = time.time() for i, item_dict in enumerate(loader): item_load_time = time.time() - item_start_time item = ScalarItem.from_dict(item_dict) # noinspection PyTypeChecker sample_id = item.metadata[0].id # type: ignore print( f"Loading item {i} with ID = {sample_id} in {item_load_time:0.8f} sec" ) if shuffle: actual_item_order.append(sample_id) else: assert sample_id == f"S{i + 1}" if not (epoch == 0 and i == 0): assert item_load_time < 0.1, f"We should only see significant item load times in the first batch " \ f"of the first epoch, but got loading time of {item_load_time:0.2f} sec" \ f" in epoch {epoch} batch {i}" # Sleep a bit so that the worker process can fill in items if num_dataload_workers > 0: time.sleep(0.05) item_start_time = time.time() if shuffle and num_dataload_workers == 1: assert actual_item_order == expected_item_order[ epoch], f"Item in wrong order for epoch {epoch}" total_epoch_time = time.time() - epoch_start_time print(f"Total time for epoch {epoch}: {total_epoch_time} sec") total_time = time.time() - total_start_time print(f"Total time for all epochs: {total_time} sec")
def test_image_encoder(test_output_dirs: OutputFolderForTests, encode_channels_jointly: bool, use_non_imaging_features: bool, kernel_size_per_encoding_block: Optional[Union[TupleInt3, List[TupleInt3]]], stride_size_per_encoding_block: Optional[Union[TupleInt3, List[TupleInt3]]], reduction_factor: float, expected_num_reduced_features: int, aggregation_type: AggregationType) -> None: """ Test if the image encoder networks can be trained without errors (including GradCam computation and data augmentation). """ logging_to_stdout() set_random_seed(0) dataset_folder = Path(test_output_dirs.make_sub_dir("dataset")) scan_size = (6, 64, 60) scan_files: List[str] = [] for s in range(4): random_scan = np.random.uniform(0, 1, scan_size) scan_file_name = f"scan{s + 1}{NumpyFile.NUMPY.value}" np.save(str(dataset_folder / scan_file_name), random_scan) scan_files.append(scan_file_name) dataset_contents = """subject,channel,path,label,numerical1,numerical2,categorical1,categorical2 S1,week0,scan1.npy,,1,10,Male,Val1 S1,week1,scan2.npy,True,2,20,Female,Val2 S2,week0,scan3.npy,,3,30,Female,Val3 S2,week1,scan4.npy,False,4,40,Female,Val1 S3,week0,scan1.npy,,5,50,Male,Val2 S3,week1,scan3.npy,True,6,60,Male,Val2 """ (dataset_folder / "dataset.csv").write_text(dataset_contents) numerical_columns = ["numerical1", "numerical2"] if use_non_imaging_features else [] categorical_columns = ["categorical1", "categorical2"] if use_non_imaging_features else [] non_image_feature_channels = get_non_image_features_dict(default_channels=["week1", "week0"], specific_channels={"categorical2": ["week1"]}) \ if use_non_imaging_features else {} config_for_dataset = ScalarModelBase( local_dataset=dataset_folder, image_channels=["week0", "week1"], image_file_column="path", label_channels=["week1"], label_value_column="label", non_image_feature_channels=non_image_feature_channels, numerical_columns=numerical_columns, categorical_columns=categorical_columns, should_validate=False ) config_for_dataset.read_dataset_into_dataframe_and_pre_process() dataset = ScalarDataset(config_for_dataset, sample_transforms=ScalarItemAugmentation( RandAugmentSlice(is_transformation_for_segmentation_maps=False))) assert len(dataset) == 3 config = ImageEncoder( encode_channels_jointly=encode_channels_jointly, should_validate=False, numerical_columns=numerical_columns, categorical_columns=categorical_columns, non_image_feature_channels=non_image_feature_channels, categorical_feature_encoder=config_for_dataset.categorical_feature_encoder, encoder_dimensionality_reduction_factor=reduction_factor, aggregation_type=aggregation_type, scan_size=(6, 64, 60) ) if kernel_size_per_encoding_block: config.kernel_size_per_encoding_block = kernel_size_per_encoding_block if stride_size_per_encoding_block: config.stride_size_per_encoding_block = stride_size_per_encoding_block config.set_output_to(test_output_dirs.root_dir) config.max_batch_grad_cam = 1 model = create_model_with_temperature_scaling(config) input_size: List[Tuple] = [(len(config.image_channels), *scan_size)] if use_non_imaging_features: input_size.append((config.get_total_number_of_non_imaging_features(),)) # Original number output channels (unreduced) is # num initial channel * (num encoder block - 1) = 4 * (3-1) = 8 if encode_channels_jointly: # reduced_num_channels + num_non_img_features assert model.final_num_feature_channels == expected_num_reduced_features + \ config.get_total_number_of_non_imaging_features() else: # num_img_channels * reduced_num_channels + num_non_img_features assert model.final_num_feature_channels == len(config.image_channels) * expected_num_reduced_features + \ config.get_total_number_of_non_imaging_features() summarizer = ModelSummary(model) summarizer.generate_summary(input_sizes=input_size) config.local_dataset = dataset_folder config.validate() model_train(config, checkpoint_handler=get_default_checkpoint_handler(model_config=config, project_root=Path(test_output_dirs.root_dir)))
def plot_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Path, k: int, prediction_target: str, config: ScalarModelBase) -> None: """ Plot images for the top "k" best predictions (i.e. correct classifications where the model was the most certain) and the top "k" worst predictions (i.e. misclassifications where the model was the most confident). :param val_metrics_csv: Path to one of the metrics csvs written during inference. This set of metrics will be used to determine the thresholds for predicting labels on the test set. The best and worst performing subjects will not be printed out for this csv. :param test_metrics_csv: Path to one of the metrics csvs written during inference. This is the csv for which best and worst performing subjects will be printed out. :param k: Number of subjects of each category to print out. :param prediction_target: The class label to filter on :param config: scalar model config object """ results = get_k_best_and_worst_performing( val_metrics_csv=val_metrics_csv, test_metrics_csv=test_metrics_csv, k=k, prediction_target=prediction_target) if results is None: print_header("Empty validation or test set", level=4) return test_metrics = pd.read_csv(test_metrics_csv, dtype=str) df = config.read_dataset_if_needed() dataset = ScalarDataset(args=config, data_frame=df) im_width = 800 print_header("", level=2) print_header(f"Top {k} false positives", level=2) for index, (subject, model_output) in enumerate( zip(results.false_positives[LoggingColumns.Patient.value], results.false_positives[LoggingColumns.ModelOutput.value])): plot_image_for_subject(subject_id=str(subject), dataset=dataset, im_width=im_width, model_output=model_output, header="False Positive", config=config, metrics_df=test_metrics) print_header(f"Top {k} false negatives", level=2) for index, (subject, model_output) in enumerate( zip(results.false_negatives[LoggingColumns.Patient.value], results.false_negatives[LoggingColumns.ModelOutput.value])): plot_image_for_subject(subject_id=str(subject), dataset=dataset, im_width=im_width, model_output=model_output, header="False Negative", config=config, metrics_df=test_metrics) print_header(f"Top {k} true positives", level=2) for index, (subject, model_output) in enumerate( zip(results.true_positives[LoggingColumns.Patient.value], results.true_positives[LoggingColumns.ModelOutput.value])): plot_image_for_subject(subject_id=str(subject), dataset=dataset, im_width=im_width, model_output=model_output, header="True Positive", config=config, metrics_df=test_metrics) print_header(f"Top {k} true negatives", level=2) for index, (subject, model_output) in enumerate( zip(results.true_negatives[LoggingColumns.Patient.value], results.true_negatives[LoggingColumns.ModelOutput.value])): plot_image_for_subject(subject_id=str(subject), dataset=dataset, im_width=im_width, model_output=model_output, header="True Negative", config=config, metrics_df=test_metrics)
def test_visualization_with_scalar_model(use_non_imaging_features: bool, imaging_feature_type: ImagingFeatureType, encode_channels_jointly: bool, test_output_dirs: OutputFolderForTests) -> None: dataset_contents = """subject,channel,path,label,numerical1,numerical2,categorical1,categorical2 S1,week0,scan1.npy,,1,10,Male,Val1 S1,week1,scan2.npy,True,2,20,Female,Val2 S2,week0,scan3.npy,,3,30,Female,Val3 S2,week1,scan4.npy,False,4,40,Female,Val1 S3,week0,scan1.npy,,5,50,Male,Val2 S3,week1,scan3.npy,True,6,60,Male,Val2 """ dataset_dataframe = pd.read_csv(StringIO(dataset_contents), dtype=str) numerical_columns = ["numerical1", "numerical2"] if use_non_imaging_features else [] categorical_columns = ["categorical1", "categorical2"] if use_non_imaging_features else [] non_image_feature_channels = get_non_image_features_dict(default_channels=["week1", "week0"], specific_channels={"categorical2": ["week1"]}) \ if use_non_imaging_features else {} config = ImageEncoder( local_dataset=Path(), encode_channels_jointly=encode_channels_jointly, should_validate=False, numerical_columns=numerical_columns, categorical_columns=categorical_columns, imaging_feature_type=imaging_feature_type, non_image_feature_channels=non_image_feature_channels, categorical_feature_encoder=CategoricalToOneHotEncoder.create_from_dataframe( dataframe=dataset_dataframe, columns=categorical_columns) ) dataloader = ScalarDataset(config, data_frame=dataset_dataframe) \ .as_data_loader(shuffle=False, batch_size=2) config.set_output_to(test_output_dirs.root_dir) config.num_epochs = 1 model = create_model_with_temperature_scaling(config) visualizer = VisualizationMaps(model, config) # Patch the load_images function that will be called once we access a dataset item image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, (6, 64, 60)), segmentations=np.random.randint(0, 2, (6, 64, 60))) with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg): batch = next(iter(dataloader)) if config.use_gpu: device = visualizer.grad_cam.device batch = transfer_batch_to_device(batch, device) visualizer.grad_cam.model = visualizer.grad_cam.model.to(device) model_inputs_and_labels = get_scalar_model_inputs_and_labels(model, target_indices=[], sample=batch) number_channels = len(config.image_channels) number_subjects = len(model_inputs_and_labels.subject_ids) guided_grad_cams, grad_cams, pseudo_cam_non_img, probas = visualizer.generate( model_inputs_and_labels.model_inputs) if imaging_feature_type == ImagingFeatureType.ImageAndSegmentation: assert guided_grad_cams.shape[:2] == (number_subjects, number_channels * 2) else: assert guided_grad_cams.shape[:2] == (number_subjects, number_channels) assert grad_cams.shape[:2] == (number_subjects, 1) if encode_channels_jointly \ else (number_subjects, number_channels) if use_non_imaging_features: non_image_features = config.numerical_columns + config.categorical_columns non_imaging_plot_labels = visualizer._get_non_imaging_plot_labels(model_inputs_and_labels.data_item, non_image_features, index=0) assert non_imaging_plot_labels == ['numerical1_week1', 'numerical1_week0', 'numerical2_week1', 'numerical2_week0', 'categorical1_week1', 'categorical1_week0', 'categorical2_week1'] assert pseudo_cam_non_img.shape == (number_subjects, 1, len(non_imaging_plot_labels))