Exemplo n.º 1
0
    def valid(self, epoch, val_loader):
        self.G.eval()
        with torch.no_grad():
            confusions_sum = [0, 0, 0, 0]
            dice, jss = 0, 0
            # F1(per dataset) , JSS, Dice(per image)
            for i, (input_, target_, _) in enumerate(val_loader):
                _, output_, target_ = self._test_foward(input_, target_)

                target_np = utils.slice_threshold(target_, 0.5)
                output_np = utils.slice_threshold(output_, 0.5)
                target_f, output_f = target_np.flatten(), output_np.flatten()
                
                # element wise sum
                confusions     =  confusion_matrix(target_f, output_f).ravel()
                confusions_sum += confusions
                score = utils.get_roc_pr(*confusions)[-2:]
                dice += score[0]
                jss  += score[1]

            f1   = utils.get_roc_pr(*confusions_sum)[-2]
            jss  /= len(val_loader.dataset) 
            dice /= len(val_loader.dataset)

            if f1 > self.best_metric:
                self.best_metric = f1
                self.save(epoch)

            self.logger.write("[Val] epoch:%d f1:%f jss:%f dice:%f"%(epoch, f1, jss, dice))
Exemplo n.º 2
0
    def pre_train(self, train_loader, val_loader):
        print("PretrainStart")
        cnt, f1 = 0, 0
        while f1 < 0.85:
            # Model Init
            self._init_model()

            for epoch in range(3):
                for i, (input_, target_, _) in enumerate(train_loader):
                    self.G.train()
                    input_, target_ = input_.to(self.torch_device), target_.to(self.torch_device)
                    output_ = self.G(input_)
                    recon_loss = self.recon_loss(output_, target_)

                    self.optim.zero_grad()
                    recon_loss.backward()
                    self.optim.step()

            self.G.eval()
            with torch.no_grad():
                confusions_sum = [0, 0, 0, 0]
                # F1(per dataset) , JSS, Dice(per image)
                for i, (input_, target_, _) in enumerate(val_loader):
                    _, output_, target_ = self._test_foward(input_, target_)

                    target_np = utils.slice_threshold(target_, 0.5)
                    output_np = utils.slice_threshold(output_, 0.5)
                    target_f, output_f = target_np.flatten(), output_np.flatten()

                    # element wise sum
                    confusions_sum += confusion_matrix(target_f, output_f).ravel()

                f1 = utils.get_roc_pr(*confusions_sum)[-2]
            cnt += 1
            print("[Cnt:%d] val_f1:%f" % (cnt, f1))
Exemplo n.º 3
0
    def test(self, test_loader):
        print("\nStart Test")
        self.G.eval()
        with torch.no_grad():
            y_true = np.array([])
            y_pred = np.array([])
            for i, (input_, target_, _) in enumerate(test_loader):
                input_, output_, target_ = self._test_foward(input_, target_)
                target_np = utils.slice_threshold(target_, 0.5)

                y_true = np.concatenate([y_true, target_np.flatten()], axis=0)
                y_pred = np.concatenate([y_pred, output_.flatten()],   axis=0)

            roc_values = np.array(roc_curve(y_true, y_pred))
            pr_values  = np.array(precision_recall_curve(y_true, y_pred))

            f1_best, th_best = -1, 0
            for precision, recall, threshold in zip(*pr_values):
                f1 = 2 * precision * recall / (precision + recall) if (precision + recall) != 0 else 1
                if f1 > f1_best and f1 != 1:
                    f1_best = f1
                    th_best = threshold

            np.save("%s/test_roc_values.npy"%(self.save_path), roc_values)
            np.save("%s/test_pr_values.npy"%(self.save_path),  pr_values)            

            confusions, cnt = [0, 0, 0, 0], 0
            f1_sum = 0
            for i, (input_, target_, f_name) in enumerate(test_loader):
                input_, output_, target_  = self._test_foward(input_, target_)

                target_np = utils.slice_threshold(target_, 0.5)
                output_np = utils.slice_threshold(output_, th_best)
                for batch_idx in range(0, input_.shape[0]):
                    target_b = target_np[batch_idx, 0, :, :]
                    output_b = output_np[batch_idx, 0, :, :]
                    target_f, output_f = target_b.flatten(), output_b.flatten()

                    save_path = "%s/%s"%(self.save_path, f_name[batch_idx][:-4])
                    input_norm = input_[batch_idx, 0, :, :]
                    input_norm = (input_norm - input_norm.min()) / (input_norm.max() - input_norm.min())
                    utils.image_save(save_path, input_norm, target_b, output_b)

                    confusion = confusion_matrix(target_f, output_f).ravel()
                    confusions += confusion
                    scores = utils.get_roc_pr(*confusion)
                    self.logger.will_write("[Save] fname:%s sen:%f spec:%f prec:%f rec:%f f1:%f jss:%f"%(f_name[batch_idx][:-4], *scores))

                    f1_sum += scores[-2] # image per f1
                    cnt += 1

            scores = utils.get_roc_pr(*confusions)
        self.logger.write("Best Threshold:%f sen:%f spec:%f prec:%f rec:%f f1:%f jss:%f dice:%f"%(th_best, *scores, f1_sum / float(cnt)))
        print("End Test\n")
Exemplo n.º 4
0
    def test(self, test_loader, val_loader):
        print("\nStart Test")
        self.load()
        self.G.eval()
        with torch.no_grad():
            y_true = np.array([])
            y_pred = np.array([])

            confusions, cnt = [0, 0, 0, 0], 0
            f1_sum = 0
            for i, (input_, target_, f_name) in enumerate(test_loader):
                input_, output_, target_ = self._test_foward(input_, target_)

                target_np = utils.slice_threshold(target_, 0.5)
                output_np = utils.slice_threshold(output_, 0.5)

                y_true = np.concatenate([y_true, target_np.flatten()], axis=0)
                y_pred = np.concatenate([y_pred, output_.flatten()],   axis=0)

                for batch_idx in range(0, input_.shape[0]):
                    target_b = target_np[batch_idx, 0, :, :]
                    output_b = output_np[batch_idx, 0, :, :]
                    target_f, output_f = target_b.flatten(), output_b.flatten()

                    save_path = "%s/fold%s/%s" % (self.save_path, self.fold, f_name[batch_idx][:-4])
                    input_norm = input_[batch_idx, 0, :, :]
                    input_norm = (input_norm - input_norm.min()) / (input_norm.max() - input_norm.min())
                    utils.image_save(save_path, input_norm, target_b, output_b)

                    confusion = confusion_matrix(target_f, output_f).ravel()
                    confusions += confusion
                    scores = utils.get_roc_pr(*confusion)
                    self.logger.will_write("[Save] fname:%s sen:%f spec:%f prec:%f rec:%f f1:%f jss:%f" % (f_name[batch_idx][:-4], *scores))

                    f1_sum += scores[-2] # image per f1
                    cnt += 1

            roc_values = np.array(roc_curve(y_true, y_pred))
            pr_values = np.array(precision_recall_curve(y_true, y_pred))

            np.save("%s/fold%s/test_roc_values.npy" % (self.save_path, self.fold), roc_values)
            np.save("%s/fold%s/test_pr_values.npy" % (self.save_path, self.fold),  pr_values)

            tn, fp, fn, tp = confusions
            self.logger.will_write("[Save] tn:%d fp:%d fn:%d tp:%d"%(tn, fp, fn, tp))
            precision, recall = tp / (tp + fp), tp / (fp + fn) 
            f05 = (5 * precision * recall) / (precision + (4 * recall))
            f2 = (5 * precision * recall) / ((4 * precision) + recall)
            scores = utils.get_roc_pr(*confusions)
            roc_auc = roc_auc_score(y_true, y_pred)
            pr_auc = auc(pr_values[0], pr_values[1], reorder=True)

        self.logger.write("Best Threshold:%f sen:%f spec:%f prec:%f rec:%f f1:%f jss:%f dice:%f f05:%f f2:%f roc:%f pr:%f" % (0.5, *scores, f1_sum / float(cnt), f05, f2, roc_auc, pr_auc))
        print("End Test\n")
    def valid(self, epoch, val_loader):
        self.G.eval()
        with torch.no_grad():
            y_true = np.array([])
            y_pred = np.array([])
            for i, (input_, target_, _) in enumerate(val_loader):
                input_, output_, target_ = self._test_foward(input_, target_)
                target_np = utils.slice_threshold(target_, 0.5)

                y_true = np.concatenate([y_true, target_np.flatten()], axis=0)
                y_pred = np.concatenate([y_pred, output_.flatten()], axis=0)

            roc_values = np.array(roc_curve(y_true, y_pred))
            pr_values = np.array(precision_recall_curve(y_true, y_pred))

            f1_best, th_best = -1, 0
            for precision, recall, threshold in zip(*pr_values):
                f1 = 2 * precision * recall / (precision + recall) if (
                    precision + recall) != 0 else 1
                if f1 > f1_best and f1 != 1:
                    f1_best = f1
                    th_best = threshold
                    # too much spend time
                    # np.save("%s/valid_best_roc_values.npy"%(self.save_path), roc_values)
                    # np.save("%s/valid_best_pr_values.npy"%(self.save_path),  pr_values)

            confusions_sum = [0, 0, 0, 0]
            for i, (input_, target_, _) in enumerate(val_loader):
                _, output_, target_ = self._test_foward(input_, target_)

                target_np = utils.slice_threshold(target_, 0.5)
                output_np = utils.slice_threshold(output_, th_best)
                target_f, output_f = target_np.flatten(), output_np.flatten()

                # element wise sum
                confusions = confusion_matrix(target_f, output_f).ravel()
                confusions_sum += confusions

            *_, total_f1, jss = utils.get_roc_pr(*confusions_sum)
            if total_f1 > self.best_metric:
                self.best_metric = total_f1
                self.th_best = th_best
                self.save(epoch)

            self.logger.write(
                "[Val] epoch:%d th:%f f1_best:%f f1_total:%f jss:%f" %
                (epoch, th_best, f1_best, total_f1, jss))
Exemplo n.º 6
0
    def valid(self, epoch, val_loader):
        self.G.eval()
        with torch.no_grad():
            # (tn, fp, fn, tp)
            confusions_sum = [0, 0, 0, 0]
            # F1(per dataset) , JSS, Dice(per image)
            dice, jss, cnt = 0, 0, 1

            for i, (input_, target_, _) in enumerate(val_loader):
                _, output_, target_ = self._test_foward(input_, target_)

                target_np = utils.slice_threshold(target_, 0.5)
                output_np = utils.slice_threshold(output_, 0.5)
                for b in range(target_.shape[0]):
                    target_f, output_f = target_np[b].flatten(
                    ), output_np[b].flatten()

                    # element wise sum
                    confusions = confusion_matrix(target_f, output_f).ravel()
                    confusions_sum += confusions
                    score = utils.get_roc_pr(*confusions)
                    dice += score[-2]
                    jss += score[-1]
                    cnt += 1

            tn, fp, fn, tp = confusions_sum
            precision, recall = tp / (tp + fp), tp / (fp + fn)
            f05 = (5 * precision * recall) / (precision + (4 * recall))
            f2 = (5 * precision * recall) / ((4 * precision) + recall)
            jss /= cnt
            dice /= cnt

            metric = f05 + jss + dice
            if metric > self.best_metric:
                self.best_metric = metric
                self.save(epoch)

            self.logger.write("[Val] epoch:%d f2:%f jss:%f dice:%f" %
                              (epoch, f2, jss, dice))
    def inference(self, train_loader, valid_loader, test_loader):
        raise NotImplementedError()

        print("Start Infernce")
        os.mkdir("%s/infer" % (self.save_path))
        os.mkdir("%s/infer/Train" % (self.save_path))
        os.mkdir("%s/infer/Valid" % (self.save_path))
        os.mkdir("%s/infer/Test" % (self.save_path))
        loaders = [("Train", train_loader), ("Valid", valid_loader),
                   ("Test", test_loader)]
        self.G.eval()
        with torch.no_grad():
            for path, loader in loaders:
                metric_avg, dice_avg = 0.0, 0.0
                for i, (input_, target_, f_name) in enumerate(loader):
                    input_, output_, target_ = self._test_foward(
                        input_, target_)

                    input_np = input_.type(
                        torch.FloatTensor).numpy()[0, self.z_idx, :, :]
                    target_np = utils.slice_threshold(target_[0, 0, :, :], 0.5)
                    output_np = utils.slice_threshold(output_[0, 0, :, :],
                                                      self.threshold)

                    jss = self.metric(output_np, target_np)
                    dice = utils.dice(output_np, target_np)

                    if dice != 1.0:
                        save_path = "%s/infer/%s/%s" % (self.save_path, path,
                                                        f_name[0][:-4])
                        input_np = (input_np - input_np.min()) / (
                            input_np.max() - input_np.min())
                        utils.image_save(save_path, input_np, target_np,
                                         output_np)

                    metric_avg += jss
                    dice_avg += dice
Exemplo n.º 8
0
    def get_best_th(self, loader):
        y_true = np.array([])
        y_pred = np.array([])
        for i, (input_, target_, _) in enumerate(loader):
            input_, output_, target_ = self._test_foward(input_, target_)
            target_np = utils.slice_threshold(target_, 0.5)

            y_true = np.concatenate([y_true, target_np.flatten()], axis=0)
            y_pred = np.concatenate([y_pred, output_.flatten()],   axis=0)

        pr_values = np.array(precision_recall_curve(y_true, y_pred))

        # TODO : F0.5 Score
        f_best, th_best = -1, 0
        for precision, recall, threshold in zip(*pr_values):
            f05 = (5 * precision * recall) / (precision + (4 * recall))
            if f05 > f_best:
                f_best = f05
                th_best = threshold
        return f_best, th_best