예제 #1
0
    def _define_model(self) -> None:
        """Instantiate the U-Net architecture."""

        # Create RESNET
        self._device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device '{self._device}'.", file=self._stdout)

        # Try to free memory on the GPU
        if self._device != "cpu":
            torch.cuda.empty_cache()

        # Monai's RESNET
        self._model = HighResNet(spatial_dims=2,
                                 in_channels=self._in_channels,
                                 out_channels=self._out_channels,
                                 norm_type=('batch', {
                                     'affine': True
                                 }),
                                 acti_type=('relu', {
                                     'inplace': True
                                 }),
                                 dropout_prob=0.2)
예제 #2
0
 def test_shape(self, input_param, input_data, expected_shape):
     net = HighResNet(**input_param)
     net.eval()
     with torch.no_grad():
         result = net.forward(input_data)
         self.assertEqual(result.shape, expected_shape)
예제 #3
0
 def test_shape(self, input_param, input_shape, expected_shape):
     net = HighResNet(**input_param).to(device)
     net.eval()
     with torch.no_grad():
         result = net.forward(torch.randn(input_shape).to(device))
         self.assertEqual(result.shape, expected_shape)
예제 #4
0
class ResnetSegmenter(AbstractBaseLearner):
    """Segmenter based on the U-Net architecture."""
    def __init__(self,
                 mask_type: MaskType = MaskType.TIFF_LABELS,
                 in_channels: int = 1,
                 out_channels: int = 3,
                 roi_size: Tuple[int, int] = (384, 384),
                 num_epochs: int = 400,
                 batch_sizes: Tuple[int, int, int, int] = (8, 1, 1, 1),
                 num_workers: Tuple[int, int, int, int] = (4, 4, 1, 1),
                 validation_step: int = 2,
                 sliding_window_batch_size: int = 4,
                 class_names: Tuple[str,
                                    ...] = ("Background", "Object", "Border"),
                 experiment_name: str = "HighResNet",
                 model_name: str = "best_model",
                 seed: int = 4294967295,
                 working_dir: str = '.',
                 stdout: TextIOWrapper = sys.stdout,
                 stderr: TextIOWrapper = sys.stderr):
        """Constructor.

        @param mask_type: MaskType
            Type of mask: defines file type, mask geometry and they way pixels
            are assigned to the various classes.

            @see qu.data.model.MaskType

        @param in_channels: int, optional: default = 1
            Number of channels in the input (e.g. 1 for gray-value images).

        @param out_channels: int, optional: default = 3
            Number of channels in the output (classes).

        @param roi_size: Tuple[int, int], optional: default = (384, 384)
            Crop area (and input size of the U-Net network) used for training and validation/prediction.

        @param num_epochs: int, optional: default = 400
            Number of epochs for training.

        @param batch_sizes: Tuple[int, int, int], optional: default = (8, 1, 1, 1)
            Batch sizes for training, validation, testing, and prediction, respectively.

        @param num_workers: Tuple[int, int, int], optional: default = (4, 4, 1, 1)
            Number of workers for training, validation, testing, and prediction, respectively.

        @param validation_step: int, optional: default = 2
            Number of training steps before the next validation is performed.

        @param sliding_window_batch_size: int, optional: default = 4
            Number of batches for sliding window inference during validation and prediction.

        @param class_names: Tuple[str, ...], optional: default = ("Background", "Object", "Border")
            Name of the classes for logging validation curves.

        @param experiment_name: str, optional: default = ""
            Name of the experiment that maps to the folder that contains training information (to
            be used by tensorboard). Please note, current datetime will be appended.

        @param model_name: str, optional: default = "best_model.ph"
            Name of the file that stores the best model. Please note, current datetime will be appended
            (before the extension).

        @param seed: int, optional; default = 4294967295
            Set random seed for modules to enable or disable deterministic training.

        @param working_dir: str, optional, default = "."
            Working folder where to save the model weights and the logs for tensorboard.

        """

        # Call base constructor
        super().__init__()

        # Standard pipe wrappers
        self._stdout = stdout
        self._stderr = stderr

        # Device (initialize as "cpu")
        self._device = "cpu"

        # Mask type
        self._mask_type = mask_type

        # Input and output channels
        self._in_channels = in_channels
        self._out_channels = out_channels

        # Define hyper parameters
        self._roi_size = roi_size
        self._training_batch_size = batch_sizes[0]
        self._validation_batch_size = batch_sizes[1]
        self._test_batch_size = batch_sizes[2]
        self._prediction_batch_size = batch_sizes[3]
        self._training_num_workers = num_workers[0]
        self._validation_num_workers = num_workers[1]
        self._test_num_workers = num_workers[2]
        self._prediction_num_workers = num_workers[3]
        self._n_epochs = num_epochs
        self._validation_step = validation_step
        self._sliding_window_batch_size = sliding_window_batch_size

        # Other parameters
        self._class_names = out_channels * ["Unknown"]
        for i in range(min(out_channels, len(class_names))):
            self._class_names[i] = class_names[i]

        # Set monai seed
        set_determinism(seed=seed)

        # All file names
        self._train_image_names: list = []
        self._train_mask_names: list = []
        self._validation_image_names: list = []
        self._validation_mask_names: list = []
        self._test_image_names: list = []
        self._test_mask_names: list = []

        # Transforms
        self._train_image_transforms = None
        self._train_mask_transforms = None
        self._validation_image_transforms = None
        self._validation_mask_transforms = None
        self._test_image_transforms = None
        self._test_mask_transforms = None
        self._prediction_image_transforms = None
        self._validation_post_transforms = None
        self._test_post_transforms = None
        self._prediction_post_transforms = None

        # Datasets and data loaders
        self._train_dataset = None
        self._train_dataloader = None
        self._validation_dataset = None
        self._validation_dataloader = None
        self._test_dataset = None
        self._test_dataloader = None
        self._prediction_dataset = None
        self._prediction_dataloader = None

        # Set model architecture, loss function, metric and optimizer
        self._model = None
        self._training_loss_function = None
        self._optimizer = None
        self._validation_metric = None

        # Working directory, model file name and experiment name for Tensorboard logs.
        # The file names will be redefined at the beginning of the training.
        self._working_dir = Path(working_dir).resolve()
        self._raw_experiment_name = experiment_name
        self._raw_model_file_name = model_name

        # Keep track of the full path of the best model
        self._best_model = ''

        # Keep track of last error message
        self._message = ""

    def train(self) -> bool:
        """Run training in a separate thread (added to the global application ThreadPool)."""

        # Free memory on the GPU
        self._clear_session()

        # Check that the data is set properly
        if len(self._train_image_names) == 0 or \
                len(self._train_mask_names) == 0 or \
                len(self._validation_image_names) == 0 or \
                len(self._validation_mask_names) == 0:
            self._message = "No training/validation data found."
            return False

        if len(self._train_image_names) != len(self._train_mask_names) == 0:
            self._message = "The number of training images does not match the number of training masks."
            return False

        if len(self._validation_image_names) != len(
                self._validation_mask_names) == 0:
            self._message = "The number of validation images does not match the number of validation masks."
            return False

        # Define the transforms
        self._define_transforms()

        # Define the datasets and data loaders
        self._define_training_data_loaders()

        # Instantiate the model
        self._define_model()

        # Define the loss function
        self._define_training_loss()

        # Define the optimizer (with default parameters)
        self._define_optimizer()

        # Define the validation metric
        self._define_validation_metric()

        # Define experiment name and model name
        experiment_name, model_file_name = self._prepare_experiment_and_model_names(
        )

        # Keep track of the best model file name
        self._best_model = model_file_name

        # Enter the main training loop
        best_metric = -1
        best_metric_epoch = -1

        epoch_loss_values = list()
        metric_values = list()

        # Initialize TensorBoard's SummaryWriter
        writer = SummaryWriter(experiment_name)

        for epoch in range(self._n_epochs):

            # Inform
            self._print_header(f"Epoch {epoch + 1}/{self._n_epochs}")

            # Switch to training mode
            self._model.train()

            epoch_loss = 0
            step = 0
            for batch_data in self._train_dataloader:

                # Update step
                step += 1

                # Get the next batch and move it to device
                inputs, labels = batch_data[0].to(
                    self._device), batch_data[1].to(self._device)

                # Zero the gradient buffers
                self._optimizer.zero_grad()

                # Forward pass
                outputs = self._model(inputs)

                # Calculate the loss
                loss = self._training_loss_function(outputs, labels)

                # Back-propagate
                loss.backward()

                # Update weights (optimize)
                self._optimizer.step()

                # Update and store metrics
                epoch_loss += loss.item()
                epoch_len = len(
                    self._train_dataset) / self._train_dataloader.batch_size
                if epoch_len != int(epoch_len):
                    epoch_len = int(epoch_len) + 1

                print(
                    f"Batch {step}/{epoch_len}: train_loss = {loss.item():.4f}",
                    file=self._stdout)

            epoch_loss /= step
            epoch_loss_values.append(epoch_loss)
            print(f"Average loss = {epoch_loss:.4f}", file=self._stdout)
            writer.add_scalar("average_train_loss", epoch_loss, epoch + 1)

            # Validation
            if (epoch + 1) % self._validation_step == 0:

                self._print_header("Validation")

                # Switch to evaluation mode
                self._model.eval()

                # Make sure not to update the gradients
                with torch.no_grad():

                    # Global metrics
                    metric_sum = 0.0
                    metric_count = 0
                    metric = 0.0

                    # Keep track of the metrics for all classes
                    metric_sum_classes = self._out_channels * [0.0]
                    metric_count_classes = self._out_channels * [0]
                    metric_classes = self._out_channels * [0.0]

                    for val_data in self._validation_dataloader:

                        # Get the next batch and move it to device
                        val_images, val_labels = val_data[0].to(
                            self._device), val_data[1].to(self._device)

                        # Apply sliding inference over ROI size
                        val_outputs = sliding_window_inference(
                            val_images, self._roi_size,
                            self._sliding_window_batch_size, self._model)
                        val_outputs = self._validation_post_transforms(
                            val_outputs)

                        # Compute overall metric
                        value, not_nans = self._validation_metric(
                            y_pred=val_outputs, y=val_labels)
                        not_nans = not_nans.item()
                        metric_count += not_nans
                        metric_sum += value.item() * not_nans

                        # Compute metric for each class
                        for c in range(self._out_channels):
                            value_obj, not_nans = self._validation_metric(
                                y_pred=val_outputs[:, c:c + 1],
                                y=val_labels[:, c:c + 1])
                            not_nans = not_nans.item()
                            metric_count_classes[c] += not_nans
                            metric_sum_classes[c] += value_obj.item(
                            ) * not_nans

                    # Global metric
                    metric = metric_sum / metric_count
                    metric_values.append(metric)

                    # Metric per class
                    for c in range(self._out_channels):
                        metric_classes[c] = metric_sum_classes[
                            c] / metric_count_classes[c]

                    # Print summary
                    print(f"Global metric = {metric:.4f} ", file=self._stdout)
                    for c in range(self._out_channels):
                        print(
                            f"Class '{self._class_names[c]}' metric = {metric_classes[c]:.4f} ",
                            file=self._stdout)

                    # Do we have the best metric so far?
                    if metric > best_metric:
                        best_metric = metric
                        best_metric_epoch = epoch + 1
                        torch.save(self._model.state_dict(), model_file_name)
                        print(
                            f"New best global metric = {best_metric:.4f} at epoch: {best_metric_epoch}",
                            file=self._stdout)
                        print(
                            f"Saved best model '{Path(model_file_name).name}'",
                            file=self._stdout)

                    # Add validation loss and metrics to log
                    writer.add_scalar("val_mean_dice_loss", metric, epoch + 1)
                    for c in range(self._out_channels):
                        metric_name = f"val_{self._class_names[c].lower()}_metric"
                        writer.add_scalar(metric_name, metric_classes[c],
                                          epoch + 1)

        print(
            f"Training completed. Best_metric = {best_metric:.4f} at epoch: {best_metric_epoch}",
            file=self._stdout)
        writer.close()

        # Return success
        return True

    def test_predict(self,
                     target_folder: Union[Path, str] = '',
                     model_path: Union[Path, str] = '') -> bool:
        """Run prediction on predefined test data.

        @param target_folder: Path|str, optional: default = ''
            Path to the folder where to store the predicted images. If not specified,
            if defaults to '{working_dir}/predictions'. See constructor.

        @param model_path: Path|str, optional: default = ''
            Full path to the model to use. If omitted and a training was
            just run, the path to the model with the best metric is
            already stored and will be used.

            @see get_best_model_path()

        @return True if the prediction was successful, False otherwise.
        """

        # Inform
        self._print_header("Test prediction")

        # Get the device
        self._device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        # If the model is not in memory, instantiate it first
        if self._model is None:
            self._define_model()

        # If the path to the best model was not set, use current one (if set)
        if model_path == '':
            model_path = self.get_best_model_path()

        # Try loading the model weights: they must be compatible
        # with the model in memory
        try:
            checkpoint = torch.load(model_path,
                                    map_location=torch.device('cpu'))
            self._model.load_state_dict(checkpoint)
            print(f"Loaded best metric model {model_path}.", file=self._stdout)
        except Exception as e:
            self._message = "Error: there was a problem loading the model! Aborting."
            return False

        # If the target folder is not specified, set it to the standard predictions out
        if target_folder == '':
            target_folder = Path(self._working_dir) / "tests"
        else:
            target_folder = Path(target_folder)
        target_folder.mkdir(parents=True, exist_ok=True)

        # Switch to evaluation mode
        self._model.eval()

        indx = 0

        # Make sure not to update the gradients
        with torch.no_grad():
            for test_data in self._test_dataloader:

                # Get the next batch and move it to device
                test_images, test_masks = test_data[0].to(
                    self._device), test_data[1].to(self._device)

                # Apply sliding inference over ROI size
                test_outputs = sliding_window_inference(
                    test_images, self._roi_size,
                    self._sliding_window_batch_size, self._model)
                test_outputs = self._test_post_transforms(test_outputs)

                # Retrieve the image from the GPU (if needed)
                pred = test_outputs.cpu().numpy().squeeze()

                # Prepare the output file name
                basename = os.path.splitext(
                    os.path.basename(self._test_image_names[indx]))[0]
                basename = basename.replace('train_', 'pred_')

                # Convert to label image
                label_img = self._prediction_to_label_tiff_image(pred)

                # Save label image as tiff file
                label_file_name = os.path.join(str(target_folder),
                                               basename + '.tif')
                with TiffWriter(label_file_name) as tif:
                    tif.save(label_img)

                # Inform
                print(f"Saved {str(target_folder)}/{basename}.tif",
                      file=self._stdout)

                # Update the index
                indx += 1

        # Inform
        print(f"Test prediction completed.", file=self._stdout)

        # Return success
        return True

    def predict(self, input_folder: Union[Path, str],
                target_folder: Union[Path, str], model_path: Union[Path, str]):
        """Run prediction.

        @param input_folder: Path|str
            Path to the folder where to store the predicted images.

        @param target_folder: Path|str
            Path to the folder where to store the predicted images.

        @param model_path: Path|str
            Full path to the model to use.

        @return True if the prediction was successful, False otherwise.
        """
        # Inform
        self._print_header("Prediction")

        # Get the device
        self._device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        # If the model is not in memory, instantiate it first
        if self._model is None:
            self._define_model()

        # Try loading the model weights: they must be compatible
        # with the model in memory
        try:
            checkpoint = torch.load(model_path,
                                    map_location=torch.device('cpu'))
            self._model.load_state_dict(checkpoint)
            print(f"Loaded best metric model {model_path}.", file=self._stdout)
        except Exception as e:
            self._message = "Error: there was a problem loading the model! Aborting."
            return False

        # Make sure the target folder exists
        if type(target_folder) == str and target_folder == '':
            self._message = "Error: please specify a valid target folder! Aborting."
            return False

        target_folder = Path(target_folder)
        target_folder.mkdir(parents=True, exist_ok=True)

        # Get prediction dataloader
        if not self._define_prediction_data_loaders(input_folder):
            self._message = "Error: could not instantiate prediction dataloader! Aborting."
            return False

        # Switch to evaluation mode
        self._model.eval()

        indx = 0

        # Make sure not to update the gradients
        with torch.no_grad():
            for prediction_data in self._prediction_dataloader:

                # Get the next batch and move it to device
                prediction_images = prediction_data.to(self._device)

                # Apply sliding inference over ROI size
                prediction_outputs = sliding_window_inference(
                    prediction_images, self._roi_size,
                    self._sliding_window_batch_size, self._model)
                prediction_outputs = self._prediction_post_transforms(
                    prediction_outputs)

                # Retrieve the image from the GPU (if needed)
                pred = prediction_outputs.cpu().numpy().squeeze()

                # Prepare the output file name
                basename = os.path.splitext(
                    os.path.basename(self._prediction_image_names[indx]))[0]
                basename = "pred_" + basename

                # Convert to label image
                label_img = self._prediction_to_label_tiff_image(pred)

                # Save label image as tiff file
                label_file_name = os.path.join(str(target_folder),
                                               basename + '.tif')
                with TiffWriter(label_file_name) as tif:
                    tif.save(label_img)

                # Inform
                print(f"Saved {str(target_folder)}/{basename}.tif",
                      file=self._stdout)

                # Update the index
                indx += 1

        # Inform
        print(f"Prediction completed.", file=self._stdout)

        # Return success
        return True

    def set_training_data(self, train_image_names, train_mask_names,
                          val_image_names, val_mask_names, test_image_names,
                          test_mask_names) -> None:
        """Set all training files names.

        @param train_image_names: list
            List of training image names.

        @param train_mask_names: list
            List of training mask names.

        @param val_image_names: list
            List of validation image names.

        @param val_mask_names: list
            List of validation image names.

        @param test_image_names: list
            List of test image names.

        @param test_mask_names: list
            List of test image names.
        """

        # First validate all data
        if len(train_image_names) != len(train_mask_names):
            raise ValueError(
                "The number of training images does not match the number of training masks."
            )

        if len(val_image_names) != len(val_mask_names):
            raise ValueError(
                "The number of validation images does not match the number of validation masks."
            )

        if len(test_image_names) != len(test_mask_names):
            raise ValueError(
                "The number of test images does not match the number of test masks."
            )

        # Training data
        self._train_image_names = train_image_names
        self._train_mask_names = train_mask_names

        # Validation data
        self._validation_image_names = val_image_names
        self._validation_mask_names = val_mask_names

        # Test data
        self._test_image_names = test_image_names
        self._test_mask_names = test_mask_names

    @staticmethod
    def _prediction_to_label_tiff_image(prediction):
        """Save the prediction to a label image (TIFF)"""

        # Convert to label image
        label_img = one_hot_stack_to_label_image(
            prediction,
            first_index_is_background=True,
            channels_first=True,
            dtype=np.uint16)

        return label_img

    def _define_transforms(self):
        """Define and initialize all data transforms.

          * training set images transform
          * training set masks transform
          * validation set images transform
          * validation set masks transform
          * validation set images post-transform
          * test set images transform
          * test set masks transform
          * test set images post-transform
          * prediction set images transform
          * prediction set images post-transform

        @return True if data transforms could be instantiated, False otherwise.
        """

        if self._mask_type == MaskType.UNKNOWN:
            raise Exception("The mask type is unknown. Cannot continue!")

        # Depending on the mask type, we will need to adapt the Mask Loader
        # and Transform. We start by initializing the most common types.
        MaskLoader = LoadMask(self._mask_type)
        MaskTransform = Identity

        # Adapt the transform for the LABEL types
        if self._mask_type == MaskType.TIFF_LABELS or self._mask_type == MaskType.NUMPY_LABELS:
            MaskTransform = ToOneHot(num_classes=self._out_channels)

        # The H5_ONE_HOT type requires a different loader
        if self._mask_type == MaskType.H5_ONE_HOT:
            # MaskLoader: still missing
            raise Exception("HDF5 one-hot masks are not supported yet!")

        # Define transforms for training
        self._train_image_transforms = Compose([
            LoadImage(image_only=True),
            ScaleIntensity(),
            AddChannel(),
            RandSpatialCrop(self._roi_size, random_size=False),
            RandRotate90(prob=0.5, spatial_axes=(0, 1)),
            RandGaussianNoise(prob=0.5, mean=0., std=0.2),
            ToTensor()
        ])
        self._train_mask_transforms = Compose([
            MaskLoader, MaskTransform,
            RandSpatialCrop(self._roi_size, random_size=False),
            RandRotate90(prob=0.5, spatial_axes=(0, 1)),
            RandGaussianNoise(prob=0.5, mean=0., std=0.2),
            ToTensor()
        ])

        # Define transforms for validation
        self._validation_image_transforms = Compose([
            LoadImage(image_only=True),
            ScaleIntensity(),
            AddChannel(),
            ToTensor()
        ])
        self._validation_mask_transforms = Compose(
            [MaskLoader, MaskTransform, ToTensor()])

        # Define transforms for testing
        self._test_image_transforms = Compose([
            LoadImage(image_only=True),
            ScaleIntensity(),
            AddChannel(),
            ToTensor()
        ])
        self._test_mask_transforms = Compose(
            [MaskLoader, MaskTransform, ToTensor()])

        # Define transforms for prediction
        self._prediction_image_transforms = Compose([
            LoadImage(image_only=True),
            ScaleIntensity(),
            AddChannel(),
            ToTensor()
        ])

        # Post transforms
        self._validation_post_transforms = Compose(
            [Activations(softmax=True),
             AsDiscrete(threshold_values=True)])

        self._test_post_transforms = Compose(
            [Activations(softmax=True),
             AsDiscrete(threshold_values=True)])

        self._prediction_post_transforms = Compose(
            [Activations(softmax=True),
             AsDiscrete(threshold_values=True)])

    def _define_training_data_loaders(self) -> bool:
        """Initialize training datasets and data loaders.

        @Note: in Windows, it is essential to set `persistent_workers=True` in the data loaders!

        @return True if datasets and data loaders could be instantiated, False otherwise.
        """

        # Optimize arguments
        if sys.platform == 'win32':
            persistent_workers = True
            pin_memory = False
        else:
            persistent_workers = False
            pin_memory = torch.cuda.is_available()

        if len(self._train_image_names) == 0 or \
                len(self._train_mask_names) == 0 or \
                len(self._validation_image_names) == 0 or \
                len(self._validation_mask_names) == 0 or \
                len(self._test_image_names) == 0 or \
                len(self._test_mask_names) == 0:

            self._train_dataset = None
            self._train_dataloader = None
            self._validation_dataset = None
            self._validation_dataloader = None
            self._test_dataset = None
            self._test_dataloader = None

            return False

        # Training
        self._train_dataset = ArrayDataset(self._train_image_names,
                                           self._train_image_transforms,
                                           self._train_mask_names,
                                           self._train_mask_transforms)
        self._train_dataloader = DataLoader(
            self._train_dataset,
            batch_size=self._training_batch_size,
            shuffle=False,
            num_workers=self._training_num_workers,
            persistent_workers=persistent_workers,
            pin_memory=pin_memory)

        # Validation
        self._validation_dataset = ArrayDataset(
            self._validation_image_names, self._validation_image_transforms,
            self._validation_mask_names, self._validation_mask_transforms)
        self._validation_dataloader = DataLoader(
            self._validation_dataset,
            batch_size=self._validation_batch_size,
            shuffle=False,
            num_workers=self._validation_num_workers,
            persistent_workers=persistent_workers,
            pin_memory=pin_memory)

        # Test
        self._test_dataset = ArrayDataset(self._test_image_names,
                                          self._test_image_transforms,
                                          self._test_mask_names,
                                          self._test_mask_transforms)
        self._test_dataloader = DataLoader(
            self._test_dataset,
            batch_size=self._test_batch_size,
            shuffle=False,
            num_workers=self._test_num_workers,
            persistent_workers=persistent_workers,
            pin_memory=pin_memory)

        return True

    def _define_prediction_data_loaders(
            self, prediction_folder_path: Union[Path, str]) -> bool:
        """Initialize prediction datasets and data loaders.

        @Note: in Windows, it is essential to set `persistent_workers=True` in the data loaders!

        @return True if datasets and data loaders could be instantiated, False otherwise.
        """

        # Check that the path exists
        prediction_folder_path = Path(prediction_folder_path)
        if not prediction_folder_path.is_dir():
            return False

        # Scan for images
        self._prediction_image_names = natsorted(
            glob(str(Path(prediction_folder_path) / "*.tif")))

        # Optimize arguments
        if sys.platform == 'win32':
            persistent_workers = True
            pin_memory = False
        else:
            persistent_workers = False
            pin_memory = torch.cuda.is_available()

        if len(self._prediction_image_names) == 0:

            self._prediction_dataset = None
            self._prediction_dataloader = None

            return False

        # Prediction
        self._prediction_dataset = Dataset(self._prediction_image_names,
                                           self._prediction_image_transforms)
        self._prediction_dataloader = DataLoader(
            self._prediction_dataset,
            batch_size=self._test_batch_size,
            shuffle=False,
            num_workers=self._test_num_workers,
            persistent_workers=persistent_workers,
            pin_memory=pin_memory)

        return True

    def get_message(self):
        """Return last error message."""
        return self._message

    def get_best_model_path(self):
        """Return the full path to the best model."""
        return self._best_model

    def _clear_session(self) -> None:
        """Try clearing cache on the GPU."""
        if self._device != "cpu":
            torch.cuda.empty_cache()

    def _define_model(self) -> None:
        """Instantiate the U-Net architecture."""

        # Create RESNET
        self._device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device '{self._device}'.", file=self._stdout)

        # Try to free memory on the GPU
        if self._device != "cpu":
            torch.cuda.empty_cache()

        # Monai's RESNET
        self._model = HighResNet(spatial_dims=2,
                                 in_channels=self._in_channels,
                                 out_channels=self._out_channels,
                                 norm_type=('batch', {
                                     'affine': True
                                 }),
                                 acti_type=('relu', {
                                     'inplace': True
                                 }),
                                 dropout_prob=0.2)

    def _define_training_loss(self) -> None:
        """Define the loss function."""

        self._training_loss_function = GeneralizedDiceLoss(
            include_background=True,
            to_onehot_y=False,
            softmax=True,
            batch=True,
        )

    def _define_optimizer(self,
                          learning_rate: float = 1e-3,
                          weight_decay: float = 1e-4) -> None:
        """Define the optimizer.

        @param learning_rate: float, optional, default = 1e-3
            Initial learning rate for the optimizer.

        @param weight_decay: float, optional, default = 1e-4
            Weight decay of the learning rate for the optimizer.

        """

        if self._model is None:
            return

        self._optimizer = Adam(self._model.parameters(),
                               learning_rate,
                               weight_decay=weight_decay,
                               amsgrad=True)

    def _define_validation_metric(self):
        """Define the metric for validation function."""

        self._validation_metric = DiceMetric(include_background=True,
                                             reduction="mean")

    def _prepare_experiment_and_model_names(self) -> Tuple[str, str]:
        """Prepare the experiment and model names.

        @return experiment_file_name, model_file_name

        Current date time is appended and the full path is returned.
        """

        # Make sure the "runs" subfolder exists
        runs_dir = Path(self._working_dir) / "runs"
        runs_dir.mkdir(parents=True, exist_ok=True)

        now = datetime.now()  # current date and time
        date_time = now.strftime("%Y%m%d_%H%M%S")

        # Experiment name
        experiment_name = f"{self._raw_experiment_name}_{date_time}" \
            if self._raw_experiment_name != "" \
            else f"{date_time}"
        experiment_name = runs_dir / experiment_name

        # Best model file name
        name = Path(self._raw_model_file_name).stem
        model_file_name = f"{name}_{date_time}.pth"
        model_file_name = runs_dir / model_file_name

        return str(experiment_name), str(model_file_name)

    def _print_header(self, header_text, line_length=80, file=None):
        """Print a section header."""
        if file is None:
            file = self._stdout
        print(f"{line_length * '-'}", file=file)
        print(f"{header_text}", file=self._stdout)
        print(f"{line_length * '-'}", file=file)
예제 #5
0
 def test_script(self):
     input_param, input_shape, expected_shape = TEST_CASE_1
     net = HighResNet(**input_param)
     test_data = torch.randn(input_shape)
     test_script_save(net, test_data)
예제 #6
0
 def test_script(self):
     input_param, input_shape, expected_shape = TEST_CASE_1
     net = HighResNet(**input_param)
     test_data = torch.randn(input_shape)
     out_orig, out_reloaded = test_script_save(net, test_data)
     assert torch.allclose(out_orig, out_reloaded)