示例#1
0
class UNetExperiment:
    """
    This class implements the basic life cycle for a segmentation task with UNet(https://arxiv.org/abs/1505.04597).
    The basic life cycle of a UNetExperiment is:

        run():
            for epoch in n_epochs:
                train()
                validate()
        test()
    """
    def __init__(self, config, split, dataset):
        self.n_epochs = config.n_epochs
        self.split = split
        self._time_start = ""
        self._time_end = ""
        self.epoch = 0
        self.name = config.name

        # Create output folders
        dirname = f'{time.strftime("%Y-%m-%d_%H%M", time.gmtime())}_{self.name}'
        self.out_dir = os.path.join(config.test_results_dir, dirname)
        os.makedirs(self.out_dir, exist_ok=True)

        # Create data loaders
        # TASK: SlicesDataset class is not complete. Go to the file and complete it.
        # Note that we are using a 2D version of UNet here, which means that it will expect
        # batches of 2D slices.
        self.train_loader = DataLoader(SlicesDataset(dataset[split["train"]]),
                                       batch_size=config.batch_size,
                                       shuffle=True,
                                       num_workers=0)
        self.val_loader = DataLoader(SlicesDataset(dataset[split["val"]]),
                                     batch_size=config.batch_size,
                                     shuffle=True,
                                     num_workers=0)

        # we will access volumes directly for testing
        self.test_data = dataset[split["test"]]

        # Do we have CUDA available?
        if not torch.cuda.is_available():
            print(
                "WARNING: No CUDA device is found. This may take significantly longer!"
            )
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        # Configure our model and other training implements
        # We will use a recursive UNet model from German Cancer Research Center,
        # Division of Medical Image Computing. It is quite complicated and works
        # very well on this task. Feel free to explore it or plug in your own model
        self.model = UNet(num_classes=3)
        self.model.to(self.device)

        # We are using a standard cross-entropy loss since the model output is essentially
        # a tensor with softmax'd prediction of each pixel's probability of belonging
        # to a certain class
        self.loss_function = torch.nn.CrossEntropyLoss()

        # We are using standard SGD method to optimize our weights
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=config.learning_rate)
        # Scheduler helps us update learning rate automatically
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, 'min')

        # Set up Tensorboard. By default it saves data into runs folder. You need to launch
        self.tensorboard_train_writer = SummaryWriter(comment="_train")
        self.tensorboard_val_writer = SummaryWriter(comment="_val")

    def train(self):
        """
        This method is executed once per epoch and takes 
        care of model weight update cycle
        """
        print(f"Training epoch {self.epoch}...")
        self.model.train()

        # Loop over our minibatches
        for i, batch in enumerate(self.train_loader):
            self.optimizer.zero_grad()

            # TASK: You have your data in batch variable. Put the slices as 4D Torch Tensors of
            # shape [BATCH_SIZE, 1, PATCH_SIZE, PATCH_SIZE] into variables data and target.
            # Feed data to the model and feed target to the loss function
            #
            # data = <YOUR CODE HERE>
            # target = <YOUR CODE HERE>
            data = batch["image"].to(self.device, dtype=torch.float)
            target = batch["seg"].to(self.device)

            prediction = self.model(data)

            # We are also getting softmax'd version of prediction to output a probability map
            # so that we can see how the model converges to the solution
            prediction_softmax = F.softmax(prediction, dim=1)

            loss = self.loss_function(prediction, target[:, 0, :, :])

            # TASK: What does each dimension of variable prediction represent?
            # ANSWER: Dimensions represent: batch_size, classes, coronal data, axial data

            loss.backward()
            self.optimizer.step()

            if (i % 10) == 0:
                # Output to console on every 10th batch
                print(
                    f"\nEpoch: {self.epoch} Train loss: {loss}, {100*(i+1)/len(self.train_loader):.1f}% complete"
                )

                counter = 100 * self.epoch + 100 * (i / len(self.train_loader))

                # You don't need to do anything with this function, but you are welcome to
                # check it out if you want to see how images are logged to Tensorboard
                # or if you want to output additional debug data
                log_to_tensorboard(self.tensorboard_train_writer, loss, data,
                                   target, prediction_softmax, prediction,
                                   counter)

            print(".", end='')

        print("\nTraining complete")

    def validate(self):
        """
        This method runs validation cycle, using same metrics as 
        Train method. Note that model needs to be switched to eval
        mode and no_grad needs to be called so that gradients do not 
        propagate
        """
        print(f"Validating epoch {self.epoch}...")

        # Turn off gradient accumulation by switching model to "eval" mode
        self.model.eval()
        loss_list = []

        with torch.no_grad():
            for i, batch in enumerate(self.val_loader):

                # TASK: Write validation code that will compute loss on a validation sample
                # <YOUR CODE HERE>

                data = batch["image"].to(self.device, dtype=torch.float)
                target = batch["seg"].to(self.device)

                prediction = self.model(data)

                prediction_softmax = F.softmax(prediction, dim=1)
                loss = self.loss_function(prediction, target[:, 0, :, :])

                print(f"Batch {i}. Data shape {data.shape} Loss {loss}")

                # We report loss that is accumulated across all of validation set
                loss_list.append(loss.item())

        self.scheduler.step(np.mean(loss_list))

        log_to_tensorboard(self.tensorboard_val_writer, np.mean(loss_list),
                           data, target, prediction_softmax, prediction,
                           (self.epoch + 1) * 100)
        print(f"Validation complete")

    def save_model_parameters(self):
        """
        Saves model parameters to a file in results directory
        """
        path = os.path.join(self.out_dir, "model.pth")

        torch.save(self.model.state_dict(), path)

    def load_model_parameters(self, path=''):
        """
        Loads model parameters from a supplied path or a
        results directory
        """
        if not path:
            model_path = os.path.join(self.out_dir, "model.pth")
        else:
            model_path = path

        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path))
        else:
            raise Exception(f"Could not find path {model_path}")

    def run_test(self):
        """
        This runs test cycle on the test dataset.
        Note that process and evaluations are quite different
        Here we are computing a lot more metrics and returning
        a dictionary that could later be persisted as JSON
        """
        print("Testing...")
        self.model.eval()

        # In this method we will be computing metrics that are relevant to the task of 3D volume
        # segmentation. Therefore, unlike train and validation methods, we will do inferences
        # on full 3D volumes, much like we will be doing it when we deploy the model in the
        # clinical environment.

        # TASK: Inference Agent is not complete. Go and finish it. Feel free to test the class
        # in a module of your own by running it against one of the data samples
        inference_agent = UNetInferenceAgent(model=self.model,
                                             device=self.device)

        out_dict = {}
        out_dict["volume_stats"] = []
        dc_list = []
        jc_list = []

        # for every in test set
        for i, x in enumerate(self.test_data):
            pred_label = inference_agent.single_volume_inference(x["image"])

            # We compute and report Dice and Jaccard similarity coefficients which
            # assess how close our volumes are to each other

            # TASK: Dice3D and Jaccard3D functions are not implemented.
            #  Complete the implementation as we discussed
            # in one of the course lessons, you can look up definition of Jaccard index
            # on Wikipedia. If you completed it
            # correctly (and if you picked your train/val/test split right ;)),
            # your average Jaccard on your test set should be around 0.80

            dc = Dice3d(pred_label, x["seg"])
            jc = Jaccard3d(pred_label, x["seg"])
            dc_list.append(dc)
            jc_list.append(jc)

            # STAND-OUT SUGGESTION: By way of exercise, consider also outputting:
            # * Sensitivity and specificity (and explain semantic meaning in terms of
            #   under/over segmenting)
            # * Dice-per-slice and render combined slices with lowest and highest DpS
            # * Dice per class (anterior/posterior)

            out_dict["volume_stats"].append({
                "filename": x['filename'],
                "dice": dc,
                "jaccard": jc
            })
            print(
                f"{x['filename']} Dice {dc:.4f}. {100*(i+1)/len(self.test_data):.2f}% complete"
            )

        out_dict["overall"] = {
            "mean_dice": np.mean(dc_list),
            "mean_jaccard": np.mean(jc_list)
        }

        print("\nTesting complete.")
        return out_dict

    def run(self):
        """
        Kicks off train cycle and writes model parameter file at the end
        """
        self._time_start = time.time()

        print("Experiment started.")

        # Iterate over epochs
        for self.epoch in range(self.n_epochs):
            self.train()
            self.validate()

        # save model for inferencing
        self.save_model_parameters()

        self._time_end = time.time()
        print(
            f"Run complete. Total time: {time.strftime('%H:%M:%S', time.gmtime(self._time_end - self._time_start))}"
        )
示例#2
0
class UNetExperiment(PytorchExperiment):
    """
    The UnetExperiment is inherited from the PytorchExperiment. It implements the basic life cycle for a segmentation task with UNet(https://arxiv.org/abs/1505.04597).
    It is optimized to work with the provided NumpyDataLoader.

    The basic life cycle of a UnetExperiment is the same s PytorchExperiment:

        setup()
        (--> Automatically restore values if a previous checkpoint is given)
        prepare()

        for epoch in n_epochs:
            train()
            validate()
            (--> save current checkpoint)

        end()
    """
    def setup(self):
        pkl_dir = self.config.split_dir
        with open(os.path.join(pkl_dir, "splits.pkl"), 'rb') as f:
            splits = pickle.load(f)

        tr_keys = splits[self.config.fold]['train']
        val_keys = splits[self.config.fold]['val']
        test_keys = splits[self.config.fold]['test']

        self.device = torch.device(
            self.config.device if torch.cuda.is_available() else "cpu")

        self.train_data_loader = NumpyDataSet(
            self.config.data_dir,
            target_size=self.config.patch_size,
            batch_size=self.config.batch_size,
            keys=tr_keys)
        self.val_data_loader = NumpyDataSet(self.config.data_dir,
                                            target_size=self.config.patch_size,
                                            batch_size=self.config.batch_size,
                                            keys=val_keys,
                                            mode="val",
                                            do_reshuffle=False)
        self.test_data_loader = NumpyDataSet(
            self.config.data_test_dir,
            target_size=self.config.patch_size,
            batch_size=self.config.batch_size,
            keys=test_keys,
            mode="test",
            do_reshuffle=False)
        self.model = UNet(num_classes=self.config.num_classes,
                          in_channels=self.config.in_channels)

        self.model.to(self.device)

        # We use a combination of DICE-loss and CE-Loss in this example.
        # This proved good in the medical segmentation decathlon.
        self.dice_loss = SoftDiceLoss(
            batch_dice=True)  # Softmax for DICE Loss!
        self.ce_loss = torch.nn.CrossEntropyLoss(
        )  # No softmax for CE Loss -> is implemented in torch!

        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.config.learning_rate)
        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min')

        # If directory for checkpoint is provided, we load it.
        if self.config.do_load_checkpoint:
            if self.config.checkpoint_dir == '':
                print(
                    'checkpoint_dir is empty, please provide directory to load checkpoint.'
                )
            else:
                self.load_checkpoint(name=self.config.checkpoint_dir,
                                     save_types=("model"))

        self.save_checkpoint(name="checkpoint_start")
        self.elog.print('Experiment set up.')

    def train(self, epoch):
        self.elog.print('=====TRAIN=====')
        self.model.train()

        data = None
        batch_counter = 0
        for data_batch in self.train_data_loader:

            self.optimizer.zero_grad()

            # Shape of data_batch = [1, b, c, w, h]
            # Desired shape = [b, c, w, h]
            # Move data and target to the GPU
            data = data_batch['data'][0].float().to(self.device)
            target = data_batch['seg'][0].long().to(self.device)

            pred = self.model(data)
            pred_softmax = F.softmax(
                pred, dim=1
            )  # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.

            #loss = self.dice_loss(pred_softmax, target.squeeze()) + self.ce_loss(pred, target.squeeze())
            loss = self.ce_loss(pred, target.squeeze())

            loss.backward()
            self.optimizer.step()

            # Some logging and plotting
            if (batch_counter % self.config.plot_freq) == 0:
                self.elog.print('Epoch: {0} Loss: {1:.4f}'.format(
                    self._epoch_idx, loss))

                self.add_result(
                    value=loss.item(),
                    name='Train_Loss',
                    tag='Loss',
                    counter=epoch +
                    (batch_counter /
                     self.train_data_loader.data_loader.num_batches))

                self.clog.show_image_grid(data.float().cpu(),
                                          name="data",
                                          normalize=True,
                                          scale_each=True,
                                          n_iter=epoch)
                self.clog.show_image_grid(target.float().cpu(),
                                          name="mask",
                                          title="Mask",
                                          n_iter=epoch)
                self.clog.show_image_grid(torch.argmax(pred.cpu(),
                                                       dim=1,
                                                       keepdim=True),
                                          name="unt_argmax",
                                          title="Unet",
                                          n_iter=epoch)
                self.clog.show_image_grid(pred.cpu()[:, 1:2, ],
                                          name="unt",
                                          normalize=True,
                                          scale_each=True,
                                          n_iter=epoch)

            batch_counter += 1

        assert data is not None, 'data is None. Please check if your dataloader works properly'

    def validate(self, epoch):
        self.elog.print('VALIDATE')
        self.model.eval()

        data = None
        loss_list = []

        with torch.no_grad():
            for data_batch in self.val_data_loader:
                data = data_batch['data'][0].float().to(self.device)
                target = data_batch['seg'][0].long().to(self.device)

                pred = self.model(data)
                pred_softmax = F.softmax(
                    pred, dim=1
                )  # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.

                #loss = self.dice_loss(pred_softmax, target.squeeze()) + self.ce_loss(pred, target.squeeze())
                loss = self.ce_loss(pred, target.squeeze())
                loss_list.append(loss.item())

        assert data is not None, 'data is None. Please check if your dataloader works properly'
        self.scheduler.step(np.mean(loss_list))

        self.elog.print('Epoch: %d Loss: %.4f' %
                        (self._epoch_idx, np.mean(loss_list)))

        self.add_result(value=np.mean(loss_list),
                        name='Val_Loss',
                        tag='Loss',
                        counter=epoch + 1)

        self.clog.show_image_grid(data.float().cpu(),
                                  name="data_val",
                                  normalize=True,
                                  scale_each=True,
                                  n_iter=epoch)
        self.clog.show_image_grid(target.float().cpu(),
                                  name="mask_val",
                                  title="Mask",
                                  n_iter=epoch)
        self.clog.show_image_grid(torch.argmax(pred.data.cpu(),
                                               dim=1,
                                               keepdim=True),
                                  name="unt_argmax_val",
                                  title="Unet",
                                  n_iter=epoch)
        self.clog.show_image_grid(pred.data.cpu()[:, 1:2, ],
                                  name="unt_val",
                                  normalize=True,
                                  scale_each=True,
                                  n_iter=epoch)

    def test(self):
        from evaluation.evaluator import aggregate_scores, Evaluator
        from collections import defaultdict

        self.elog.print('=====TEST=====')
        self.model.eval()

        pred_dict = defaultdict(list)
        gt_dict = defaultdict(list)

        batch_counter = 0
        with torch.no_grad():
            for data_batch in self.test_data_loader:
                print('testing...', batch_counter)
                batch_counter += 1

                # Get data_batches
                mr_data = data_batch['data'][0].float().to(self.device)
                mr_target = data_batch['seg'][0].float().to(self.device)

                pred = self.model(mr_data)
                pred_argmax = torch.argmax(pred.data.cpu(),
                                           dim=1,
                                           keepdim=True)

                fnames = data_batch['fnames']
                for i, fname in enumerate(fnames):
                    pred_dict[fname[0]].append(
                        pred_argmax[i].detach().cpu().numpy())
                    gt_dict[fname[0]].append(
                        mr_target[i].detach().cpu().numpy())

        test_ref_list = []
        for key in pred_dict.keys():
            test_ref_list.append(
                (np.stack(pred_dict[key]), np.stack(gt_dict[key])))

        scores = aggregate_scores(test_ref_list,
                                  evaluator=Evaluator,
                                  json_author=self.config.author,
                                  json_task=self.config.name,
                                  json_name=self.config.name,
                                  json_output_file=self.elog.work_dir +
                                  "/{}_".format(self.config.author) +
                                  self.config.name + '.json')

        print("Scores:\n", scores)

    def segment_single_image(self, data):
        self.model = UNet(num_classes=self.config.num_classes,
                          in_channels=self.config.in_channels)
        self.device = torch.device(
            self.config.device if torch.cuda.is_available() else "cpu")

        # a model must be present and loaded in here
        if self.config.model_dir == '':
            print(
                'model_dir is empty, please provide directory to load checkpoint.'
            )
        else:
            self.load_checkpoint(name=self.config.model_dir,
                                 save_types=("model"))

        self.elog.print("=====SEGMENT_SINGLE_IMAGE=====")
        self.model.eval()
        self.model.to(self.device)

        # Desired shape = [b, c, w, h]
        # split into even chunks (lets use size)
        with torch.no_grad():

            ######
            # When working entirely on CPU and in memory, the following lines replace the split/concat method
            # mr_data = data.float().to(self.device)
            # pred = self.model(mr_data)
            # pred_argmax = torch.argmax(pred.data.cpu(), dim=1, keepdim=True)
            ######

            ######
            # for CUDA (also works on CPU) split into batches
            blocksize = self.config.batch_size

            # number_of_elements = round(data.shape[0]/blocksize+0.5)     # make blocks large enough to not lose any slices
            chunks = [
                data[i:i + blocksize, ::, ::, ::]
                for i in range(0, data.shape[0], blocksize)
            ]
            pred_list = []
            for data_batch in chunks:
                mr_data = data_batch.float().to(self.device)
                pred_dict = self.model(mr_data)
                pred_list.append(pred_dict.cpu())

            pred = torch.Tensor(np.concatenate(pred_list))
            pred_argmax = torch.argmax(pred, dim=1, keepdim=True)

        # detach result and put it back to cpu so that we can work with, create a numpy array
        result = pred_argmax.short().detach().cpu().numpy()

        return result
class UNetExperiment(PytorchExperiment):
    """
    The UnetExperiment is inherited from the PytorchExperiment. It implements the basic life cycle for a segmentation task with UNet(https://arxiv.org/abs/1505.04597).
    It is optimized to work with the provided NumpyDataLoader.

    The basic life cycle of a UnetExperiment is the same s PytorchExperiment:

        setup()
        (--> Automatically restore values if a previous checkpoint is given)
        prepare()

        for epoch in n_epochs:
            train()
            validate()
            (--> save current checkpoint)

        end()
    """
    def setup(self):
        pkl_dir = self.config.split_dir
        with open(os.path.join(pkl_dir, "splits.pkl"), 'rb') as f:
            splits = pickle.load(f)

        tr_keys = splits[self.config.fold]['train']
        val_keys = splits[self.config.fold]['val']
        test_keys = splits[self.config.fold]['test']

        self.device = torch.device(
            self.config.device if torch.cuda.is_available() else "cpu")

        self.train_data_loader = NumpyDataSet(
            self.config.data_dir,
            target_size=self.config.patch_size,
            batch_size=self.config.batch_size,
            keys=tr_keys)
        self.val_data_loader = NumpyDataSet(self.config.data_dir,
                                            target_size=self.config.patch_size,
                                            batch_size=self.config.batch_size,
                                            keys=val_keys,
                                            mode="val",
                                            do_reshuffle=False)
        self.test_data_loader = NumpyDataSet(
            self.config.data_test_dir,
            target_size=self.config.patch_size,
            batch_size=self.config.batch_size,
            keys=test_keys,
            mode="test",
            do_reshuffle=False)
        self.model = UNet(num_classes=self.config.num_classes,
                          in_channels=self.config.in_channels)

        self.model.to(self.device)

        # We use a combination of DICE-loss and CE-Loss in this example.
        # This proved good in the medical segmentation decathlon.
        self.dice_loss = SoftDiceLoss(
            batch_dice=True)  # Softmax für DICE Loss!
        self.ce_loss = torch.nn.CrossEntropyLoss(
        )  # Kein Softmax für CE Loss -> ist in torch schon mit drin!

        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.config.learning_rate)
        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min')

        # If directory for checkpoint is provided, we load it.
        if self.config.do_load_checkpoint:
            if self.config.checkpoint_dir == '':
                print(
                    'checkpoint_dir is empty, please provide directory to load checkpoint.'
                )
            else:
                self.load_checkpoint(name=self.config.checkpoint_dir,
                                     save_types=("model"))

        self.save_checkpoint(name="checkpoint_start")
        self.elog.print('Experiment set up.')

    def train(self, epoch):
        self.elog.print('=====TRAIN=====')
        self.model.train()

        data = None
        batch_counter = 0
        for data_batch in self.train_data_loader:

            self.optimizer.zero_grad()

            # Shape of data_batch = [1, b, c, w, h]
            # Desired shape = [b, c, w, h]
            # Move data and target to the GPU
            data = data_batch['data'][0].float().to(self.device)
            target = data_batch['seg'][0].long().to(self.device)

            pred = self.model(data)
            pred_softmax = F.softmax(
                pred, dim=1
            )  # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.

            loss = self.dice_loss(pred_softmax,
                                  target.squeeze()) + self.ce_loss(
                                      pred, target.squeeze())
            # loss = self.ce_loss(pred, target.squeeze())
            loss.backward()
            self.optimizer.step()

            # Some logging and plotting
            if (batch_counter % self.config.plot_freq) == 0:
                self.elog.print('Epoch: %d Loss: %.4f' %
                                (self._epoch_idx, loss))

                self.add_result(
                    value=loss.item(),
                    name='Train_Loss',
                    tag='Loss',
                    counter=epoch +
                    (batch_counter /
                     self.train_data_loader.data_loader.num_batches))

                self.clog.show_image_grid(data.float(),
                                          name="data",
                                          normalize=True,
                                          scale_each=True,
                                          n_iter=epoch)
                self.clog.show_image_grid(target.float(),
                                          name="mask",
                                          title="Mask",
                                          n_iter=epoch)
                self.clog.show_image_grid(torch.argmax(pred.cpu(),
                                                       dim=1,
                                                       keepdim=True),
                                          name="unt_argmax",
                                          title="Unet",
                                          n_iter=epoch)
                self.clog.show_image_grid(pred.cpu()[:, 1:2, ],
                                          name="unt",
                                          normalize=True,
                                          scale_each=True,
                                          n_iter=epoch)

            batch_counter += 1

        assert data is not None, 'data is None. Please check if your dataloader works properly'

    def validate(self, epoch):
        self.elog.print('VALIDATE')
        self.model.eval()

        data = None
        loss_list = []

        with torch.no_grad():
            for data_batch in self.val_data_loader:
                data = data_batch['data'][0].float().to(self.device)
                target = data_batch['seg'][0].long().to(self.device)

                pred = self.model(data)
                pred_softmax = F.softmax(
                    pred
                )  # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.

                loss = self.dice_loss(pred_softmax,
                                      target.squeeze()) + self.ce_loss(
                                          pred, target.squeeze())
                loss_list.append(loss.item())

        assert data is not None, 'data is None. Please check if your dataloader works properly'
        self.scheduler.step(np.mean(loss_list))

        self.elog.print('Epoch: %d Loss: %.4f' %
                        (self._epoch_idx, np.mean(loss_list)))

        self.add_result(value=np.mean(loss_list),
                        name='Val_Loss',
                        tag='Loss',
                        counter=epoch + 1)

        self.clog.show_image_grid(data.float(),
                                  name="data_val",
                                  normalize=True,
                                  scale_each=True,
                                  n_iter=epoch)
        self.clog.show_image_grid(target.float(),
                                  name="mask_val",
                                  title="Mask",
                                  n_iter=epoch)
        self.clog.show_image_grid(torch.argmax(pred.data.cpu(),
                                               dim=1,
                                               keepdim=True),
                                  name="unt_argmax_val",
                                  title="Unet",
                                  n_iter=epoch)
        self.clog.show_image_grid(pred.data.cpu()[:, 1:2, ],
                                  name="unt_val",
                                  normalize=True,
                                  scale_each=True,
                                  n_iter=epoch)

    def test(self):
        # TODO
        print('TODO: Implement your test() method here')
class UNetExperiment(PytorchExperiment):
    """
    The UnetExperiment is inherited from the PytorchExperiment. It implements the basic life cycle for a segmentation task with UNet(https://arxiv.org/abs/1505.04597).
    It is optimized to work with the provided NumpyDataLoader.

    The basic life cycle of a UnetExperiment is the same s PytorchExperiment:

        setup()
        (--> Automatically restore values if a previous checkpoint is given)
        prepare()

        for epoch in n_epochs:
            train()
            validate()
            (--> save current checkpoint)

        end()
    """
    def setup(self):
        pkl_dir = self.config.split_dir
        with open(os.path.join(pkl_dir, "splits.pkl"), 'rb') as f:
            splits = pickle.load(f)

        tr_keys = splits[self.config.fold]['train']
        val_keys = splits[self.config.fold]['val']
        test_keys = splits[self.config.fold]['test']
        print("pkl_dir: ", pkl_dir)
        print("tr_keys: ", tr_keys)
        print("val_keys: ", val_keys)
        print("test_keys: ", test_keys)
        self.device = torch.device(
            self.config.device if torch.cuda.is_available() else "cpu")
        task = self.config.dataset_name
        self.train_data_loader = torch.utils.data.DataLoader(
            NucleusDataset(self.config.data_root_dir,
                           train=True,
                           transform=transforms.Compose([
                               Normalize(),
                               Rescale(self.config.patch_size),
                               ToTensor()
                           ]),
                           target_transform=transforms.Compose([
                               Normalize(),
                               Rescale(self.config.patch_size),
                               ToTensor()
                           ]),
                           mode="train",
                           keys=tr_keys,
                           taskname=task),
            batch_size=self.config.batch_size,
            shuffle=True)

        self.val_data_loader = torch.utils.data.DataLoader(
            NucleusDataset(self.config.data_root_dir,
                           train=True,
                           transform=transforms.Compose([
                               Normalize(),
                               Rescale(self.config.patch_size),
                               ToTensor()
                           ]),
                           target_transform=transforms.Compose([
                               Normalize(),
                               Rescale(self.config.patch_size),
                               ToTensor()
                           ]),
                           mode="val",
                           keys=val_keys,
                           taskname=self.config.dataset_name),
            batch_size=self.config.batch_size,
            shuffle=True)

        self.test_data_loader = torch.utils.data.DataLoader(
            NucleusDataset(self.config.data_root_dir,
                           train=True,
                           transform=transforms.Compose([
                               Normalize(),
                               Rescale(self.config.patch_size),
                               ToTensor()
                           ]),
                           target_transform=transforms.Compose([
                               Normalize(),
                               Rescale(self.config.patch_size),
                               ToTensor()
                           ]),
                           mode="test",
                           keys=test_keys,
                           taskname=self.config.dataset_name),
            batch_size=self.config.batch_size,
            shuffle=True)

        self.model = UNet(num_classes=self.config.num_classes,
                          in_channels=self.config.in_channels)
        #self.model = UNet()
        self.model.to(self.device)
        self.bce_weight = 0.5
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.config.learning_rate)
        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min')

        # If directory for checkpoint is provided, we load it.
        if self.config.do_load_checkpoint:
            if self.config.checkpoint_dir == '':
                print(
                    'checkpoint_dir is empty, please provide directory to load checkpoint.'
                )
            else:
                self.load_checkpoint(name=self.config.checkpoint_dir,
                                     save_types=("model"))

        self.save_checkpoint(name="checkpoint_start")
        self.elog.print('Experiment set up.')

    def train(self, epoch):
        self.elog.print('=====TRAIN=====')
        self.model.train()

        data = None
        batch_counter = 0
        metrics = defaultdict(float)
        #running_loss = 0.0
        for batch_idx, (images, masks) in enumerate(self.train_data_loader):
            data, target = images.to(self.device), masks.to(self.device)

            self.optimizer.zero_grad()

            #print("data  shape :",data.shape, "target shape :",target.shape)
            pred = self.model(data)
            pred = torch.sigmoid(pred)
            #pred = F.softmax(pred, dim=1)
            #We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.
            #print("pred_softmax  shape :",pred_softmax.shape, "target shape :",target.shape)
            #loss = self.dice_loss(pred_softmax, target.squeeze()) + self.ce_loss(pred, target.squeeze())
            #loss = F.binary_cross_entropy(pred, target) + soft_dice(pred,target)
            loss = self.bce_weight * F.binary_cross_entropy(pred, target) + (
                1 - self.bce_weight) * soft_dice(pred, target)
            #loss,_ = calc_loss(pred, target, metrics)
            loss.backward()
            self.optimizer.step()

            #running_loss+=loss.item()
            #epoch_loss = running_loss/len(train_data_loader)

            # Some logging and plotting
            if (batch_counter % self.config.plot_freq) == 0:
                self.elog.print('Epoch: {0} Loss: {1:.4f}'.format(
                    self._epoch_idx, loss.item()))

                #self.add_result(value=loss.item(), name='Train_Loss', tag='Loss', counter=epoch + (batch_counter / self.train_data_loader.num_batches))
                self.add_result(value=loss.item(),
                                name='Train_Loss',
                                tag='Loss',
                                counter=epoch)
                self.clog.show_image_grid(data.float().cpu(),
                                          name="data",
                                          normalize=True,
                                          scale_each=True,
                                          n_iter=epoch)
                self.clog.show_image_grid(target.float().cpu(),
                                          name="mask",
                                          title="Mask",
                                          n_iter=epoch)
                #self.clog.show_image_grid(torch.argmax(pred.cpu(), dim=1, keepdim=True), name="unt_argmax", title="Unet", n_iter=epoch)
                self.clog.show_image_grid(pred.cpu(),
                                          name="unt",
                                          normalize=True,
                                          scale_each=True,
                                          n_iter=epoch)

            batch_counter += 1

        assert data is not None, 'data is None. Please check if your dataloader works properly'

    def validate(self, epoch):
        self.elog.print('-------------VALIDATE-------------')
        self.model.eval()

        data = None
        loss_list = []
        acc_list = []
        metrics = defaultdict(float)
        with torch.no_grad():
            for batch_idx, (images, masks) in enumerate(self.val_data_loader):
                data, target = images.to(self.device), masks.to(self.device)
                pred = self.model(data)
                pred = torch.sigmoid(pred)
                #pred = F.softmax(pred, dim=1)
                # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.
                # Ramesh check if soft max is needed
                # loss = self.dice_loss(pred_softmax, target.squeeze()) + self.ce_loss(pred, target.squeeze())
                # loss = F.binary_cross_entropy(pred, masks)

                #loss,dice = calc_loss(pred, target, metrics)
                acc = (-1) * soft_dice(pred, target)
                acc_list.append(acc.item())

                #loss = F.binary_cross_entropy(pred, target) + soft_dice(pred,target)
                loss = self.bce_weight * F.binary_cross_entropy(
                    pred, target) + (1 - self.bce_weight) * soft_dice(
                        pred, target)
                loss_list.append(loss.item())

        assert data is not None, 'data is None. Please check if your dataloader works properly'
        self.scheduler.step(np.mean(loss_list))

        self.elog.print(
            'Epoch: %d Mean Loss: %.4f Mean Dice :' %
            (self._epoch_idx, np.mean(loss_list)), np.mean(acc_list))

        self.add_result(value=np.mean(loss_list),
                        name='Val_Loss',
                        tag='Loss',
                        counter=epoch + 1)
        self.add_result(value=np.mean(acc_list),
                        name='Val_Mean_Accuracy',
                        tag='Accuracy',
                        counter=epoch + 1)

        self.clog.show_image_grid(data.float().cpu(),
                                  name="data_val",
                                  normalize=True,
                                  scale_each=True,
                                  n_iter=epoch)
        self.clog.show_image_grid(target.float().cpu(),
                                  name="mask_val",
                                  title="Mask",
                                  n_iter=epoch)
        self.clog.show_image_grid(torch.argmax(pred.data.cpu(),
                                               dim=1,
                                               keepdim=True),
                                  name="unt_argmax_val",
                                  title="Unet",
                                  n_iter=epoch)
        self.clog.show_image_grid(pred.data.cpu(),
                                  name="unt_val",
                                  normalize=True,
                                  scale_each=True,
                                  n_iter=epoch)

    def test(self):
        # TODO
        print(' In test() method here')
        self.elog.print('----------Test-------------')
        self.model.eval()
        trial = 10
        data = None
        loss_list = []
        acc_list = []
        metrics = defaultdict(float)
        with torch.no_grad():
            for batch_idx, (images, masks) in enumerate(self.test_data_loader):
                data, target = images.to(self.device), masks.to(self.device)
                pred = self.model(data)
                pred = torch.sigmoid(pred)
                #pred = np.where(pred > 0.5,1,0)
                #pred = F.softmax(pred, dim=1)
                # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.
                # Ramesh check if soft max is needed
                # loss = self.dice_loss(pred_softmax, target.squeeze()) + self.ce_loss(pred, target.squeeze())
                # loss = F.binary_cross_entropy(pred, masks)

                #loss,dice = calc_loss(pred, target, metrics)
                acc = (-1) * soft_dice(pred, target)
                acc_list.append(acc.item())

                #loss = F.binary_cross_entropy(pred, target) + soft_dice(pred,target)
                loss = self.bce_weight * F.binary_cross_entropy(
                    pred, target) + (1 - self.bce_weight) * soft_dice(
                        pred, target)
                loss_list.append(loss.item())
                assert data is not None, 'data is None. Please check if your dataloader works properly'
                #self.scheduler.step(np.mean(loss_list))
                self.add_result(value=loss.item(),
                                name='Test_Loss',
                                tag='Test_Loss',
                                counter=trial + 1)
                self.add_result(value=acc.item(),
                                name='Test_Mean_Accuracy',
                                tag='Test_Accuracy',
                                counter=trial + 1)

                self.clog.show_image_grid(data.float().cpu(),
                                          name="data_test",
                                          normalize=True,
                                          scale_each=True,
                                          n_iter=trial)

                self.clog.show_image_grid(target.float().cpu(),
                                          name="mask_test",
                                          title="Mask",
                                          n_iter=trial)

                self.clog.show_image_grid(torch.argmax(pred.data.cpu(),
                                                       dim=1,
                                                       keepdim=True),
                                          name="unt_argmax_test",
                                          title="Unet",
                                          n_iter=trial)

                self.clog.show_image_grid(pred.data.cpu(),
                                          name="unt_test",
                                          normalize=True,
                                          scale_each=True,
                                          n_iter=trial)

            self.elog.print(
                'Test Mean Loss: %.4f Test Mean Dice :' % (np.mean(loss_list)),
                np.mean(acc_list))
示例#5
0
class UNetExperiment:
    """
    This class implements the basic life cycle for a segmentation task with UNet(https://arxiv.org/abs/1505.04597).
    The basic life cycle of a UNetExperiment is:
        run():
            for epoch in n_epochs:
                train()
                validate()
        test()
    """
    def __init__(self, config, split, dataset):
        self.n_epochs = config.n_epochs
        self.split = split
        self._time_start = ""
        self._time_end = ""
        self.epoch = 0
        self.name = config.name

        # Create output folders
        dirname = f'{time.strftime("%Y-%m-%d_%H%M", time.gmtime())}_{self.name}'
        self.out_dir = os.path.join(config.test_results_dir, dirname)
        os.makedirs(self.out_dir, exist_ok=True)

        # Create data loaders
        self.train_loader = DataLoader(SlicesDataset(dataset[split["train"]]),
                batch_size=config.batch_size, shuffle=True, num_workers=0)
        self.val_loader = DataLoader(SlicesDataset(dataset[split["val"]]),
                batch_size=config.batch_size, shuffle=True, num_workers=0)

        # access volumes directly for testing
        self.test_data = dataset[split["test"]]

        if not torch.cuda.is_available():
            print("WARNING: No CUDA device is found. This may take significantly longer!")
        #self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = torch.device('cpu')

        # use a recursive UNet model from German Cancer Research Center, Division of Medical Image Computing
        self.model = UNet()
        self.model.to(self.device)

        # use a standard cross-entropy loss since the model output is essentially
        # a tensor with softmax prediction of each pixel's probability of belonging to a certain class
        self.loss_function = torch.nn.CrossEntropyLoss()

        # use standard SGD method to optimize the weights
        self.optimizer = optim.Adam(self.model.parameters(), lr=config.learning_rate)
        
        # Scheduler helps to update learning rate automatically
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min')

        # Set up Tensorboard. By default it saves data into runs folder. You need to launch
#         self.tensorboard_train_writer = SummaryWriter(comment="_train")
#         self.tensorboard_val_writer = SummaryWriter(comment="_val")

    def train(self):
        """
        This method is executed once per epoch and takes 
        care of model weight update cycle
        """
        print(f"Training epoch {self.epoch}...")
        self.model.train()

        # Loop over the minibatches
        for i, batch in enumerate(self.train_loader):
            self.optimizer.zero_grad()

            # Feed data to the model and feed target to the loss function
            data = batch['image'].float()
            target = batch['seg']
            prediction = self.model(data.to(self.device))
            prediction_softmax = F.softmax(prediction, dim=1)
            loss = self.loss_function(prediction_softmax, target[:, 0, :, :].to(self.device))

            # What does each dimension of variable prediction represent?
            # batch_size, 3 classes, coronal, axial

            loss.backward()
            self.optimizer.step()

            if (i % 10) == 0:
                # Output to console on every 10th batch
                print(f"\nEpoch: {self.epoch} Train loss: {loss}, {100*(i+1)/len(self.train_loader):.1f}% complete")

                counter = 100*self.epoch + 100*(i/len(self.train_loader))

#                 log_to_tensorboard(
#                     self.tensorboard_train_writer,
#                     loss,
#                     data,
#                     target,
#                     prediction_softmax,
#                     prediction,
#                     counter)

            print(".", end='')

        print("\nTraining complete")

    def validate(self):
        """
        This method runs validation cycle, using same metrics as 
        Train method. Note that model needs to be switched to eval
        mode and no_grad needs to be called so that gradients do not 
        propagate
        """
        print(f"Validating epoch {self.epoch}...")

        # Turn off gradient accumulation by switching model to "eval" mode
        self.model.eval()
        loss_list = []

        with torch.no_grad():
            for i, batch in enumerate(self.val_loader):              
                data = batch['image'].float()
                target = batch['seg']
                prediction = self.model(data.to(self.device))
                prediction_softmax = F.softmax(prediction, dim=1)
                loss = self.loss_function(prediction_softmax, target[:, 0, :, :].to(self.device))

                print(f"Batch {i}. Data shape {data.shape} Loss {loss}")

                # We report loss that is accumulated across all of validation set
                loss_list.append(loss.item())

        self.scheduler.step(np.mean(loss_list))

#         log_to_tensorboard(
#             self.tensorboard_val_writer,
#             np.mean(loss_list),
#             data,
#             target,
#             prediction_softmax, 
#             prediction,
#             (self.epoch+1) * 100)
        print(f"Validation complete")

    def save_model_parameters(self):
        """
        Saves model parameters to a file in results directory
        """
        path = os.path.join(self.out_dir, "model.pth")

        torch.save(self.model.state_dict(), path)

    def load_model_parameters(self, path=''):
        """
        Loads model parameters from a supplied path or a
        results directory
        """
        if not path:
            model_path = os.path.join(self.out_dir, "model.pth")
        else:
            model_path = path

        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path))
        else:
            raise Exception(f"Could not find path {model_path}")

    def run_test(self):
        """
        This runs test cycle on the test dataset.
        Note that process and evaluations are quite different
        Here we are computing a lot more metrics and returning
        a dictionary that could later be persisted as JSON
        """
        print("Testing...")
        self.model.eval()

        inference_agent = UNetInferenceAgent(model=self.model, device=self.device)

        out_dict = {}
        out_dict["volume_stats"] = []
        dc_list = []
        jc_list = []

        # for every in test set
        for i, x in enumerate(self.test_data):
            pred_label = inference_agent.single_volume_inference(x["image"])

            # We compute and report Dice and Jaccard similarity coefficients which 
            # assess how close our volumes are to each other

            dc = Dice3d(pred_label, x["seg"])
            jc = Jaccard3d(pred_label, x["seg"])
            dc_list.append(dc)
            jc_list.append(jc)

            # STAND-OUT SUGGESTION: By way of exercise, consider also outputting:
            # * Sensitivity and specificity (and explain semantic meaning in terms of 
            #   under/over segmenting)
            # * Dice-per-slice and render combined slices with lowest and highest DpS
            # * Dice per class (anterior/posterior)

            out_dict["volume_stats"].append({
                "filename": x['filename'],
                "dice": dc,
                "jaccard": jc
                })
            print(f"{x['filename']} Dice {dc:.4f}. {100*(i+1)/len(self.test_data):.2f}% complete")

        out_dict["overall"] = {
            "mean_dice": np.mean(dc_list),
            "mean_jaccard": np.mean(jc_list)}

        print("\nTesting complete.")
        return out_dict

    def run(self):
        """
        Kicks off train cycle and writes model parameter file at the end
        """
        self._time_start = time.time()

        print("Experiment started.")

        # Iterate over epochs
        for self.epoch in range(self.n_epochs):
            self.train()
            self.validate()

        # save model for inferencing
        self.save_model_parameters()

        self._time_end = time.time()
        print(f"Run complete. Total time: {time.strftime('%H:%M:%S', time.gmtime(self._time_end - self._time_start))}")
示例#6
0
class CHDExperiment(PytorchExperiment):
    """
    """

    def setup(self):

        pkl_dir = self.config.split_dir
        with open(os.path.join(pkl_dir, "splits.pkl"), 'rb') as f:
            splits = pickle.load(f)

        tr_keys = splits[self.config.fold]['train']
        val_keys = splits[self.config.fold]['val']
        keys = tr_keys + val_keys
        test_keys = splits[self.config.fold]['test']

        self.device = torch.device(self.config.device if torch.cuda.is_available() else 'cpu')    #

        self.model = UNet(num_classes=self.config.num_classes, num_downs=3)

        self.model.to(self.device)

        self.data_loader = NumpyDataSet(self.config.data_dir, target_size=256, batch_size=self.config.batch_size,
                                        keys=keys, mode='test', do_reshuffle=False)

        self.data_16_loader = NumpyDataSet(self.config.scaled_image_32_dir, target_size=32, batch_size=self.config.batch_size,
                                        keys=keys, mode='test', do_reshuffle=False)

        # We use a combination of DICE-loss and CE-Loss in this example.
        # This proved good in the medical segmentation decathlon.
        self.dice_loss = SoftDiceLoss(batch_dice=True)  # Softmax für DICE Loss!

        # weight = torch.tensor([1, 30, 30]).float().to(self.device)
        self.ce_loss = torch.nn.CrossEntropyLoss()  # Kein Softmax für CE Loss -> ist in torch schon mit drin!
        # self.dice_pytorch = dice_pytorch(self.config.num_classes)

        self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.learning_rate)
        # self.optimizer = optim.SGD(self.model.parameters(), lr=self.config.learning_rate)

        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min')

        # If directory for checkpoint is provided, we load it.
        if self.config.do_load_checkpoint:
            if self.config.checkpoint_dir == '':
                print('checkpoint_dir is empty, please provide directory to load checkpoint.')
            else:
                self.load_checkpoint(name=self.config.checkpoint_dir, save_types=("model"))



    def inference(self):
        self.elog.print('=====INFERENCE=====')
        image_files = subfiles(self.config.scaled_image_32_dir, suffix='.npy')

        with torch.no_grad():
            if os.path.exists(self.config.stage_1_dir_32):
               print('stage_1_dir already exists')
            else:
                for data_batch in self.data_16_loader:
                    file_dir = data_batch['fnames']
                    data_16 = data_batch['data'][0].float().to(self.device)  # size (8, 1, 16, 16)
                    target_16 = data_batch['seg'][0].float().to(self.device)

                    if not os.path.exists(self.config.stage_1_dir_32):
                        os.mkdir(self.config.stage_1_dir_32)
                        print('Creatting stage_1_dir...')

                    pred_16 = self.model(data_16)
                    pred_16_softmax = F.softmax(pred_16, dim=1)
                    dice_16 = 1 - self.dice_loss(pred_16_softmax, target_16.squeeze())
                    ce_16 = self.ce_loss(pred_16, target_16.squeeze().long())


                    if dice_16 < 0.6:
                        print(file_dir[0])
                        print(data_batch['slice_idxs'])

                    pred_32 = F.interpolate(pred_16, scale_factor=2, mode='bilinear')
                    target_32 = F.interpolate(target_16, scale_factor=2, mode='bilinear')
                    pred_32_softmax = F.softmax(pred_32, dim=1)
                    dice_32 = 1 - self.dice_loss(pred_32_softmax, target_32.squeeze())
                    ce_32 = self.ce_loss(pred_32, target_32.squeeze().long())

                    # print('dice_16: %.4f  dice_32: %.4f' % (dice_16, dice_32))
                    print('dice_16: %.4f dice_32: %.4f ce_16: %.4f  ce_32: %.4f' % (dice_16, dice_32, ce_16, ce_32))

                    for k in range(self.config.batch_size):
                        filename = file_dir[k][0][-14:-4]
                        output_dir = os.path.join(self.config.stage_1_dir_32,
                                                  'pred_' + filename + '_64')
                        if os.path.exists(output_dir + '.npy'):
                            all_data = np.load(output_dir + '.npy')
                            new_data = np.concatenate((pred_32[k:k + 1], target_32[k:k + 1]),
                                                      axis=1)  # size (1,9,32,16)
                            all_data = np.concatenate((all_data, new_data), axis=0)
                        else:
                            all_data = np.concatenate((pred_32[k:k + 1], target_32[k:k + 1]), axis=1)
                            print(filename)

                        np.save(output_dir, all_data)

        # do softmax analysis, and divide the pred image into 4 parts
            pred_32_files = subfiles(self.config.stage_1_dir_32, suffix='32.npy', join=False)

            with torch.no_grad():
                softmax = []
                dice_score = []
                for file in pred_32_files:
                    dice_score = []
                    pred_32 = np.load(os.path.join(self.config.stage_1_dir_32, file))[:, 0:8]  # size (N,8,32,32)
                    target_32 = np.load(os.path.join(self.config.stage_1_dir_32, file))[:, 8:9]

                    pred_32 = torch.tensor(pred_32).float()
                    target_32 = torch.tensor(target_32).long()


                    md_softmax, index, weak_image = softmax_analysis(pred_32, threshold=0)

                    softmax = softmax + md_softmax

                    shape = pred_32.shape
                    image_num = shape[0]

                    for k in range(image_num):
                        pred_softmax = F.softmax(pred_32[k:k+1], 1)
                        dice = self.dice_loss(pred_softmax, target_32[k])
                        pred_image = torch.argmax(pred_softmax, dim=1)
                        dice_score.append(dice)
                    # visualize dice
                    dice_score = np.array(dice_score)
                    avg_dice = np.average(dice_score)
                    min_dice = min(dice_score)
                    # print(file, 'dice_loss:%.4f  min_dice:%.4f' % (avg_dice, min_dice))

                    plot_bar(softmax)

    def compare_dice(self):
        with torch.no_grad():
            for data_batch in self.data_loader:
                file = data_batch['fname']
                data = data_batch['data'][0].float().to(self.device)  # size (8, 1, 256, 256)
                target = data_batch['seg'][0].long().to(self.device)

                file_16 = os.path.join(self.config.scaled_image_32_dir, file[0])
                image_16_tensor = torch.tensor(np.load(file_16))  # size (N, 2, 16, 16)
                data_16 = image_16_tensor[:, 0:1].float().to(self.device)  # size (N, 1, 16, 16)
                target_16 = image_16_tensor[:, 1:2].float().to(self.device)

                idxs = data_batch['slice_indxs']
                start_idx = min(idxs)
                end_idx = max(idxs)

                scaled_data = data_16[start_idx:end_idx + 1]
                scaled_target = target_16[start_idx:end_idx + 1]  # size (8, 8, 16, 16)

                scaled_pred = self.model(scaled_data)
                scaled_pred_softmax = F.softmax(scaled_pred)
                loss_16 = self.dice_loss(scaled_pred_softmax, scaled_target.squeeze())

                upsample_pred = F.interpolate(scaled_target, scale_factor=16, mode='bilinear')
                pred_sofmax = F.softmax(upsample_pred)
                loss = self.dice_loss(pred_sofmax, target_16.squeeze())

                print('loss_16: %.4f  loss: %.4f' % (loss_16, loss))
class UNetExperiment:
    """
    This class implements the basic life cycle for a segmentation task with UNet(https://arxiv.org/abs/1505.04597).
    The basic life cycle of a UNetExperiment is:

        run():
            for epoch in n_epochs:
                train()
                validate()
        test()
    """
    def __init__(self, config, split, dataset):
        self.n_epochs = config.n_epochs
        self.split = split
        self._time_start = ""
        self._time_end = ""
        self.epoch = 0
        self.name = config.name

        # Create output folders
        dirname = f'{time.strftime("%Y-%m-%d_%H%M", time.gmtime())}_{self.name}'
        self.out_dir = os.path.join(config.test_results_dir, dirname)
        os.makedirs(self.out_dir, exist_ok=True)

        # Create data loaders
        self.train_loader = DataLoader(SlicesDataset(dataset[split["train"]]),
                                       batch_size=config.batch_size,
                                       shuffle=True,
                                       num_workers=0)
        self.val_loader = DataLoader(SlicesDataset(dataset[split["val"]]),
                                     batch_size=config.batch_size,
                                     shuffle=True,
                                     num_workers=0)

        # we will access volumes directly for testing
        self.test_data = dataset[split["test"]]

        # Do we have CUDA available?
        if not torch.cuda.is_available():
            print(
                "WARNING: No CUDA device is found. This may take significantly longer!"
            )
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        # Configure our model and other training implements
        self.model = UNet(num_classes=3)
        self.model.to(self.device)

        # Cross entropy loss
        self.loss_function = torch.nn.CrossEntropyLoss()

        # We are using standard SGD method to optimize our weights
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=config.learning_rate)
        # Scheduler helps us update learning rate automatically
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, 'min')

        # Set up Tensorboard. By default it saves data into runs folder.
        self.tensorboard_train_writer = SummaryWriter(comment="_train")
        self.tensorboard_val_writer = SummaryWriter(comment="_val")

    def train(self):
        """
        This method is executed once per epoch and takes 
        care of model weight update cycle
        """
        print(f"Training epoch {self.epoch}...")
        self.model.train()

        # Loop over our minibatches
        for i, batch in enumerate(self.train_loader):
            self.optimizer.zero_grad()

            # Put the slices as 4D Torch Tensors of
            # shape [BATCH_SIZE, 1, PATCH_SIZE, PATCH_SIZE] into variables data and target.
            # Feed data to the model and feed target to the loss function

            data = batch['image'].float().to(self.device)
            target = batch['seg'].to(self.device)

            prediction = self.model(data)

            # We are also getting softmax'd version of prediction to output a probability map
            prediction_softmax = F.softmax(prediction, dim=1)

            loss = self.loss_function(prediction, target[:, 0, :, :].long())

            loss.backward()
            self.optimizer.step()

            if (i % 10) == 0:
                # Output to console on every 10th batch
                print(
                    f"\nEpoch: {self.epoch} Train loss: {loss}, {100*(i+1)/len(self.train_loader):.1f}% complete"
                )

                counter = 100 * self.epoch + 100 * (i / len(self.train_loader))

                log_to_tensorboard(self.tensorboard_train_writer, loss, data,
                                   target, prediction_softmax, prediction,
                                   counter)

            print(".", end='')

        print("\nTraining complete")

    def validate(self):
        """
        This method runs validation cycle, using same metrics as 
        Train method. Note that model needs to be switched to eval
        mode and no_grad needs to be called so that gradients do not 
        propagate
        """
        print(f"Validating epoch {self.epoch}...")

        # Turn off gradient accumulation by switching model to "eval" mode
        self.model.eval()
        loss_list = []

        with torch.no_grad():
            for i, batch in enumerate(self.val_loader):

                # Compute loss on a validation sample
                data = batch["image"].float().to(self.device)
                target = batch["seg"].to(self.device)
                prediction = self.model(data)
                prediction_softmax = F.softmax(prediction, dim=1)

                loss = self.loss_function(prediction, target[:,
                                                             0, :, :].long())

                print(f"Batch {i}. Data shape {data.shape} Loss {loss}")

                # We report loss that is accumulated across all of validation set
                loss_list.append(loss.item())

        self.scheduler.step(np.mean(loss_list))

        log_to_tensorboard(self.tensorboard_val_writer, np.mean(loss_list),
                           data, target, prediction_softmax, prediction,
                           (self.epoch + 1) * 100)
        print(f"Validation complete")

    def save_model_parameters(self):
        """
        Saves model parameters to a file in results directory
        """
        path = os.path.join(self.out_dir, "model.pth")

        torch.save(self.model.state_dict(), path)

    def load_model_parameters(self, path=''):
        """
        Loads model parameters from a supplied path or a
        results directory
        """
        if not path:
            model_path = os.path.join(self.out_dir, "model.pth")
        else:
            model_path = path

        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path))
        else:
            raise Exception(f"Could not find path {model_path}")

    def run_test(self):
        """
        This runs test cycle on the test dataset.
        Note that process and evaluations are quite different
        Here we are computing a lot more metrics and returning
        a dictionary that could later be persisted as JSON
        """
        print("Testing...")
        self.model.eval()

        # In this method we will be computing metrics that are relevant to the task of 3D volume
        # segmentation. Therefore, unlike train and validation methods, we will do inferences
        # on full 3D volumes, much like we will be doing it when we deploy the model in the
        # clinical environment.

        # Inference Agent is not complete.
        inference_agent = UNetInferenceAgent(model=self.model,
                                             device=self.device)

        out_dict = {}
        out_dict["volume_stats"] = []
        dc_list = []
        jc_list = []

        # for every in test set
        for i, x in enumerate(self.test_data):
            pred_label = inference_agent.single_volume_inference(x["image"])

            # Dice3D and Jaccard3D functions are not implemented.

            dc = Dice3d(pred_label, x["seg"])
            jc = Jaccard3d(pred_label, x["seg"])
            dc_list.append(dc)
            jc_list.append(jc)

            out_dict["volume_stats"].append({
                "filename": x['filename'],
                "dice": dc,
                "jaccard": jc
            })
            print(
                f"{x['filename']} Dice {dc:.4f}. {100*(i+1)/len(self.test_data):.2f}% complete"
            )

        out_dict["overall"] = {
            "mean_dice": np.mean(dc_list),
            "mean_jaccard": np.mean(jc_list)
        }

        print("\nTesting complete.")
        return out_dict

    def run(self):
        """
        Kicks off train cycle and writes model parameter file at the end
        """
        self._time_start = time.time()

        print("Experiment started.")

        # Iterate over epochs
        for self.epoch in range(self.n_epochs):
            self.train()
            self.validate()

        # save model for inferencing
        self.save_model_parameters()

        self._time_end = time.time()
        print(
            f"Run complete. Total time: {time.strftime('%H:%M:%S', time.gmtime(self._time_end - self._time_start))}"
        )
示例#8
0
class UNetExperiment(PytorchExperiment):
    """
    The UnetExperiment is inherited from the PytorchExperiment. It implements the basic life cycle for a segmentation task with UNet(https://arxiv.org/abs/1505.04597).
    It is optimized to work with the provided NumpyDataLoader.

    The basic life cycle of a UnetExperiment is the same s PytorchExperiment:

        setup()
        (--> Automatically restore values if a previous checkpoint is given)
        prepare()

        for epoch in n_epochs:
            train()
            validate()
            (--> save current checkpoint)

        end()
    """
    def setup(self):
        pkl_dir = self.config.split_dir
        with open(os.path.join(pkl_dir, "splits.pkl"), 'rb') as f:
            splits = pickle.load(f)

        tr_keys = splits[self.config.fold]['train']
        val_keys = splits[self.config.fold]['val']
        test_keys = splits[self.config.fold]['test']

        self.device = torch.device(
            self.config.device if torch.cuda.is_available() else "cpu")

        self.train_data_loader = NumpyDataSet(
            self.config.data_dir,
            target_size=self.config.patch_size,
            batch_size=self.config.batch_size,
            keys=tr_keys)
        self.val_data_loader = NumpyDataSet(self.config.data_dir,
                                            target_size=self.config.patch_size,
                                            batch_size=self.config.batch_size,
                                            keys=val_keys,
                                            mode="val",
                                            do_reshuffle=False)
        self.test_data_loader = NumpyDataSet(
            self.config.data_test_dir,
            target_size=self.config.patch_size,
            batch_size=self.config.batch_size,
            keys=test_keys,
            mode="test",
            do_reshuffle=False)
        self.model = UNet(num_classes=self.config.num_classes,
                          in_channels=self.config.in_channels)

        self.model.to(self.device)

        # We use a combination of DICE-loss and CE-Loss in this example.
        # This proved good in the medical segmentation decathlon.
        self.dice_loss = SoftDiceLoss(
            batch_dice=True)  # Softmax for DICE Loss!
        self.ce_loss = torch.nn.CrossEntropyLoss(
        )  # No softmax for CE Loss -> is implemented in torch!

        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.config.learning_rate)

        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min')

        # If directory for checkpoint is provided, we load it.
        if self.config.do_load_checkpoint:
            if self.config.checkpoint_dir == '':
                print('Checkpoint_dir is empty, training from scratch.')
            else:
                self.load_checkpoint(name=self.config.checkpoint_filename,
                                     save_types=("model"),
                                     path=self.config.checkpoint_dir)

            if self.config.fine_tune in ['expanding_all', 'expanding_plus1']:
                # freeze part of the network, fine-tune the other part
                unfreeze_block_parameters(
                    model=self.model, fine_tune_option=self.config.fine_tune)
                # else just train the whole network

        self.save_checkpoint(name="checkpoint_start")
        self.elog.print('Experiment set up.')

    # overloaded method from the base class PytorchExperiment
    def load_checkpoint(self,
                        name="checkpoint",
                        save_types=("model", "optimizer", "simple", "th_vars",
                                    "results"),
                        n_iter=None,
                        iter_format="{:05d}",
                        prefix=False,
                        path=None):
        """
        Loads a checkpoint and restores the experiment.
        Make sure you have your torch stuff already on the right devices beforehand,
        otherwise this could lead to errors e.g. when making a optimizer step
        (and for some reason the Adam states are not already on the GPU:
        https://discuss.pytorch.org/t/loading-a-saved-model-for-continue-training/17244/3 )
        Args:
            name (str): The name of the checkpoint file
            save_types (list or tuple): What kind of member variables should be loaded? Choices are:
                "model" <-- Pytorch models,
                "optimizer" <-- Optimizers,
                "simple" <-- Simple python variables (basic types and lists/tuples),
                "th_vars" <-- torch tensors,
                "results" <-- The result dict
            n_iter (int): Number of iterations. Together with the name, defined by the iter_format,
                a file name will be created and searched for.
            iter_format (str): Defines how the name and the n_iter will be combined.
            prefix (bool): If True, the formatted n_iter will be prepended, otherwise appended.
            path (str): If no path is given then it will take the current experiment dir and formatted
                name, otherwise it will simply use the path and the formatted name to define the
                checkpoint file.
        """
        if self.elog is None:
            return

        model_dict = {}
        optimizer_dict = {}
        simple_dict = {}
        th_vars_dict = {}
        results_dict = {}

        if "model" in save_types:
            model_dict = self.get_pytorch_modules()
        if "optimizer" in save_types:
            optimizer_dict = self.get_pytorch_optimizers()
        if "simple" in save_types:
            simple_dict = self.get_simple_variables()
        if "th_vars" in save_types:
            th_vars_dict = self.get_pytorch_variables()
        if "results" in save_types:
            results_dict = {"results": self.results}

        checkpoint_dict = {
            **model_dict,
            **optimizer_dict,
            **simple_dict,
            **th_vars_dict,
            **results_dict
        }

        if n_iter is not None:
            name = name_and_iter_to_filename(name,
                                             n_iter,
                                             ".pth.tar",
                                             iter_format=iter_format,
                                             prefix=prefix)

        # Jorg Begin
        # if self.config.dont_load_lastlayer:
        #     exclude_layer_dict = {'model': ['model.model.5.weight', 'model.model.5.bias']}
        # else:
        #     exclude_layer_dict = {}
        exclude_layer_dict = {}
        # Jorg End

        if path is None:
            restore_dict = self.elog.load_checkpoint(name=name,
                                                     **checkpoint_dict)
        else:
            checkpoint_path = os.path.join(path, name)
            if checkpoint_path.endswith("/"):
                checkpoint_path = checkpoint_path[:-1]
            restore_dict = self.elog.load_checkpoint_static(
                checkpoint_file=checkpoint_path,
                exclude_layer_dict=exclude_layer_dict,
                **checkpoint_dict)

        self.update_attributes(restore_dict)

    def train(self, epoch):
        self.elog.print('=====TRAIN=====')
        self.model.train()

        data = None
        batch_counter = 0
        for data_batch in self.train_data_loader:

            self.optimizer.zero_grad()

            # Shape of data_batch = [1, b, c, w, h]
            # Desired shape = [b, c, w, h]
            # Move data and target to the GPU
            data = data_batch['data'][0].float().to(self.device)
            target = data_batch['seg'][0].long().to(self.device)

            pred = self.model(data)
            pred_softmax = F.softmax(
                pred, dim=1
            )  # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.

            loss = self.dice_loss(pred_softmax,
                                  target.squeeze()) + self.ce_loss(
                                      pred, target.squeeze())

            loss.backward()
            self.optimizer.step()

            # Some logging and plotting
            if (batch_counter % self.config.plot_freq) == 0:
                self.elog.print('Epoch: {0} Loss: {1:.4f}'.format(
                    self._epoch_idx, loss))

                self.add_result(
                    value=loss.item(),
                    name='Train_Loss',
                    tag='Loss',
                    counter=epoch +
                    (batch_counter /
                     self.train_data_loader.data_loader.num_batches))

                self.clog.show_image_grid(data.float().cpu(),
                                          name="data",
                                          normalize=True,
                                          scale_each=True,
                                          n_iter=epoch)
                self.clog.show_image_grid(target.float().cpu(),
                                          name="mask",
                                          title="Mask",
                                          n_iter=epoch)
                self.clog.show_image_grid(torch.argmax(pred.cpu(),
                                                       dim=1,
                                                       keepdim=True),
                                          name="unt_argmax",
                                          title="Unet",
                                          n_iter=epoch)
                self.clog.show_image_grid(pred.cpu()[:, 1:2, ],
                                          name="unt",
                                          normalize=True,
                                          scale_each=True,
                                          n_iter=epoch)

            batch_counter += 1

        assert data is not None, 'data is None. Please check if your dataloader works properly'

    def validate(self, epoch):
        self.elog.print('VALIDATE')
        self.model.eval()

        data = None
        loss_list = []

        with torch.no_grad():
            for data_batch in self.val_data_loader:
                data = data_batch['data'][0].float().to(self.device)
                target = data_batch['seg'][0].long().to(self.device)

                pred = self.model(data)
                pred_softmax = F.softmax(
                    pred, dim=1
                )  # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.

                loss = self.dice_loss(pred_softmax,
                                      target.squeeze()) + self.ce_loss(
                                          pred, target.squeeze())
                loss_list.append(loss.item())

        assert data is not None, 'data is None. Please check if your dataloader works properly'
        self.scheduler.step(np.mean(loss_list))

        self.elog.print('Epoch: %d Loss: %.4f' %
                        (self._epoch_idx, np.mean(loss_list)))

        self.add_result(value=np.mean(loss_list),
                        name='Val_Loss',
                        tag='Loss',
                        counter=epoch + 1)

        self.clog.show_image_grid(data.float().cpu(),
                                  name="data_val",
                                  normalize=True,
                                  scale_each=True,
                                  n_iter=epoch)
        self.clog.show_image_grid(target.float().cpu(),
                                  name="mask_val",
                                  title="Mask",
                                  n_iter=epoch)
        self.clog.show_image_grid(torch.argmax(pred.data.cpu(),
                                               dim=1,
                                               keepdim=True),
                                  name="unt_argmax_val",
                                  title="Unet",
                                  n_iter=epoch)
        self.clog.show_image_grid(pred.data.cpu()[:, 1:2, ],
                                  name="unt_val",
                                  normalize=True,
                                  scale_each=True,
                                  n_iter=epoch)

    def test(self):
        from evaluation.evaluator import aggregate_scores, Evaluator
        from collections import defaultdict

        self.elog.print('=====TEST=====')
        self.model.eval()

        pred_dict = defaultdict(list)
        gt_dict = defaultdict(list)

        batch_counter = 0

        if self.config.visualize_segm:
            color_class_converter = LabelTensorToColor()

        with torch.no_grad():
            for data_batch in self.test_data_loader:
                print('testing...', batch_counter)
                batch_counter += 1

                # Get data_batches
                mr_data = data_batch['data'][0].float().to(self.device)
                mr_target = data_batch['seg'][0].float().to(self.device)

                pred = self.model(mr_data)
                pred_argmax = torch.argmax(pred.data.cpu(),
                                           dim=1,
                                           keepdim=True)

                fnames = data_batch['fnames']
                for i, fname in enumerate(fnames):
                    pred_dict[fname[0]].append(
                        pred_argmax[i].detach().cpu().numpy())
                    gt_dict[fname[0]].append(
                        mr_target[i].detach().cpu().numpy())

                if batch_counter == 35 and self.config.visualize_segm:
                    segm_visualization(mr_data, mr_target, pred_argmax,
                                       color_class_converter, self.config)

        test_ref_list = []
        for key in pred_dict.keys():
            test_ref_list.append(
                (np.stack(pred_dict[key]), np.stack(gt_dict[key])))

        scores = aggregate_scores(test_ref_list,
                                  evaluator=Evaluator,
                                  json_author=self.config.author,
                                  json_task=self.config.name,
                                  json_name=self.config.name,
                                  json_output_file=self.elog.work_dir +
                                  "/{}_".format(self.config.author) +
                                  self.config.name + '.json')

        self.scores = scores

        print("Scores:\n", scores)
示例#9
0
class UNetExperiment:
    """
    This class implements the basic life cycle for a segmentation task with UNet(https://arxiv.org/abs/1505.04597).
    The basic life cycle of a UNetExperiment is:

        run():
            for epoch in n_epochs:
                train()
                validate()
        test()
    """
    def __init__(self, config, split, dataset):
        self.n_epochs = config.n_epochs
        self.split = split
        self._time_start = ""
        self._time_end = ""
        self.epoch = 0
        self.name = config.name

        # Create output folders
        dirname = f'{time.strftime("%Y-%m-%d_%H%M", time.gmtime())}_{self.name}'
        self.out_dir = os.path.join(config.test_results_dir, dirname)
        os.makedirs(self.out_dir, exist_ok=True)
        self.out_images_dir = os.path.join(self.out_dir, "images")
        os.makedirs(self.out_images_dir)

        # Create data loaders
        # Note that we are using a 2D version of UNet here, which means that it will expect
        # batches of 2D slices.
        self.train_loader = DataLoader(SlicesDataset(dataset[split["train"]]),
                batch_size=config.batch_size, shuffle=True, num_workers=0)
        self.val_loader = DataLoader(SlicesDataset(dataset[split["val"]]),
                batch_size=config.batch_size, shuffle=True, num_workers=0)

        # we will access volumes directly for testing
        self.test_data = dataset[split["test"]]

        # Do we have CUDA available?
        if not torch.cuda.is_available():
            print("WARNING: No CUDA device is found. This may take significantly longer!")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Configure our model and other training implements
        # We will use a recursive UNet model from German Cancer Research Center,
        # Division of Medical Image Computing. It is quite complicated and works
        # very well on this task.
        self.model = UNet(num_classes=3)
        self.model.to(self.device)

        # We are using a standard cross-entropy loss since the model output is essentially
        # a tensor with softmax'd prediction of each pixel's probability of belonging
        # to a certain class
        self.loss_function = torch.nn.CrossEntropyLoss()

        # We are using standard SGD method to optimize our weights
        self.optimizer = optim.Adam(self.model.parameters(), lr=config.learning_rate)
        # Scheduler helps us update learning rate automatically
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min')

        # Set up Tensorboard. By default it saves data into runs folder. You need to launch
        self.tensorboard_train_writer = SummaryWriter(comment="_train")
        self.tensorboard_val_writer = SummaryWriter(comment="_val")

    def train(self):
        """
        This method is executed once per epoch and takes
        care of model weight update cycle
        """
        print(f"Training epoch {self.epoch}...")
        self.model.train()

        # Loop over our minibatches
        for i, batch in enumerate(self.train_loader):
            self.optimizer.zero_grad()

            # We have our data in batch variable. Put the slices as 4D Torch Tensors of
            # shape [BATCH_SIZE, 1, PATCH_SIZE, PATCH_SIZE] into variables data and target.
            # Feed data to the model and feed target to the loss function
            #
            data = batch['image'].to(self.device, dtype=torch.float)
            target = batch['seg'].to(self.device)

            prediction = self.model(data)


            # We are also getting softmax'd version of prediction to output a probability map
            # so that we can see how the model converges to the solution
            prediction_softmax = F.softmax(prediction, dim=1)

            loss = self.loss_function(prediction, target[:, 0, :, :])

            # What does each dimension of variable prediction represent?
            # Each dimension is the probability for each pixel of a imput 2D slice for each class

            loss.backward()
            self.optimizer.step()

            if (i % 10) == 0:
                # Output to console on every 10th batch
                print(f"\nEpoch: {self.epoch} Train loss: {loss}, {100*(i+1)/len(self.train_loader):.1f}% complete")

                counter = 100*self.epoch + 100*(i/len(self.train_loader))

                log_to_tensorboard(
                    self.tensorboard_train_writer,
                    loss,
                    data,
                    target,
                    prediction_softmax,
                    prediction,
                    counter)

            print(".", end='')

        print("\nTraining complete")

    def validate(self):
        """
        This method runs validation cycle, using same metrics as
        Train method. Note that model needs to be switched to eval
        mode and no_grad needs to be called so that gradients do not
        propagate
        """
        print(f"Validating epoch {self.epoch}...")

        # Turn off gradient accumulation by switching model to "eval" mode
        self.model.eval()
        loss_list = []

        with torch.no_grad():
            for i, batch in enumerate(self.val_loader):

                # Compute loss on a validation sample
                data = batch['image'].to(self.device, dtype=torch.float)
                target = batch['seg'].to(self.device)

                prediction = self.model(data)

                # We are also getting softmax'd version of prediction to output a probability map
                # so that we can see how the model converges to the solution
                prediction_softmax = F.softmax(prediction, dim=1)

                loss = self.loss_function(prediction, target[:, 0, :, :])

                print(f"Batch {i}. Data shape {data.shape} Loss {loss}")

                # We report loss that is accumulated across all of validation set
                loss_list.append(loss.item())

        self.scheduler.step(np.mean(loss_list))

        log_to_tensorboard(
            self.tensorboard_val_writer,
            np.mean(loss_list),
            data,
            target,
            prediction_softmax,
            prediction,
            (self.epoch+1) * 100)
        print(f"Validation complete")

    def save_predictions(self):
        """
        Saves model predicted images in results directory
        """
        print("Save image predictions")

        # Prepare model for inference
        self.model.eval()
        inference_agent = UNetInferenceAgent(model=self.model, device=self.device)
        # Get first test data volume
        first_test_data = self.test_data[0]
        # Get the model predictions
        pred_label = inference_agent.single_volume_inference(first_test_data["image"])
        # Calculate middle slice indice
        axial_middle_index = int(pred_label.shape[0] / 2)
        # Create middle slice images for these volumes for mri image, target and predictions for this epoch
        image = (first_test_data["image"][axial_middle_index] * 255).astype(np.uint8)
        label = (first_test_data["seg"][axial_middle_index] * 255).astype(np.uint8)
        prediction = (pred_label[axial_middle_index] * 255).astype(np.uint8)
        # Convert from numpy array to image objects
        image = Image.fromarray(image)
        label = Image.fromarray(label)
        prediction = Image.fromarray(prediction)
        # Save images
        image.save(self.out_images_dir + '/Epoch' + str(self.epoch) + '-image.png', cmap='Greys')
        label.save(self.out_images_dir + '/Epoch' + str(self.epoch) + '-label.png', cmap='Greys')
        prediction.save(self.out_images_dir + '/Epoch' + str(self.epoch) + '-prediction.png', cmap='Greys')

    def save_model_parameters(self):
        """
        Saves model parameters to a file in results directory
        """
        path = os.path.join(self.out_dir, "model.pth")

        torch.save(self.model.state_dict(), path)

    def load_model_parameters(self, path=''):
        """
        Loads model parameters from a supplied path or a
        results directory
        """
        if not path:
            model_path = os.path.join(self.out_dir, "model.pth")
        else:
            model_path = path

        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path))
        else:
            raise Exception(f"Could not find path {model_path}")

    def run_test(self):
        """
        This runs test cycle on the test dataset.
        Note that process and evaluations are quite different
        Here we are computing a lot more metrics and returning
        a dictionary that could later be persisted as JSON
        """
        print("Testing...")
        self.model.eval()

        # In this method we will be computing metrics that are relevant to the task of 3D volume
        # segmentation. Therefore, unlike train and validation methods, we will do inferences
        # on full 3D volumes, much like we will be doing it when we deploy the model in the
        # clinical environment.

        # Instantiate inference agent
        inference_agent = UNetInferenceAgent(model=self.model, device=self.device)

        out_dict = {}
        out_dict["volume_stats"] = []
        dc_list = []
        jc_list = []

        # for every in test set
        for i, x in enumerate(self.test_data):
            pred_label = inference_agent.single_volume_inference(x["image"])

            # We compute and report Dice and Jaccard similarity coefficients which
            # assess how close our volumes are to each other

            dc = Dice3d(pred_label, x["seg"])
            jc = Jaccard3d(pred_label, x["seg"])
            dc_list.append(dc)
            jc_list.append(jc)

            out_dict["volume_stats"].append({
                "filename": x['filename'],
                "dice": dc,
                "jaccard": jc
                })
            print(f"{x['filename']} Dice {dc:.4f} Jaccard {dc:.4f} {100*(i+1)/len(self.test_data):.2f}% complete")

        mean_dice = np.mean(dc_list)
        mean_jaccard = np.mean(jc_list)

        print(f" Mean Dice {mean_dice:.4f} Mean Jaccard {mean_jaccard:.4f}")

        out_dict["overall"] = {
            "mean_dice": mean_dice,
            "mean_jaccard": mean_jaccard}

        print("\nTesting complete.")
        return out_dict

    def run(self):
        """
        Kicks off train cycle and writes model parameter file at the end
        """
        self._time_start = time.time()

        print("Experiment started.")

        # Iterate over epochs
        for self.epoch in range(self.n_epochs):
            self.train()
            self.validate()
            self.save_predictions()

        # save model for inferencing
        self.save_model_parameters()

        self._time_end = time.time()
        print(f"Run complete. Total time: {time.strftime('%H:%M:%S', time.gmtime(self._time_end - self._time_start))}")
示例#10
0
class FCNExperiment(PytorchExperiment):
    """
    The UnetExperiment is inherited from the PytorchExperiment. It implements the basic life cycle for a segmentation task with UNet(https://arxiv.org/abs/1505.04597).
    It is optimized to work with the provided NumpyDataLoader.

    The basic life cycle of a UnetExperiment is the same s PytorchExperiment:

        setup()
        (--> Automatically restore values if a previous checkpoint is given)
        prepare()

        for epoch in n_epochs:
            train()
            validate()
            (--> save current checkpoint)

        end()
    """
    def setup(self):
        pkl_dir = self.config.split_dir
        with open(os.path.join(pkl_dir, "splits.pkl"), 'rb') as f:
            splits = pickle.load(f)

        tr_keys = splits[self.config.fold]['train']
        val_keys = splits[self.config.fold]['val']
        test_keys = splits[self.config.fold]['test']

        self.device = torch.device(
            self.config.device if torch.cuda.is_available() else 'cpu')  #

        self.train_data_loader = NumpyDataSet(
            self.config.scaled_image_64_dir,
            target_size=64,
            batch_size=self.config.batch_size,
            keys=tr_keys,
            do_reshuffle=True)
        self.val_data_loader = NumpyDataSet(self.config.scaled_image_64_dir,
                                            target_size=64,
                                            batch_size=self.config.batch_size,
                                            keys=val_keys,
                                            mode="val",
                                            do_reshuffle=True)
        self.test_data_loader = NumpyDataSet(self.config.scaled_image_64_dir,
                                             target_size=64,
                                             batch_size=self.config.batch_size,
                                             keys=test_keys,
                                             mode="test",
                                             do_reshuffle=False)
        self.model = UNet(num_classes=self.config.num_classes, num_downs=3)

        self.model.to(self.device)

        # We use a combination of DICE-loss and CE-Loss in this example.
        # This proved good in the medical segmentation decathlon.
        self.dice_loss = SoftDiceLoss(
            batch_dice=True)  # Softmax für DICE Loss!

        # weight = torch.tensor([1, 30, 30]).float().to(self.device)
        self.ce_loss = torch.nn.CrossEntropyLoss(
        )  # Kein Softmax für CE Loss -> ist in torch schon mit drin!
        # self.dice_pytorch = dice_pytorch(self.config.num_classes)

        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.config.learning_rate)
        # self.optimizer = optim.SGD(self.model.parameters(), lr=self.config.learning_rate)

        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min')

        # If directory for checkpoint is provided, we load it.
        if self.config.do_load_checkpoint:
            if self.config.checkpoint_dir == '':
                print(
                    'checkpoint_dir is empty, please provide directory to load checkpoint.'
                )
            else:
                self.load_checkpoint(name=self.config.checkpoint_dir,
                                     save_types=("model"))

        self.save_checkpoint(name="checkpoint_start")
        self.elog.print('Experiment set up.')

    def train(self, epoch):
        self.elog.print('=====TRAIN=====')
        self.model.train()

        data = None
        batch_counter = 0
        for data_batch in self.train_data_loader:

            self.optimizer.zero_grad()

            # Shape of data_batch = [1, b, c, w, h]
            # Desired shape = [b, c, w, h]
            # Move data and target to the GPU
            data = data_batch['data'][0].float().to(self.device)
            target = data_batch['seg'][0].long().to(self.device)
            max_value = target.max()
            min_value = target.min()

            pred = self.model(data)
            pred_softmax = F.softmax(
                pred, dim=1
            )  # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.
            pred_image = torch.argmax(pred_softmax, dim=1)

            t = target.squeeze()

            # loss = self.dice_pytorch(outputs=pred_image, labels=target)
            loss = self.ce_loss(pred, target.squeeze()) + self.dice_loss(
                pred_softmax, target.squeeze())
            # loss = self.dice_loss(pred_softmax, target.squeeze())
            loss.backward()
            self.optimizer.step()

            # Some logging and plotting
            if (batch_counter % self.config.plot_freq) == 0:
                self.elog.print('Epoch: %d Loss: %.4f' %
                                (self._epoch_idx, loss))

                self.add_result(
                    value=loss.item(),
                    name='Train_Loss',
                    tag='Loss',
                    counter=epoch +
                    (batch_counter /
                     self.train_data_loader.data_loader.num_batches))

                self.clog.show_image_grid(data.float(),
                                          name="data",
                                          normalize=True,
                                          scale_each=True,
                                          n_iter=epoch)
                self.clog.show_image_grid(target.float(),
                                          name="mask",
                                          title="Mask",
                                          n_iter=epoch)
                self.clog.show_image_grid(torch.argmax(pred.cpu(),
                                                       dim=1,
                                                       keepdim=True),
                                          name="unt_argmax",
                                          title="Unet",
                                          n_iter=epoch)
                self.clog.show_image_grid(pred.cpu()[:, 1:2, ],
                                          name="unt",
                                          normalize=True,
                                          scale_each=True,
                                          n_iter=epoch)

            batch_counter += 1

        assert data is not None, 'data is None. Please check if your dataloader works properly'

    def validate(self, epoch):
        self.elog.print('VALIDATE')
        self.model.eval()

        data = None
        loss_list = []

        with torch.no_grad():
            for data_batch in self.val_data_loader:
                data = data_batch['data'][0].float().to(self.device)
                target = data_batch['seg'][0].long().to(self.device)

                pred = self.model(data)
                pred_softmax = F.softmax(
                    pred
                )  # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.

                loss = self.dice_loss(
                    pred_softmax,
                    target.squeeze())  #self.ce_loss(pred, target.squeeze())
                loss_list.append(loss.item())

        assert data is not None, 'data is None. Please check if your dataloader works properly'
        self.scheduler.step(np.mean(loss_list))

        self.elog.print('Epoch: %d Loss: %.4f' %
                        (self._epoch_idx, np.mean(loss_list)))

        self.add_result(value=np.mean(loss_list),
                        name='Val_Loss',
                        tag='Loss',
                        counter=epoch + 1)

        self.clog.show_image_grid(data.float(),
                                  name="data_val",
                                  normalize=True,
                                  scale_each=True,
                                  n_iter=epoch)
        self.clog.show_image_grid(target.float(),
                                  name="mask_val",
                                  title="Mask",
                                  n_iter=epoch)
        self.clog.show_image_grid(torch.argmax(pred.data.cpu(),
                                               dim=1,
                                               keepdim=True),
                                  name="unt_argmax_val",
                                  title="Unet",
                                  n_iter=epoch)
        self.clog.show_image_grid(pred.data.cpu()[:, 1:2, ],
                                  name="unt_val",
                                  normalize=True,
                                  scale_each=True,
                                  n_iter=epoch)

    def test(self):

        self.model.eval()
        data = None
        dice_array = np.array([0])

        num_of_parameters = sum(p.numel() for p in self.model.parameters()
                                if p.requires_grad)
        print("number of parameters:", num_of_parameters)

        with torch.no_grad():
            for data_batch in self.test_data_loader:
                data = data_batch['data'][0].float().to(self.device)
                target = data_batch['seg'][0].long().to(self.device)
                file_dir = data_batch['fnames']  # 8*tuple (a,)

                pred = self.model(data)
                pred_softmax = F.softmax(
                    pred, dim=1
                )  # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally.
                pred_image = torch.argmax(pred_softmax, dim=1)
                dice_result = dice_pytorch(outputs=pred_image,
                                           labels=target,
                                           N_class=self.config.num_classes)
                dice_loss = self.dice_loss(pred_softmax, target.squeeze())
                ce_loss = self.ce_loss(pred, target.squeeze())
                print('ce_loss:%.4f   dice:%s' %
                      (ce_loss.data, dice_result.data))

                data_image = data.data.cpu().numpy()
                pred_image = pred_image.data.cpu().numpy()
                target_image = target.data.cpu().numpy()

                pred_softmax = pred_softmax.data.cpu().numpy()
                dice_result = dice_result.data.cpu().numpy()

                size = np.shape(dice_result)[0]
                for i in range(size):
                    dice_array = np.concatenate((dice_array, [dice_result[i]]))

                for k in range(self.config.batch_size):
                    ##save the results
                    pred = pred_softmax[k].reshape((3, 64, 64))
                    filename = file_dir[k][0][-8:-4]
                    output_dir = os.path.join(
                        self.config.cross_vali_result_all_dir,
                        'pred_' + self.config.dataset_name + filename)

                    if os.path.exists(output_dir + '.npy'):
                        all_image = np.load(output_dir + '.npy')
                        output = np.concatenate(
                            (data_image[k], target_image[k], pred),
                            axis=0).reshape((1, 5, 64, 64))
                        all_image = np.concatenate((all_image, output), axis=0)
                    else:
                        all_image = np.concatenate(
                            (data_image[k], target_image[k], pred),
                            axis=0).reshape((1, 5, 64, 64))

                    np.save(output_dir, all_image)
                #    saveName = filenames[k]

            dice_array = dice_array[dice_array != 0]
            print("average dice:", np.average(dice_array))
            print('test_data loading finished')

        assert data is not None, 'data is None. Please check if your dataloader works properly'