コード例 #1
0
    def predict(self, im):
        w, h, c = moil.getWidthHeightChannels(im)

        if w != self.colDim or h != self.rowDim or c != self.channels:
            im = cv2.resize(im, (self.colDim, self.rowDim))[:, :]
        im = self.prepareImage(im)
        im = self.model.predict(im)
        im = moil.convertImageNetOutput(im)
        return im
コード例 #2
0
    def check_performance(self,
                          validate_generator,
                          times=1,
                          metrics=['distance', 'youden', 'jaccard', 'dice']):
        for i in range(times):
            pic = validate_generator.next()

            true = pic[1][0]

            pred = self.model.predict(pic[0][0].reshape(
                1, self.rowDim, self.colDim, self.channels))
            pred = moil.convertImageNetOutput(pred)
            true = moil.convertImageNetOutput(true)

            met.customMetric(pred, true, metrics=metrics)
            x = []

            x.append(pic[0][0].reshape(
                (self.rowDim, self.colDim, self.channels)))
            x.append(true)
            x.append(pred)

            if self.show_function != None:
                self.show_function(x)
コード例 #3
0
    def validate(self,
                 pathForce=None,
                 validateMode=0,
                 preprocessFunc=lambda x: x,
                 draw=True,
                 onlyWithMetric=False,
                 onlyWithoutMetric=False,
                 sumTimes=None,
                 metrics=['distance', 'youden', 'jaccard', 'dice'],
                 validTimes=1,
                 weightsTimesValids=None,
                 validName=''):
        avgs, globals = (0, 0)
        for i in range(validTimes):
            if weightsTimesValids is not None:
                self.constantVar = i * weightsTimesValids
                self.load_weights()
            sum = [0] * len(metrics)
            confusion_matrix = [0] * 4
            globalCount = False
            for metr in metrics:
                if 'global' in metr:
                    globalCount = True
            times = 0
            visited_path = {}
            while True:
                if pathForce is None:
                    path = self.validate_path_provider_func(
                        self.validate_start_path, visited_path)
                    visited_path[path] = times

                else:
                    path = pathForce
                if path is None:
                    break
                if not os.path.exists(path):
                    continue
                images = os.listdir(path)
                for imp in images:  # len(os.listdir(path)) - 2):

                    true_path = path + 'mask/'
                    if not os.path.exists(os.path.join(path, imp)):
                        continue
                    if onlyWithMetric and not os.path.exists(
                            os.path.join(true_path, imp)):
                        continue
                    else:
                        if onlyWithoutMetric and os.path.exists(
                                os.path.join(true_path, imp)):
                            continue

                    im = self.read_func(name=imp,
                                        extension='',
                                        path=path,
                                        target_size=(self.colDim, self.rowDim),
                                        mode=0)
                    imgX, img = self.prepareImage(im, retboth=True)
                    pred = self.model.predict(imgX)

                    pred = moil.convertImageNetOutput(pred)

                    toDraw = im if draw else None

                    x = [im, pred, img]
                    if os.path.exists(os.path.join(true_path, imp)):
                        true = self.read_func(name=imp,
                                              extension='',
                                              path=true_path,
                                              target_size=(self.colDim,
                                                           self.rowDim))

                        true = true.reshape(
                            (self.rowDim, self.colDim, self.out_channels))
                        x.append(true)
                        results = met.customMetric(pred,
                                                   true,
                                                   toDraw=toDraw,
                                                   metrics=metrics,
                                                   globalCount=globalCount)
                        sum = list(map(add, sum, results[0]))

                        confusion_matrix = list(
                            map(add, confusion_matrix, results[1]))

                        times += 1
                        if sumTimes is not None and times >= sumTimes:
                            break
                    else:
                        met.draw(pred, toDraw)

                    if sumTimes is None:
                        self.show_function(x)

            avgs = [x / times for x in sum]
            strgSum = ''
            strgAvgs = ''
            for val in sum:
                strgSum += str(val) + ', '
            for val in avgs:
                strgAvgs += str(val) + ', '

            globals = []
            if globalCount:
                globals = met.globals(confusion_matrix)
                print("Global Jaccard: " + str(globals[0]) +
                      ", Global Dice: " + str(globals[1]))
            print("Times: " + str(times) + ", sums: " + strgSum +
                  "Average metrics: " + strgAvgs)
            self.validate_to_csv(metrics, avgs + globals, validName)
        return avgs + globals