Example #1
0
    def validation(self, epoch):
        self.net.eval()

        acc = 0.  # Accuracy
        SE = 0.  # Sensit ivity (Recall)
        SP = 0.  # Specificity
        PC = 0.  # Precision
        F1 = 0.  # F1 Score
        JS = 0.  # Jaccard Similarity
        DC = 0.  # Dice Coefficient
        length = 0

        for i, (imgs, probs, gts) in enumerate(self.val_loader):
            imgs = imgs.to(self.device)
            probs = probs.to(self.device)
            gts = gts.round().long().to(self.device)

            outputs = self.net(imgs,probs)



            # weight = np.array(
            #     [0., 100., 100., 100., 50., 50., 80., 80., 50., 80., 80., 80., 50., 50., 70., 70., 70., 70.,
            #      60., 60., 100., 100., 100., ])

            # weight = torch.tensor(
            #     [1., 100., 100., 50., 80., 50., 80., 80., 50., 70., 70.,
            #      60., 100., 100., ]).to(self.device)

            ious = MulticlassJaccardLoss(classes=list(range(23)))(outputs, gts.reshape(-1, 256, 256))

            # ious = jaccard_similarity_score(gts.detach().cpu().squeeze().numpy().reshape(-1)
            #            , outputs.detach().cpu().squeeze().argmax(dim=1).numpy().reshape(-1))*imgs.size(0)
            DC += ious
            length += imgs.size(0)


        DC = DC / length

        score = DC

        print('[Validation] DC: %.4f' % (
            DC))

        # save the best net model
        if score < self.best_score:  # 算的其实是loss 保小的
            self.best_score = score
            self.best_epoch = epoch
            print('Best %s model score: %.4f'%(self.model_type, self.best_score))
            torch.save(self.net.state_dict(), self.net_path)
Example #2
0
    def test(self):
        del self.net
        self.build_model()
        self.net.load_state_dict(torch.load(self.net_path))

        self.net.eval()

        DC = 0.  # Dice Coefficient
        length = 0

        for i, (imgs, probs, gts) in enumerate(self.test_loader):
            imgs = imgs.to(self.device)
            probs = probs.to(self.device)
            gts = gts.round().long().to(self.device)

            outputs = self.net(imgs, probs)

            ious = MulticlassJaccardLoss(classes=list(range(23)))(outputs,
                                                                  gts.reshape(
                                                                      -1, 256,
                                                                      256))

            # ious = jaccard_similarity_score(gts.detach().cpu().squeeze().numpy().reshape(-1)
            #            , outputs.detach().cpu().squeeze().argmax(dim=1).numpy().reshape(-1))*imgs.size(0)
            DC += ious
            length += imgs.size(0)

            # weight = np.array(
            #     [0., 100., 100., 100., 50., 50., 80., 80., 50., 80., 80., 80., 50., 50., 70., 70., 70., 70.,
            #      60., 60., 100., 100., 100., ])
            # ious = IoU(gts.detach().cpu().squeeze().numpy().reshape(-1),
            #            outputs.detach().cpu().squeeze().argmax(dim=0).numpy().reshape(-1), num_classes=14) * imgs.size(
            #     0)
            # DC += np.array(ious[1:]).mean()
            length += imgs.size(0)

        DC = DC / length
        score = DC

        f = open(os.path.join(self.csv_path, 'result.csv'),
                 'a',
                 encoding='utf8',
                 newline='')
        wr = csv.writer(f)
        wr.writerow([
            self.model_type, DC, self.lr, self.best_epoch, self.num_epochs,
            self.num_epochs_decay, self.augmentation_prob, self.batch_size,
            self.comment
        ])
        f.close()
Example #3
0
    def train(self, epoch):
        self.net.train(True)

        # Decay learning rate
        if (epoch + 1) > (self.num_epochs - self.num_epochs_decay):
            self.decayed_lr -= (self.lr / float(self.num_epochs_decay))
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.decayed_lr
            print('epoch{}: Decay learning rate to lr: {}.'.format(epoch, self.decayed_lr))

        epoch_loss = 0

        acc = 0.  # Accuracy
        SE = 0.  # Sensitivity (Recall)
        SP = 0.  # Specificity
        PC = 0.  # Precision
        F1 = 0.  # F1 Score
        JS = 0.  # Jaccard Similarity
        DC = 0.  # Dice Coefficient
        length = 0

        for i, (img, prob , gt) in enumerate(tqdm(self.train_loader)):
            img = img.to(self.device)
            prob = prob.to(self.device)
            gt = gt.round().long().to(self.device)

            self.optimizer.zero_grad()

            outputs = self.net(img, prob)

            # make sure shapes are the same by flattening them

            # weight = torch.tensor([1.,100.,100.,100.,50.,50.,80.,80.,50.,80.,80.,80.,50.,50.,70.,70.,70.,70.,
            #                        60.,60.,100.,100.,100.,]).to(self.device)
            #
            # weight = torch.tensor(
            #     [1., 100., 130., 1000., 700., 900., 30., 1000., 60., 200., 100., 300., 100., 55.]).to(self.device)

            weight = torch.tensor([
                1., 100., 130., 130., 1000., 1000., 700.,700.,900., 30.,30.,1000.,60.,60.,200.,200.,
                100.,100.,300.,300.,100.,55.,55.
            ]).to(self.device)
            # weight = torch.tensor(
            #     [1., 100., 100., 50., 80., 50., 80., 80., 50., 70., 70.,
            #      60., 100., 100., ]).to(self.device)

            ce_loss = nn.CrossEntropyLoss(weight=weight,reduction='mean')(outputs, gt.reshape(-1,256,256))
            #dice_loss = DiceLoss(sigmoid_normalization=False)(outputs, expand_as_one_hot(gt.reshape(-1,128,128),14))
            dice_loss = MulticlassJaccardLoss(classes=list(range(14)))(outputs, gt.reshape(-1,256,256))
            # bce_loss = torch.nn.BCEWithLogitsLoss()(outputs, gts)
            # focal_loss = FocalLoss(alpha=0.8,gamma=0.5)(outputs, gts)
            # focal_loss = FocalLoss2d(gamma=0.5)(outputs, gt.reshape(-1,256,256))

            loss =  ce_loss +dice_loss
            #loss = focal_loss + dice_loss
            epoch_loss += loss.item() * img.size(0)  # because reduction = 'mean'
            loss.backward()
            self.optimizer.step()


            # DC += iou(outputs.detach().cpu().squeeze().argmax(dim=1),gts.detach().cpu(),n_classes=14)*imgs.size(0)
            length += img.size(0)



        # DC = DC / length
        # epoch_loss = epoch_loss/length
        # # Print the log info
        # print(
        #     'Epoch [%d/%d], Loss: %.4f, \n[Training] DC: %.4f' % (
        #         epoch + 1, self.num_epochs,
        #         epoch_loss,
        #          DC))
        print('EPOCH{}, Loss{}'.format(epoch,epoch_loss/length))