예제 #1
0
def test_visualization_for_different_target_weeks(test_output_dirs: TestOutputDirectories) -> None:
    """
    Tests that the visualizations are differentiated depending on the target week
    for which we visualize it.
    """
    config = ToyMultiLabelSequenceModel(should_validate=False)
    config.set_output_to(test_output_dirs.root_dir)
    config.dataset_data_frame = _get_multi_label_sequence_dataframe()
    config.pre_process_dataset_dataframe()
    model = create_model_with_temperature_scaling(config)
    dataloader = SequenceDataset(config,
                                 data_frame=config.dataset_data_frame).as_data_loader(shuffle=False,
                                                                                      batch_size=2)
    batch = next(iter(dataloader))
    model_inputs_and_labels = get_scalar_model_inputs_and_labels(config, model, batch)  # type: ignore

    visualizer = VisualizationMaps(model, config)
    # Pseudo-grad cam explaining the prediction at target sequence 2
    _, _, pseudo_cam_non_img_3, probas_3 = visualizer.generate(model_inputs_and_labels.model_inputs,
                                                               target_position=2,
                                                               target_label_index=2)
    # Pseudo-grad cam explaining the prediction at target sequence 0
    _, _, pseudo_cam_non_img_1, probas_1 = visualizer.generate(model_inputs_and_labels.model_inputs,
                                                               target_position=0,
                                                               target_label_index=0)
    assert pseudo_cam_non_img_1.shape[1] == 1
    assert pseudo_cam_non_img_3.shape[1] == 3
    # Both visualizations should not be equal
    assert np.any(pseudo_cam_non_img_1 != pseudo_cam_non_img_3)
    assert np.any(probas_3 != probas_1)
def test_visualization_with_sequence_model(
        use_combined_model: bool, imaging_feature_type: ImagingFeatureType,
        test_output_dirs: OutputFolderForTests) -> None:
    config = ToySequenceModel(use_combined_model,
                              imaging_feature_type,
                              should_validate=False)
    config.set_output_to(test_output_dirs.root_dir)
    config.dataset_data_frame = _get_mock_sequence_dataset()
    config.num_epochs = 1
    model = config.create_model()
    if config.use_gpu:
        model = model.cuda()
    dataloader = SequenceDataset(
        config,
        data_frame=config.dataset_data_frame).as_data_loader(shuffle=False,
                                                             batch_size=2)
    # Patch the load_images function that will be called once we access a dataset item
    image_and_seg = ImageAndSegmentations[np.ndarray](
        images=np.random.uniform(0, 1, SCAN_SIZE),
        segmentations=np.random.randint(0, 2, SCAN_SIZE))
    with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats',
                    return_value=image_and_seg):
        batch = next(iter(dataloader))
        if config.use_gpu:
            batch = transfer_batch_to_device(batch, torch.device(0))
        model_inputs_and_labels = get_scalar_model_inputs_and_labels(
            model, target_indices=config.get_target_indices(),
            sample=batch)  # type: ignore
    number_sequences = model_inputs_and_labels.model_inputs[0].shape[1]
    number_subjects = len(model_inputs_and_labels.subject_ids)
    visualizer = VisualizationMaps(model, config)
    guided_grad_cams, grad_cams, pseudo_cam_non_img, probas = visualizer.generate(
        model_inputs_and_labels.model_inputs)
    if use_combined_model:
        if imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
            assert guided_grad_cams.shape[:2] == (number_subjects,
                                                  number_sequences * 2)
            assert grad_cams.shape[:2] == (number_subjects,
                                           number_sequences * 2)
        else:
            assert guided_grad_cams.shape[:2] == (number_subjects,
                                                  number_sequences)
            assert grad_cams.shape[:2] == (number_subjects, number_sequences)
    else:
        assert guided_grad_cams is None
        assert grad_cams is None
        assert pseudo_cam_non_img.shape[:2] == (number_subjects,
                                                number_sequences)
        assert probas.shape[0] == number_subjects
    non_image_features = config.numerical_columns + config.categorical_columns
    non_imaging_plot_labels = visualizer._get_non_imaging_plot_labels(
        model_inputs_and_labels.data_item,
        non_image_features,
        index=0,
        target_position=3)
    assert non_imaging_plot_labels == [
        'numerical1_0', 'numerical2_0', 'cat1_0', 'numerical1_1',
        'numerical2_1', 'cat1_1', 'numerical1_2', 'numerical2_2', 'cat1_2',
        'numerical1_3', 'numerical2_3', 'cat1_3'
    ]
예제 #3
0
    def __init__(self, config: F,
                 train_val_params: TrainValidateParameters[DeviceAwareModule]):
        """
        Creates a new instance of the class.
        :param config: The configuration of a classification model.
        :param train_val_params: The parameters for training the model, including the optimizer and the data loaders.
        """
        # This field needs to be defined in the constructor to keep pycharm happy, but before the call to the
        # base class because the base class constructor create_loss_function
        self.label_tensor_dtype = torch.float32
        super().__init__(config, train_val_params)
        self.metrics = create_metrics_dict_from_config(config)
        self.compute_mean_teacher_model = self.model_config.compute_mean_teacher_model

        if self.model_config.compute_grad_cam:
            model_to_evaluate = self.train_val_params.mean_teacher_model if \
                self.model_config.compute_mean_teacher_model else self.train_val_params.model
            self.guided_grad_cam = VisualizationMaps(model_to_evaluate,
                                                     self.model_config)
            self.model_config.visualization_folder.mkdir(exist_ok=True)
예제 #4
0
def test_visualization_with_scalar_model(use_non_imaging_features: bool,
                                         imaging_feature_type: ImagingFeatureType,
                                         encode_channels_jointly: bool,
                                         test_output_dirs: OutputFolderForTests) -> None:
    dataset_contents = """subject,channel,path,label,numerical1,numerical2,categorical1,categorical2
    S1,week0,scan1.npy,,1,10,Male,Val1
    S1,week1,scan2.npy,True,2,20,Female,Val2
    S2,week0,scan3.npy,,3,30,Female,Val3
    S2,week1,scan4.npy,False,4,40,Female,Val1
    S3,week0,scan1.npy,,5,50,Male,Val2
    S3,week1,scan3.npy,True,6,60,Male,Val2
    """
    dataset_dataframe = pd.read_csv(StringIO(dataset_contents), dtype=str)
    numerical_columns = ["numerical1", "numerical2"] if use_non_imaging_features else []
    categorical_columns = ["categorical1", "categorical2"] if use_non_imaging_features else []
    non_image_feature_channels = get_non_image_features_dict(default_channels=["week1", "week0"],
                                                             specific_channels={"categorical2": ["week1"]}) \
        if use_non_imaging_features else {}

    config = ImageEncoder(
        local_dataset=Path(),
        encode_channels_jointly=encode_channels_jointly,
        should_validate=False,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
        imaging_feature_type=imaging_feature_type,
        non_image_feature_channels=non_image_feature_channels,
        categorical_feature_encoder=CategoricalToOneHotEncoder.create_from_dataframe(
            dataframe=dataset_dataframe, columns=categorical_columns)
    )

    dataloader = ScalarDataset(config, data_frame=dataset_dataframe) \
        .as_data_loader(shuffle=False, batch_size=2)

    config.set_output_to(test_output_dirs.root_dir)
    config.num_epochs = 1
    model = create_model_with_temperature_scaling(config)
    visualizer = VisualizationMaps(model, config)
    # Patch the load_images function that will be called once we access a dataset item
    image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, (6, 64, 60)),
                                                      segmentations=np.random.randint(0, 2, (6, 64, 60)))
    with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg):
        batch = next(iter(dataloader))
        if config.use_gpu:
            device = visualizer.grad_cam.device
            batch = transfer_batch_to_device(batch, device)
            visualizer.grad_cam.model = visualizer.grad_cam.model.to(device)
        model_inputs_and_labels = get_scalar_model_inputs_and_labels(model,
                                                                     target_indices=[],
                                                                     sample=batch)
    number_channels = len(config.image_channels)
    number_subjects = len(model_inputs_and_labels.subject_ids)
    guided_grad_cams, grad_cams, pseudo_cam_non_img, probas = visualizer.generate(
        model_inputs_and_labels.model_inputs)

    if imaging_feature_type == ImagingFeatureType.ImageAndSegmentation:
        assert guided_grad_cams.shape[:2] == (number_subjects, number_channels * 2)
    else:
        assert guided_grad_cams.shape[:2] == (number_subjects, number_channels)

    assert grad_cams.shape[:2] == (number_subjects, 1) if encode_channels_jointly \
        else (number_subjects, number_channels)

    if use_non_imaging_features:
        non_image_features = config.numerical_columns + config.categorical_columns
        non_imaging_plot_labels = visualizer._get_non_imaging_plot_labels(model_inputs_and_labels.data_item,
                                                                          non_image_features,
                                                                          index=0)
        assert non_imaging_plot_labels == ['numerical1_week1',
                                           'numerical1_week0',
                                           'numerical2_week1',
                                           'numerical2_week0',
                                           'categorical1_week1',
                                           'categorical1_week0',
                                           'categorical2_week1']
        assert pseudo_cam_non_img.shape == (number_subjects, 1, len(non_imaging_plot_labels))
예제 #5
0
class ModelTrainingStepsForScalarModel(
        ModelTrainingStepsBase[F, DeviceAwareModule]):
    """
    This class implements all steps necessary for training an image classification model during a single epoch.
    """
    def __init__(self, config: F,
                 train_val_params: TrainValidateParameters[DeviceAwareModule]):
        """
        Creates a new instance of the class.
        :param config: The configuration of a classification model.
        :param train_val_params: The parameters for training the model, including the optimizer and the data loaders.
        """
        # This field needs to be defined in the constructor to keep pycharm happy, but before the call to the
        # base class because the base class constructor create_loss_function
        self.label_tensor_dtype = torch.float32
        super().__init__(config, train_val_params)
        self.metrics = create_metrics_dict_from_config(config)
        self.compute_mean_teacher_model = self.model_config.compute_mean_teacher_model

        if self.model_config.compute_grad_cam:
            model_to_evaluate = self.train_val_params.mean_teacher_model if \
                self.model_config.compute_mean_teacher_model else self.train_val_params.model
            self.guided_grad_cam = VisualizationMaps(model_to_evaluate,
                                                     self.model_config)
            self.model_config.visualization_folder.mkdir(exist_ok=True)

    def create_loss_function(self) -> torch.nn.Module:
        """
        Returns a torch module that computes a loss function.
        Depending on the chosen loss function, the required data type for the labels tensor is set in
        self.
        """
        if self.model_config.loss_type == ScalarLoss.BinaryCrossEntropyWithLogits:
            return BinaryCrossEntropyWithLogitsLoss(
                smoothing_eps=self.model_config.label_smoothing_eps)
        if self.model_config.loss_type == ScalarLoss.WeightedCrossEntropyWithLogits:
            return BinaryCrossEntropyWithLogitsLoss(
                smoothing_eps=self.model_config.label_smoothing_eps,
                class_counts=self.model_config.get_training_class_counts())
        elif self.model_config.loss_type == ScalarLoss.MeanSquaredError:
            self.label_tensor_dtype = torch.float32
            return MSELoss()
        else:
            raise NotImplementedError("Loss type {} is not implemented".format(
                self.model_config.loss_type))

    def get_label_tensor(self, labels: torch.Tensor) -> torch.Tensor:
        """
        Converts the given tensor to the right data format, depending on the chosen loss function.
        :param labels: The label tensor that should be converted.
        """
        try:
            labels = labels.to(dtype=self.label_tensor_dtype)
        except ValueError as ex:
            raise ValueError(
                f"Unable to convert tensor {labels} to data type {self.label_tensor_dtype}: {str(ex)}"
            )
        return self.model_config.get_gpu_tensor_if_possible(labels)

    def get_logits_and_posteriors(self, *model_inputs: torch.Tensor, use_mean_teacher_model: bool = False) \
            -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns a Tuple containing the logits and the final model output. Note that the logits might be
        distributed over multiple GPU if the model is an instance of DataParallel. In this case,
        the posteriors will be gathered to GPU_0.
        :param model_inputs: input to evaluate the model on
        :param use_mean_teacher_model: If True, logits and posteriors are produced for the mean teacher model. Else
        logits and posteriors are produced for the standard (student) model.
        :return: Tuple (logits, posteriors).
        """
        if use_mean_teacher_model:
            logits = self.train_val_params.mean_teacher_model(*model_inputs)
        else:
            logits = self.train_val_params.model(*model_inputs)
        posteriors = self.model_config.get_post_loss_logits_normalization_function(
        )(gather_tensor(logits))
        return logits, posteriors

    def _compute_model_output_and_loss(self, model_inputs_and_labels: ScalarModelInputsAndLabels) -> \
            Tuple[Tensor, Tensor, Tensor]:
        """
        Computes the output of the model for a given set of inputs and labels.
        Returns a tuple of (logits, posteriors, loss). For multi-GPU computation, the logits are returned
        as a list.
        """
        model = self.train_val_params.model
        label_gpu = self.get_label_tensor(model_inputs_and_labels.labels)
        if self.model_config.use_mixed_precision and self.model_config.use_gpu:
            label_gpu = label_gpu.to(dtype=torch.float16)

        def compute() -> Tuple[Tensor, Tensor, Tensor]:
            if self.in_training_mode:
                model.train()
                logits, posteriors = self.get_logits_and_posteriors(
                    *model_inputs_and_labels.model_inputs)
            else:
                model.eval()
                with torch.no_grad():
                    logits, posteriors = self.get_logits_and_posteriors(
                        *model_inputs_and_labels.model_inputs)
                model.train()
            loss = self.compute_loss(logits, label_gpu)
            return logits, posteriors, loss

        return execute_within_autocast_if_needed(
            func=compute, use_autocast=self.model_config.use_mixed_precision)

    def forward_and_backward_minibatch(
            self, sample: Dict[str, Any], batch_index: int,
            epoch: int) -> ModelForwardAndBackwardsOutputs:
        """
        Runs training for a single minibatch of training data, and computes all metrics.
        :param sample: The batched sample on which the model should be trained.
        :param batch_index: The index of the present batch (supplied only for diagnostics).
        :param epoch: The number of the present epoch.
        """
        start_time = time.time()
        model = self.train_val_params.model
        mean_teacher_model = self.train_val_params.mean_teacher_model
        model_inputs_and_labels = get_scalar_model_inputs_and_labels(
            self.model_config, model, sample)
        label_gpu = self.get_label_tensor(model_inputs_and_labels.labels)
        logits, posteriors, loss = self._compute_model_output_and_loss(
            model_inputs_and_labels)
        gathered_logits = gather_tensor(logits)
        if self.in_training_mode:
            single_optimizer_step(loss, self.train_val_params.optimizer,
                                  self.train_val_params.gradient_scaler)
            if self.model_config.compute_mean_teacher_model:
                self.update_mean_teacher_parameters()

        if self.compute_mean_teacher_model:
            # If the mean teacher model is computed, use the output of the mean teacher for the metrics report
            # instead of the output of the student model.
            mean_teacher_model.eval()
            with torch.no_grad():
                logits, posteriors = self.get_logits_and_posteriors(
                    *model_inputs_and_labels.model_inputs,
                    use_mean_teacher_model=True)
                gathered_logits = gather_tensor(logits)

        # Autocast may have returned float16 tensors. Documentation suggests to simply cast back to float32.
        # If tensor was already float32, no overhead is incurred.
        posteriors = posteriors.detach().float()
        gathered_logits = gathered_logits.detach().float().cpu()
        loss_scalar = loss.float().item()

        if self.train_val_params.save_metrics:
            if self._should_save_grad_cam_output(epoch=epoch,
                                                 batch_index=batch_index):
                self.save_grad_cam(epoch, model_inputs_and_labels.subject_ids,
                                   model_inputs_and_labels.data_item,
                                   model_inputs_and_labels.model_inputs,
                                   label_gpu)

            self.metrics.add_metric(MetricType.LOSS, loss_scalar)
            self.update_metrics(model_inputs_and_labels.subject_ids,
                                posteriors, label_gpu)
            logging.debug(f"Batch {batch_index}: {self.metrics.to_string()}")
            minibatch_time = time.time() - start_time
            self.metrics.add_metric(MetricType.SECONDS_PER_BATCH,
                                    minibatch_time)

        return ModelForwardAndBackwardsOutputs(
            loss=loss_scalar,
            logits=gathered_logits,
            labels=model_inputs_and_labels.labels)

    def get_epoch_results_and_store(self,
                                    epoch_time_seconds: float) -> MetricsDict:
        """
        Assembles all training results that were achieved over all minibatches, returns them as a dictionary
        mapping from metric name to metric value.
        :param epoch_time_seconds: For diagnostics, this is the total time in seconds for training the present epoch.
        :return: A dictionary that holds all metrics averaged over the epoch.
        """
        self.metrics.add_metric(MetricType.SECONDS_PER_EPOCH,
                                epoch_time_seconds)
        assert len(self.train_val_params.epoch_learning_rate
                   ) == 1, "Expected a single entry for learning rate."
        self.metrics.add_metric(MetricType.LEARNING_RATE,
                                self.train_val_params.epoch_learning_rate[0])
        averaged_across_hues = self.metrics.average(across_hues=False)
        mode = ModelExecutionMode.TRAIN if self.in_training_mode else ModelExecutionMode.VAL
        diagnostics_lines = averaged_across_hues.to_string()
        logging.info(
            f"Results for epoch {self.train_val_params.epoch:3d} {mode.value}\n{diagnostics_lines}"
        )

        # Store subject level metrics
        subject_logger = self.train_val_params.dataframe_loggers.train_subject_metrics if \
            self.train_val_params.in_training_mode \
            else self.train_val_params.dataframe_loggers.val_subject_metrics
        self.metrics.store_metrics_per_subject(
            epoch=self.train_val_params.epoch,
            df_logger=subject_logger,
            mode=mode,
            cross_validation_split_index=self.model_config.
            cross_validation_split_index)

        if self._should_save_regression_error_plot(
                self.train_val_params.epoch):
            error_plot_name = f"error_plot_{self.train_val_params.epoch}"
            path = str(self.model_config.outputs_folder /
                       f"{error_plot_name}.png")
            plot_variation_error_prediction(self.metrics.get_labels(),
                                            self.metrics.get_predictions(),
                                            path)
            self.azure_and_tensorboard_logger.log_image(error_plot_name, path)

        # Write metrics to Azure and TensorBoard
        metrics.store_epoch_metrics(self.azure_and_tensorboard_logger,
                                    self.df_logger,
                                    self.train_val_params.epoch,
                                    averaged_across_hues,
                                    self.train_val_params.epoch_learning_rate,
                                    self.model_config)
        return self.metrics.average(across_hues=True)

    def update_metrics(self, subject_ids: List[str],
                       model_output: torch.Tensor,
                       labels: torch.Tensor) -> None:
        """
        Handle metrics updates based on the provided model outputs and labels.
        """
        compute_scalar_metrics(self.metrics, subject_ids, model_output, labels,
                               self.model_config.loss_type)

    def save_grad_cam(self, epoch: int, subject_ids: List,
                      classification_item: Union[
                          List[ClassificationItemSequence[ScalarItem]],
                          ScalarItem], model_inputs: List[torch.Tensor],
                      labels: torch.Tensor) -> None:
        filenames = [f"{epoch}_viz_{id}" for id in subject_ids]
        self.guided_grad_cam.save_visualizations_in_notebook(
            classification_item,  # type: ignore
            model_inputs,
            filenames,
            ground_truth_labels=labels.cpu().numpy(),
            gradcam_dir=self.model_config.visualization_folder)

    def update_mean_teacher_parameters(self) -> None:
        """
        Updates the mean teacher model parameters as per the update formula
        mean_teacher_model_weight = alpha * (mean_teacher_model_weight) + (1-alpha) * (student_model_weight)
        see https://arxiv.org/abs/1703.01780
        """
        mean_teacher_model = self.train_val_params.mean_teacher_model
        model = self.train_val_params.model
        if isinstance(mean_teacher_model, DataParallelModel):
            mean_teacher_model = mean_teacher_model.module  # type: ignore
            model = model.module  # type: ignore
        for mean_teacher_param, student_param in zip(
                mean_teacher_model.parameters(), model.parameters()):
            mean_teacher_param.data = self.model_config.mean_teacher_alpha * mean_teacher_param.data \
                                      + (1 - self.model_config.mean_teacher_alpha) * student_param.data

    def _should_save_grad_cam_output(self, epoch: int,
                                     batch_index: int) -> bool:
        return self.model_config.is_classification_model \
               and (not self.in_training_mode) \
               and self.model_config.should_save_epoch(epoch) \
               and (batch_index < self.model_config.max_batch_grad_cam)

    def _should_save_regression_error_plot(self, epoch: int) -> bool:
        return self.model_config.is_regression_model \
               and (not self.in_training_mode) \
               and self.model_config.should_save_epoch(epoch)