Пример #1
0
 def __init__(self, model_config: SegmentationModelBase,
              train_val_params: TrainValidateParameters[DeviceAwareModule]):
     """
     Creates a new instance of the class.
     :param model_config: The configuration of a segmentation model.
     :param train_val_params: The parameters for training the model, including the optimizer and the data loaders.
     """
     super().__init__(model_config, train_val_params)
     self.example_to_save = np.random.randint(
         0, len(train_val_params.data_loader))
     self.pipeline = SegmentationForwardPass(
         model=self.train_val_params.model,
         model_config=self.model_config,
         batch_size=self.model_config.train_batch_size,
         optimizer=self.train_val_params.optimizer,
         in_training_mode=self.train_val_params.in_training_mode,
         criterion=self.compute_loss,
         gradient_scaler=train_val_params.gradient_scaler)
     self.metrics = MetricsDict(hues=[BACKGROUND_CLASS_NAME] +
                                model_config.ground_truth_ids)
def test_anomaly_detection(value_to_insert: float,
                           in_training_mode: bool) -> None:
    """
    Test anomaly detection for the segmentation forward pass.
    :param value_to_insert: The value to insert in the image image (nan, inf, or a valid float)
    :param in_training_mode: If true, run the segmentation forward pass in training mode, otherwise use the
    settings for running on the validation set.
    :return:
    """
    image_size = [1, 1, 4, 4, 4]
    labels_size = [1, 2, 4, 4, 4]
    mask_size = [1, 4, 4, 4]
    crop_size = (4, 4, 4)
    inference_stride_size = (2, 2, 2)
    ground_truth_ids = ["Lung"]

    # image to run inference on
    image = torch.from_numpy(
        np.random.uniform(size=image_size).astype(ImageDataType.IMAGE.value))
    # labels for criterion
    labels = torch.from_numpy(
        np.random.uniform(size=labels_size).astype(
            ImageDataType.SEGMENTATION.value))
    # create a random mask if required
    mask = torch.from_numpy((np.round(np.random.uniform(
        size=mask_size)).astype(dtype=ImageDataType.MASK.value)))

    config = SegmentationModelBase(crop_size=crop_size,
                                   inference_stride_size=inference_stride_size,
                                   image_channels=["ct"],
                                   ground_truth_ids=ground_truth_ids,
                                   should_validate=False,
                                   detect_anomaly=True)

    # instantiate the model
    model = SimpleModel(1, [1], 2, 2)
    config.adjust_after_mixed_precision_and_parallel(model)
    config.use_gpu = False

    # Create the optimizer_type and loss criterion
    optimizer = model_util.create_optimizer(config, model)
    criterion = lambda x, y: torch.tensor(value_to_insert, requires_grad=True)
    pipeline = SegmentationForwardPass(model,
                                       config,
                                       batch_size=1,
                                       optimizer=optimizer,
                                       in_training_mode=in_training_mode,
                                       criterion=criterion)
    image[0, 0, 0, 0, 0] = value_to_insert
    if np.isnan(value_to_insert) or np.isinf(value_to_insert):
        with pytest.raises(RuntimeError) as ex:
            pipeline.forward_pass_patches(patches=image,
                                          mask=mask,
                                          labels=labels)
        assert f"loss computation returned {value_to_insert}" in str(ex)
    else:
        pipeline.forward_pass_patches(patches=image, mask=mask, labels=labels)
Пример #3
0
    def _model_fn(self, patches: np.ndarray) -> np.ndarray:
        """
        Wrapper function to handle the model forward pass
        :param patches: Image patches to be passed to the model in format Patches x Channels x Z x Y x X
        :return posteriors: Confidence maps [0,1] for each patch per class
        in format: Patches x Channels x Class x Z x Y x X
        """
        model_config = self.get_configs()

        # get the model from the pipeline environment
        model = self.pipeline.get_variable(InferencePipeline.Variables.Model)

        # convert patches to Torch tensor
        patches = torch.from_numpy(patches).float()

        return SegmentationForwardPass(
            model=model,
            model_config=model_config,
            batch_size=model_config.inference_batch_size,
            optimizer=None,
            in_training_mode=False).forward_pass_patches(
                patches=patches).posteriors
def test_amp_activated(use_model_parallel: bool,
                       execution_mode: ModelExecutionMode,
                       use_mixed_precision: bool) -> None:
    """
    Tests the mix precision flag and the model parallel flag.
    """
    assert machine_has_gpu, "This test must be executed on a GPU machine."
    assert torch.cuda.device_count(
    ) > 1, "This test must be executed on a multi-GPU machine"
    # image, labels, and mask to run forward and backward passes
    image = torch.from_numpy(
        np.random.uniform(size=[1, 1, 4, 4, 4]).astype(
            ImageDataType.IMAGE.value))
    labels = torch.from_numpy(
        np.random.uniform(size=[1, 2, 4, 4, 4]).astype(
            ImageDataType.SEGMENTATION.value))
    mask = torch.from_numpy((np.round(np.random.uniform(
        size=[1, 4, 4, 4])).astype(dtype=ImageDataType.MASK.value)))

    crop_size = (4, 4, 4)

    model = SimpleModel(1, [1], 2, 2)
    model_config = SegmentationModelBase(
        crop_size=crop_size,
        image_channels=["ct"],
        ground_truth_ids=["Lung"],
        use_mixed_precision=use_mixed_precision,
        use_model_parallel=use_model_parallel,
        should_validate=False)
    assert model_config.use_gpu
    # Move the model to the GPU. This is mostly to avoid issues with AMP, which has trouble
    # with first using a GPU model and later using a CPU-based one.
    model = model.cuda()
    optimizer = model_util.create_optimizer(model_config, model)
    model_and_info = ModelAndInfo(model, optimizer)
    try:
        model_and_info_amp = model_util.update_model_for_multiple_gpus(
            model_and_info, model_config, execution_mode)
    except NotImplementedError as ex:
        if use_model_parallel:
            # The SimpleModel does not implement model partitioning, and should hence fail at this step.
            assert "Model partitioning is not implemented" in str(ex)
            return
        else:
            raise ValueError(f"Expected this call to succeed, but got: {ex}")

    # This is the same logic spelt out in update_model_for_multiple_gpu
    use_data_parallel = (execution_mode == ModelExecutionMode.TRAIN) or (
        not use_model_parallel)
    if use_data_parallel:
        assert isinstance(model_and_info.model, DataParallelModel)
    gradient_scaler = GradScaler() if use_mixed_precision else None
    criterion = lambda x, y: torch.tensor([0.0], requires_grad=True).cuda()
    pipeline = SegmentationForwardPass(model_and_info_amp.model,
                                       model_config,
                                       batch_size=1,
                                       optimizer=optimizer,
                                       gradient_scaler=gradient_scaler,
                                       criterion=criterion)
    logits, _ = pipeline._compute_loss(image, labels)
    # When using DataParallel, we expect to get a list of tensors back, one per GPU.
    if use_data_parallel:
        assert isinstance(logits, list)
        first_logit = logits[0]
    else:
        first_logit = logits
    if use_mixed_precision:
        assert first_logit.dtype == torch.float16
    else:
        assert first_logit.dtype == torch.float32
    # Verify that forward and backward passes do not throw an exception
    pipeline._forward_pass(patches=image, mask=mask, labels=labels)
Пример #5
0
class ModelTrainingStepsForSegmentation(
        ModelTrainingStepsBase[SegmentationModelBase, DeviceAwareModule]):
    """
    This class implements all steps necessary for training an image segmentation model during a single epoch.
    """
    def __init__(self, model_config: SegmentationModelBase,
                 train_val_params: TrainValidateParameters[DeviceAwareModule]):
        """
        Creates a new instance of the class.
        :param model_config: The configuration of a segmentation model.
        :param train_val_params: The parameters for training the model, including the optimizer and the data loaders.
        """
        super().__init__(model_config, train_val_params)
        self.example_to_save = np.random.randint(
            0, len(train_val_params.data_loader))
        self.pipeline = SegmentationForwardPass(
            model=self.train_val_params.model,
            model_config=self.model_config,
            batch_size=self.model_config.train_batch_size,
            optimizer=self.train_val_params.optimizer,
            in_training_mode=self.train_val_params.in_training_mode,
            criterion=self.compute_loss,
            gradient_scaler=train_val_params.gradient_scaler)
        self.metrics = MetricsDict(hues=[BACKGROUND_CLASS_NAME] +
                                   model_config.ground_truth_ids)

    def create_loss_function(self) -> torch.nn.Module:
        """
        Returns a torch module that computes a loss function.
        """
        return self.construct_loss_function(self.model_config)

    @classmethod
    def construct_loss_function(
            cls, model_config: SegmentationModelBase
    ) -> SupervisedLearningCriterion:
        """
        Returns a loss function from the model config; mixture losses are constructed as weighted combinations of
        other loss functions.
        """
        if model_config.loss_type == SegmentationLoss.Mixture:
            components = model_config.mixture_loss_components
            assert components is not None
            sum_weights = sum(component.weight for component in components)
            weights_and_losses = []
            for component in components:
                normalized_weight = component.weight / sum_weights
                loss_function = cls.construct_non_mixture_loss_function(
                    model_config, component.loss_type,
                    component.class_weight_power)
                weights_and_losses.append((normalized_weight, loss_function))
            return MixtureLoss(weights_and_losses)
        return cls.construct_non_mixture_loss_function(
            model_config, model_config.loss_type,
            model_config.loss_class_weight_power)

    @classmethod
    def construct_non_mixture_loss_function(
            cls, model_config: SegmentationModelBase,
            loss_type: SegmentationLoss,
            power: Optional[float]) -> SupervisedLearningCriterion:
        """
        :param model_config: model configuration to get some parameters from
        :param loss_type: type of loss function
        :param power: value for class_weight_power for the loss function
        :return: instance of loss function
        """
        if loss_type == SegmentationLoss.SoftDice:
            return SoftDiceLoss(class_weight_power=power)
        elif loss_type == SegmentationLoss.CrossEntropy:
            return CrossEntropyLoss(
                class_weight_power=power,
                smoothing_eps=model_config.label_smoothing_eps,
                focal_loss_gamma=None)
        elif loss_type == SegmentationLoss.Focal:
            return CrossEntropyLoss(
                class_weight_power=power,
                smoothing_eps=model_config.label_smoothing_eps,
                focal_loss_gamma=model_config.focal_loss_gamma)
        else:
            raise NotImplementedError(
                "Loss type {} is not implemented".format(loss_type))

    def forward_and_backward_minibatch(
            self, sample: Dict[str, Any], batch_index: int,
            epoch: int) -> ModelForwardAndBackwardsOutputs:
        """
        Runs training for a single minibatch of training data, and computes all metrics.
        :param sample: The batched sample on which the model should be trained.
        :param batch_index: The index of the present batch (supplied only for diagnostics).
        :param epoch: The number of the present epoch.
        """
        cropped_sample: CroppedSample = CroppedSample.from_dict(sample=sample)
        labels = self.model_config.get_gpu_tensor_if_possible(
            cropped_sample.labels_center_crop)

        mask = None if self.train_val_params.in_training_mode else cropped_sample.mask_center_crop
        forward_pass_result = self.pipeline.forward_pass_patches(
            patches=cropped_sample.image, labels=labels, mask=mask)
        # Clear the GPU cache between forward and backward passes to avoid possible out-of-memory
        torch.cuda.empty_cache()
        dice_for_all_classes = metrics.compute_dice_across_patches(
            segmentation=torch.tensor(
                forward_pass_result.segmentations).long(),
            ground_truth=labels,
            use_cuda=self.model_config.use_gpu,
            allow_multiple_classes_for_each_pixel=True).cpu().numpy()
        foreground_voxels = metrics_util.get_number_of_voxels_per_class(
            cropped_sample.labels)
        # loss is a scalar, also when running the forward pass over multiple crops.
        # dice_for_all_structures has one row per crop.
        if forward_pass_result.loss is None:
            raise ValueError(
                "During training, the loss should always be computed, but the value is None."
            )
        loss = forward_pass_result.loss

        # store metrics per batch
        self.metrics.add_metric(MetricType.LOSS, loss)
        for i, ground_truth_id in enumerate(
                self.metrics.get_hue_names(include_default=False)):
            for b in range(dice_for_all_classes.shape[0]):
                self.metrics.add_metric(MetricType.DICE,
                                        dice_for_all_classes[b, i].item(),
                                        hue=ground_truth_id,
                                        skip_nan_when_averaging=True)
            self.metrics.add_metric(MetricType.VOXEL_COUNT,
                                    foreground_voxels[i],
                                    hue=ground_truth_id)
        # store diagnostics per batch
        center_indices = cropped_sample.center_indices
        if isinstance(center_indices, torch.Tensor):
            center_indices = center_indices.cpu().numpy()
        self.metrics.add_diagnostics(MetricType.PATCH_CENTER.value,
                                     np.copy(center_indices))
        if self.train_val_params.in_training_mode:
            # store the sample train patch from this epoch for visualization
            if batch_index == self.example_to_save and self.model_config.store_dataset_sample:
                _store_dataset_sample(self.model_config,
                                      self.train_val_params.epoch,
                                      forward_pass_result, cropped_sample)

        return ModelForwardAndBackwardsOutputs(
            loss=loss,
            logits=forward_pass_result.posteriors,
            labels=forward_pass_result.segmentations)

    def get_epoch_results_and_store(self,
                                    epoch_time_seconds: float) -> MetricsDict:
        """
        Assembles all training results that were achieved over all minibatches, writes them to Tensorboard and
        AzureML, and returns them as a MetricsDict object.
        :param epoch_time_seconds: For diagnostics, this is the total time in seconds for training the present epoch.
        :return: A dictionary that holds all metrics averaged over the epoch.
        """
        self.metrics.add_metric(MetricType.SECONDS_PER_EPOCH,
                                epoch_time_seconds)
        assert len(self.train_val_params.epoch_learning_rate
                   ) == 1, "Expected a single entry for learning rate."
        self.metrics.add_metric(MetricType.LEARNING_RATE,
                                self.train_val_params.epoch_learning_rate[0])
        result = metrics.aggregate_segmentation_metrics(self.metrics)
        metrics.store_epoch_metrics(self.azure_and_tensorboard_logger,
                                    self.df_logger,
                                    self.train_val_params.epoch, result,
                                    self.train_val_params.epoch_learning_rate,
                                    self.model_config)
        return result