def test_cropping_dataset_as_data_loader(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None: batch_size = 2 loader = cropping_dataset.as_data_loader( shuffle=True, batch_size=batch_size, num_dataload_workers=num_dataload_workers) for i, item in enumerate(loader): item = CroppedSample.from_dict(sample=item) assert item is not None assert item.image.shape == \ (batch_size, cropping_dataset.args.number_of_image_channels) + cropping_dataset.args.crop_size # type: ignore assert item.mask.shape == ( batch_size, ) + cropping_dataset.args.crop_size # type: ignore assert item.labels.shape == \ (batch_size, cropping_dataset.args.number_of_classes) + cropping_dataset.args.crop_size # type: ignore # check the mask center crops are as expected assert item.mask_center_crop.shape == ( batch_size, ) + cropping_dataset.args.center_size # type: ignore assert item.labels_center_crop.shape == \ (batch_size, cropping_dataset.args.number_of_classes) + cropping_dataset.args.center_size # type: ignore # check the contents of the center crops for b in range(batch_size): expected = image_util.get_center_crop( image=item.mask[b], crop_shape=cropping_dataset.args.center_size) assert np.array_equal(item.mask_center_crop[b], expected) for c in range(len(item.labels_center_crop[b])): expected = image_util.get_center_crop( image=item.labels[b][c], crop_shape=cropping_dataset.args.center_size) assert np.array_equal(item.labels_center_crop[b][c], expected)
def check_patient_id_in_dataset(loader: DataLoader, split: pd.DataFrame) -> None: subjects = list(split.subject.unique()) for i, x in enumerate(loader): sample_from_loader = CroppedSample.from_dict(x) assert isinstance(sample_from_loader.metadata, list) assert len(sample_from_loader.metadata) == 1 assert sample_from_loader.metadata[0].patient_id in subjects
def test_cropping_dataset_padding(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None: """ Tests the data type of torch tensors (e.g. image, labels, and mask) created by the dataset generator, which are provided as input into the computational graph :return: """ cropping_dataset.args.crop_size = (300, 300, 300) cropping_dataset.args.padding_mode = PaddingMode.Zero loader = cropping_dataset.as_data_loader(shuffle=True, batch_size=2, num_dataload_workers=1) for i, item in enumerate(loader): sample = CroppedSample.from_dict(item) assert sample.image.shape[-3:] == cropping_dataset.args.crop_size
def extract_activation_maps(args: ModelConfigBase) -> None: """ Extracts and saves activation maps of a specific layer of a trained network :param args: :return: """ model = create_model_with_temperature_scaling(args) if args.use_gpu: model = torch.nn.DataParallel(model, device_ids=list( range(torch.cuda.device_count()))) model = model.cuda() checkpoint_path = args.get_path_to_checkpoint() if checkpoint_path.is_file(): checkpoint = torch.load(checkpoint_path) # type: ignore model.load_state_dict(checkpoint['state_dict']) else: raise FileNotFoundError("Could not find checkpoint") model.eval() val_loader = args.create_data_loaders()[ModelExecutionMode.VAL] feature_extractor = model_hooks.HookBasedFeatureExtractor( model, layer_name=args.activation_map_layers) for batch, sample in enumerate(val_loader): sample = CroppedSample.from_dict(sample=sample) input_image = sample.image.cuda().float() feature_extractor(input_image) # access first image of batch of feature maps activation_map = feature_extractor.outputs[0][0].cpu().numpy() if len(activation_map.shape) == 4: visualize_3d_activation_map(activation_map, args) elif len(activation_map.shape) == 3: visualize_2d_activation_map(activation_map, args) else: raise NotImplementedError( 'cannot visualize activation map of shape', activation_map.shape) # Only visualize the first validation example break
def test_cropping_dataset_has_reproducible_randomness(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None: cropping_dataset.dataset_indices = [1, 2] * 2 expected_center_indices = None for k in range(3): ml_util.set_random_seed(1) loader = cropping_dataset.as_data_loader(shuffle=True, batch_size=4, num_dataload_workers=num_dataload_workers) for i, item in enumerate(loader): item = CroppedSample.from_dict(sample=item) if expected_center_indices is None: expected_center_indices = item.center_indices else: assert np.array_equal(expected_center_indices, item.center_indices)
def test_cropping_dataset_sample_dtype(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None: """ Tests the data type of torch tensors (e.g. image, labels, and mask) created by the dataset generator, which are provided as input into the computational graph :return: """ loader = cropping_dataset.as_data_loader(shuffle=True, batch_size=2, num_dataload_workers=num_dataload_workers) for i, item in enumerate(loader): item = CroppedSample.from_dict(item) assert item.image.numpy().dtype == ImageDataType.IMAGE.value assert item.labels.numpy().dtype == ImageDataType.SEGMENTATION.value assert item.mask.numpy().dtype == ImageDataType.MASK.value assert item.mask_center_crop.numpy().dtype == ImageDataType.MASK.value assert item.labels_center_crop.numpy().dtype == ImageDataType.SEGMENTATION.value
def training_or_validation_step(self, sample: Dict[str, Any], batch_index: int, is_training: bool) -> torch.Tensor: """ Runs training for a single minibatch of training or validation data, and computes all metrics. :param is_training: If true, the method is called from `training_step`, otherwise it is called from `validation_step`. :param sample: The batched sample on which the model should be trained. :param batch_index: The index of the present batch (supplied only for diagnostics). """ cropped_sample: CroppedSample = CroppedSample.from_dict(sample=sample) # Forward propagation can lead to a model output that is smaller than the input image (crop). # labels_center_crop is the relevant part of the labels tensor that the model will actually produce. labels = cropped_sample.labels_center_crop mask = cropped_sample.mask_center_crop if is_training else None if is_training: logits = self.model(cropped_sample.image) else: with torch.no_grad(): logits = self.model(cropped_sample.image) loss = self.loss_fn(logits, labels) # apply Softmax on dimension 1 (Class) to map model output into a posterior probability distribution [0,1] posteriors = self.logits_to_posterior(logits) # apply mask if required if mask is not None: posteriors = image_util.apply_mask_to_posteriors( posteriors=posteriors, mask=mask) # type: ignore # post process posteriors to compute result segmentation = image_util.posteriors_to_segmentation( posteriors=posteriors) # type: ignore self.compute_metrics(cropped_sample, segmentation, is_training) # type: ignore self.write_loss(is_training, loss) return loss
def forward_and_backward_minibatch( self, sample: Dict[str, Any], batch_index: int, epoch: int) -> ModelForwardAndBackwardsOutputs: """ Runs training for a single minibatch of training data, and computes all metrics. :param sample: The batched sample on which the model should be trained. :param batch_index: The index of the present batch (supplied only for diagnostics). :param epoch: The number of the present epoch. """ cropped_sample: CroppedSample = CroppedSample.from_dict(sample=sample) labels = self.model_config.get_gpu_tensor_if_possible( cropped_sample.labels_center_crop) mask = None if self.train_val_params.in_training_mode else cropped_sample.mask_center_crop forward_pass_result = self.pipeline.forward_pass_patches( patches=cropped_sample.image, labels=labels, mask=mask) # Clear the GPU cache between forward and backward passes to avoid possible out-of-memory torch.cuda.empty_cache() dice_for_all_classes = metrics.compute_dice_across_patches( segmentation=torch.tensor( forward_pass_result.segmentations).long(), ground_truth=labels, use_cuda=self.model_config.use_gpu, allow_multiple_classes_for_each_pixel=True).cpu().numpy() foreground_voxels = metrics_util.get_number_of_voxels_per_class( cropped_sample.labels) # loss is a scalar, also when running the forward pass over multiple crops. # dice_for_all_structures has one row per crop. if forward_pass_result.loss is None: raise ValueError( "During training, the loss should always be computed, but the value is None." ) loss = forward_pass_result.loss # store metrics per batch self.metrics.add_metric(MetricType.LOSS, loss) for i, ground_truth_id in enumerate( self.metrics.get_hue_names(include_default=False)): for b in range(dice_for_all_classes.shape[0]): self.metrics.add_metric(MetricType.DICE, dice_for_all_classes[b, i].item(), hue=ground_truth_id, skip_nan_when_averaging=True) self.metrics.add_metric(MetricType.VOXEL_COUNT, foreground_voxels[i], hue=ground_truth_id) # store diagnostics per batch center_indices = cropped_sample.center_indices if isinstance(center_indices, torch.Tensor): center_indices = center_indices.cpu().numpy() self.metrics.add_diagnostics(MetricType.PATCH_CENTER.value, np.copy(center_indices)) if self.train_val_params.in_training_mode: # store the sample train patch from this epoch for visualization if batch_index == self.example_to_save and self.model_config.store_dataset_sample: _store_dataset_sample(self.model_config, self.train_val_params.epoch, forward_pass_result, cropped_sample) return ModelForwardAndBackwardsOutputs( loss=loss, logits=forward_pass_result.posteriors, labels=forward_pass_result.segmentations)