def validate(self, epoch):

        #set net up for inference
        self.expConfig.net.eval()

        expConfig = self.expConfig
        hausdorffEnabled = (expConfig.LOG_HAUSDORFF_EVERY_K_EPOCHS > 0)
        logHausdorff = hausdorffEnabled and epoch % expConfig.LOG_HAUSDORFF_EVERY_K_EPOCHS == (
            expConfig.LOG_HAUSDORFF_EVERY_K_EPOCHS - 1)

        startTime = time.time()
        with torch.no_grad():
            diceWT, diceTC, diceET = [], [], []
            sensWT, sensTC, sensET = [], [], []
            specWT, specTC, specET = [], [], []
            hdWT, hdTC, hdET = [], [], []
            #buckets = np.zeros(5)
            dice_score = []
            sens_score = []
            spec_score = []
            hd_score = []

            for i, data in enumerate(self.valDataLoader):

                # feed inputs through neural net
                inputs, labels = data
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = expConfig.net(inputs)

                #get dice metrics
                dice_score.append(bratsUtils.dice(outputs, labels))

                #get sensitivity metrics
                sens_score.append(bratsUtils.sensitivity(outputs, labels))

                #get specificity metrics
                spec_score.append(bratsUtils.specificity(outputs, labels))

                #get hausdorff distance
                if logHausdorff:
                    hd95 = bratsUtils.getHd95(outputs, labels)
                    #ignore edgcases in which no distance could be calculated
                    if (hd95 >= 0):
                        hd_score.append(hd95)

        #calculate mean dice scores
        meanDice = np.mean(dice_score)
        if (meanDice > self.bestMeanDice):
            self.bestMeanDice = meanDice
            self.bestMeanDiceEpoch = epoch

        #update moving avg
        self._updateMovingAvg(meanDice, epoch)

        #print metrics
        print("------ Validation epoch {} ------".format(epoch))
        print("Dice Mean: {:.4f} MovingAvg: {:.4f}".format(
            meanDice, self.movingAvg))
        print("Sensitivity: {:.4f}".format(np.mean(sens_score)))
        print("Specificity: {:.4f}".format(np.mean(spec_score)))
        if logHausdorff:
            print("Hausdorff: {:6.2f}".format(np.mean(hd_score)))

        #log metrics
        if self.experiment is not None:
            self.experiment.log_metrics(
                {
                    "wt": meanDiceWT,
                    "tc": meanDiceTC,
                    "et": meanDiceET,
                    "mean": meanDice,
                    "movingAvg": self.movingAvg
                }, "dice", epoch)
            self.experiment.log_metrics(
                {
                    "wt": np.mean(sensWT),
                    "tc": np.mean(sensTC),
                    "et": np.mean(sensET)
                }, "sensitivity", epoch)
            self.experiment.log_metrics(
                {
                    "wt": np.mean(specWT),
                    "tc": np.mean(specTC),
                    "et": np.mean(specET)
                }, "specificity", epoch)
            if logHausdorff:
                self.experiment.log_metrics(
                    {
                        "wt": np.mean(hdWT),
                        "tc:": np.mean(hdTC),
                        "et": np.mean(hdET)
                    }, "hausdorff", epoch)

        #print(buckets)

        #log validation time
        if expConfig.LOG_VALIDATION_TIME:
            print("Time for validation: {:.2f}s".format(time.time() -
                                                        startTime))
        print("--------------------------------")
Ejemplo n.º 2
0
    def validate(self, epoch):

        #set net up for inference
        self.expConfig.net.eval()

        expConfig = self.expConfig
        hausdorffEnabled = (expConfig.LOG_HAUSDORFF_EVERY_K_EPOCHS > 0)
        logHausdorff = hausdorffEnabled and epoch % expConfig.LOG_HAUSDORFF_EVERY_K_EPOCHS == (expConfig.LOG_HAUSDORFF_EVERY_K_EPOCHS - 1)

        startTime = time.time()
        with torch.no_grad():
            diceWT, diceTC, diceET = [], [], []
            sensWT, sensTC, sensET = [], [], []
            specWT, specTC, specET = [], [], []
            hdWT, hdTC, hdET = [], [], []
            #buckets = np.zeros(5)

            for i, data in enumerate(self.valDataLoader):

                # feed inputs through neural net
                inputs, _, labels = data
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = expConfig.net(inputs)

                if expConfig.TRAIN_ORIGINAL_CLASSES:
                    outputsOriginal5 = outputs
                    outputs = torch.argmax(outputs, 1)
                    #hist, _ = np.histogram(outputs.cpu().numpy(), 5, (0, 4))
                    #buckets = buckets + hist
                    wt = bratsUtils.getWTMask(outputs)
                    tc = bratsUtils.getTCMask(outputs)
                    et = bratsUtils.getETMask(outputs)

                    labels = torch.argmax(labels, 1)
                    wtMask = bratsUtils.getWTMask(labels)
                    tcMask = bratsUtils.getTCMask(labels)
                    etMask = bratsUtils.getETMask(labels)

                else:

                    #separate outputs channelwise
                    wt, tc, et = outputs.chunk(3, dim=1)
                    s = wt.shape
                    wt = wt.view(s[0], s[2], s[3], s[4])
                    tc = tc.view(s[0], s[2], s[3], s[4])
                    et = et.view(s[0], s[2], s[3], s[4])

                    wtMask, tcMask, etMask = labels.chunk(3, dim=1)
                    s = wtMask.shape
                    wtMask = wtMask.view(s[0], s[2], s[3], s[4])
                    tcMask = tcMask.view(s[0], s[2], s[3], s[4])
                    etMask = etMask.view(s[0], s[2], s[3], s[4])

                #TODO: add special evaluation metrics for original 5

                #get dice metrics
                diceWT.append(bratsUtils.dice(wt, wtMask))
                diceTC.append(bratsUtils.dice(tc, tcMask))
                diceET.append(bratsUtils.dice(et, etMask))

                #get sensitivity metrics
                sensWT.append(bratsUtils.sensitivity(wt, wtMask))
                sensTC.append(bratsUtils.sensitivity(tc, tcMask))
                sensET.append(bratsUtils.sensitivity(et, etMask))

                #get specificity metrics
                specWT.append(bratsUtils.specificity(wt, wtMask))
                specTC.append(bratsUtils.specificity(tc, tcMask))
                specET.append(bratsUtils.specificity(et, etMask))

                #get hausdorff distance
                if logHausdorff:
                    lists = [hdWT, hdTC, hdET]
                    results = [wt, tc, et]
                    masks = [wtMask, tcMask, etMask]
                    for i in range(3):
                        hd95 = bratsUtils.getHd95(results[i], masks[i])
                        #ignore edgcases in which no distance could be calculated
                        if (hd95 >= 0):
                            lists[i].append(hd95)

        #calculate mean dice scores
        meanDiceWT = np.mean(diceWT)
        meanDiceTC = np.mean(diceTC)
        meanDiceET = np.mean(diceET)
        meanDice = np.mean([meanDiceWT, meanDiceTC, meanDiceET])
        if (meanDice > self.bestMeanDice):
            self.bestMeanDice = meanDice
            self.bestMeanDiceEpoch = epoch

        #update moving avg
        self._updateMovingAvg(meanDice, epoch)

        #print metrics
        print("------ Validation epoch {} ------".format(epoch))
        print("Dice        WT: {:.4f} TC: {:.4f} ET: {:.4f} Mean: {:.4f} MovingAvg: {:.4f}".format(meanDiceWT, meanDiceTC, meanDiceET, meanDice, self.movingAvg))
        print("Sensitivity WT: {:.4f} TC: {:.4f} ET: {:.4f}".format(np.mean(sensWT), np.mean(sensTC), np.mean(sensET)))
        print("Specificity WT: {:.4f} TC: {:.4f} ET: {:.4f}".format(np.mean(specWT), np.mean(specTC), np.mean(specET)))
        self.log.write("------ Validation epoch {} ------".format(epoch))
        self.log.write("Dice        WT: {:.4f} TC: {:.4f} ET: {:.4f} Mean: {:.4f} MovingAvg: {:.4f}".format(meanDiceWT, meanDiceTC, meanDiceET, meanDice, self.movingAvg))
        if logHausdorff:
            print("Hausdorff   WT: {:6.2f} TC: {:6.2f} ET: {:6.2f}".format(np.mean(hdWT), np.mean(hdTC), np.mean(hdET)))
            self.log("Hausdorff   WT: {:6.2f} TC: {:6.2f} ET: {:6.2f}".format(np.mean(hdWT), np.mean(hdTC), np.mean(hdET)))

        #log metrics
        if self.experiment is not None:
            self.experiment.log_metrics({"wt": meanDiceWT, "tc": meanDiceTC, "et":  meanDiceET, "mean": meanDice, "movingAvg": self.movingAvg}, "dice", epoch)
            self.experiment.log_metrics({"wt": np.mean(sensWT), "tc": np.mean(sensTC), "et": np.mean(sensET)}, "sensitivity", epoch)
            self.experiment.log_metrics({"wt": np.mean(specWT), "tc": np.mean(specTC), "et": np.mean(specET)}, "specificity", epoch)
            if logHausdorff:
                self.experiment.log_metrics({"wt": np.mean(hdWT), "tc:": np.mean(hdTC), "et": np.mean(hdET)}, "hausdorff", epoch)

        #print(buckets)

        #log validation time
        if expConfig.LOG_VALIDATION_TIME:
            print("Time for validation: {:.2f}s".format(time.time() - startTime))
        print("--------------------------------")