コード例 #1
0
def data_analysis():
    import time
    from utils.visualizer import Visualizer
    vis = Visualizer(env='{}'.format('v99_99_debug'), port=31434)
    dataroot = '/root/workspace/2018_US_project/ultrsound_zhy'
    train_loader, test_loader = ultraDataLoader(dataroot, 1).data_load()

    for i, (sample, target, name) in enumerate(train_loader):
        # sample.shape = (B, 1, 480, 480)
        vis.images(sample)
        vis.plot_single_win(dict(max=sample.max(),
                                 mean=sample.mean(),
                                 min=sample.min()),
                            win='sample')
        vis.plot_multi_win(dict(target=target.item()))
        time.sleep(2)
コード例 #2
0
class RunMyModel(object):
    def __init__(self):
        args = ParserArgs().get_args()
        cuda_visible(args.gpu)

        cudnn.benchmark = True

        self.vis = Visualizer(env='{}'.format(args.version),
                              port=args.port,
                              server=args.vis_server)

        if args.data_modality == 'fundus':
            # IDRiD dataset for segmentation
            # image, mask, image_name_item

            # iSee dataset for classification
            # image, image_name
            self.train_loader, self.normal_test_loader, \
            self.amd_fundus_loader, self.myopia_fundus_loader, \
            self.glaucoma_fundus_loader, self.dr_fundus_loader = \
                NewClsFundusDataloader(data_root=self.args.isee_fundus_root,
                                       batch=self.args.batch,
                                       scale=self.args.scale).data_load()

        else:
            # Challenge OCT dataset for classification
            # image, [case_name, image_name]
            self.train_loader, self.normal_test_loader, self.oct_abnormal_loader = OCT_ClsDataloader(
                data_root=args.challenge_oct,
                batch=args.batch,
                scale=args.scale).data_load()

        print_args(args)
        self.args = args
        self.new_lr = self.args.lr
        self.model = PNetModel(args)

        if args.predict:
            self.test_acc()
        else:
            self.train_val()

    def train_val(self):
        # general metrics
        self.best_auc = 0
        self.is_best = False
        # self.total_auc_top10 = AverageMeter()
        self.total_auc_last10 = LastAvgMeter(length=10)
        self.acc_last10 = LastAvgMeter(length=10)

        # metrics for iSee
        self.myopia_auc_last10 = LastAvgMeter(length=10)
        self.amd_auc_last10 = LastAvgMeter(length=10)
        self.glaucoma_auc_last10 = LastAvgMeter(length=10)
        self.dr_auc_last10 = LastAvgMeter(length=10)

        for epoch in range(self.args.start_epoch, self.args.n_epochs):
            if self.args.data_modality == 'fundus':
                # total: 1000
                adjust_lr_epoch_list = [40, 80, 160, 240]
            else:
                # total: 180
                adjust_lr_epoch_list = [20, 40, 80, 120]
            _ = adjust_lr(self.args.lr, self.model.optimizer_G, epoch,
                          adjust_lr_epoch_list)
            new_lr = adjust_lr(self.args.lr, self.model.optimizer_D, epoch,
                               adjust_lr_epoch_list)
            self.new_lr = min(new_lr, self.new_lr)

            self.epoch = epoch
            self.train()
            # last 80 epoch, validate with freq
            if epoch > self.args.validate_start_epoch \
                    and (epoch % self.args.validate_freq == 0
                         or epoch > (self.args.n_epochs - self.args.validate_each_epoch)):
                self.validate_cls()

            print('\n', '*' * 10, 'Program Information', '*' * 10)
            print('Node: {}'.format(self.args.node))
            print('GPU: {}'.format(self.args.gpu))
            print('Version: {}\n'.format(self.args.version))

    def train(self):
        self.model.train()
        prev_time = time.time()
        train_loader = self.train_loader

        for i, (
                image,
                _,
        ) in enumerate(train_loader):
            image = image.cuda(non_blocking=True)

            # train
            seg_mask, image_rec, gen_loss, dis_loss, logs = \
                self.model.process(image)

            # backward
            self.model.backward(gen_loss, dis_loss)

            # --------------
            #  Log Progress
            # --------------
            # Determine approximate time left
            batches_done = self.epoch * train_loader.__len__() + i
            batches_left = self.args.n_epochs * train_loader.__len__(
            ) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] ETA: %s"
                % (self.epoch, self.args.n_epochs, i, train_loader.__len__(),
                   dis_loss.item(), gen_loss.item(), time_left))

            # --------------
            #  Visdom
            # --------------
            if i % self.args.vis_freq == 0:
                image = image[:self.args.vis_batch]

                if self.args.data_modality == 'oct':
                    # BCWH -> BWH, torch.max in Channel dimension
                    seg_mask = torch.argmax(seg_mask[:self.args.vis_batch],
                                            dim=1).float()
                    # BWH -> B1WH, 11 -> 1
                    seg_mask = (seg_mask.unsqueeze(dim=1) / 11).clamp(0, 1)

                else:
                    seg_mask = seg_mask[:self.args.vis_batch].clamp(0, 1)
                image_rec = image_rec[:self.args.vis_batch].clamp(0, 1)
                image_diff = torch.abs(image - image_rec)

                vim_images = torch.cat(
                    [image, seg_mask, image_rec, image_diff], dim=0)
                self.vis.images(vim_images,
                                win_name='train',
                                nrow=self.args.vis_batch)

                output_save = os.path.join(self.args.output_root,
                                           self.args.project, 'output_v1_0812',
                                           self.args.version, 'train')
                if not os.path.exists(output_save):
                    os.makedirs(output_save)
                tv.utils.save_image(vim_images,
                                    os.path.join(output_save,
                                                 '{}.png'.format(i)),
                                    nrow=4)

            if i + 1 == train_loader.__len__():
                self.vis.plot_multi_win(
                    dict(dis_loss=dis_loss.item(), lr=self.new_lr))
                self.vis.plot_single_win(dict(
                    gen_loss=gen_loss.item(),
                    gen_l1_loss=logs['gen_l1_loss'].item(),
                    gen_fm_loss=logs['gen_fm_loss'].item(),
                    gen_gan_loss=logs['gen_gan_loss'].item(),
                    gen_content_loss=logs['gen_content_loss'].item(),
                    gen_style_loss=logs['gen_style_loss'].item()),
                                         win='gen_loss')

    def validate_cls(self):
        # self.model.eval()
        self.model.train()

        with torch.no_grad():
            """
            Difference: abnormal dataloader and abnormal_list
            """
            if self.args.data_modality == 'fundus':
                myopia_gt_list, myopia_pred_list = self.forward_cls_dataloader(
                    loader=self.myopia_fundus_loader, is_disease=True)

                amd_gt_list, amd_pred_list = self.forward_cls_dataloader(
                    loader=self.amd_fundus_loader, is_disease=True)
                glaucoma_gt_list, glaucoma_pred_list = self.forward_cls_dataloader(
                    loader=self.glaucoma_fundus_loader, is_disease=True)
                dr_gt_list, dr_pred_list = self.forward_cls_dataloader(
                    loader=self.dr_fundus_loader, is_disease=True)
            else:
                abnormal_gt_list, abnormal_pred_list = self.forward_cls_dataloader(
                    loader=self.oct_abnormal_loader, is_disease=True)

            _, normal_train_pred_list = self.forward_cls_dataloader(
                loader=self.train_loader, is_disease=False)
            normal_gt_list, normal_pred_list = self.forward_cls_dataloader(
                loader=self.normal_test_loader, is_disease=False)
            """
            computer metrics
            """
            # Difference: total_true_list and total_pred_list
            if self.args.data_modality == 'fundus':
                # test metrics for myopia
                m_true_list = myopia_gt_list + normal_gt_list
                m_pred_list = myopia_pred_list + normal_pred_list
                # test metrics for amd
                a_true_list = amd_gt_list + normal_gt_list
                a_pred_list = amd_pred_list + normal_pred_list
                # test metrics for glaucoma
                g_true_list = glaucoma_gt_list + normal_gt_list
                g_pred_list = glaucoma_pred_list + normal_pred_list
                # test metrics for amd
                d_true_list = dr_gt_list + normal_gt_list
                d_pred_list = dr_pred_list + normal_pred_list
                # total
                total_true_list = a_true_list + myopia_gt_list + glaucoma_gt_list + dr_gt_list
                total_pred_list = a_pred_list + myopia_pred_list + glaucoma_pred_list + dr_pred_list

                # fpr, tpr, thresholds = metrics.roc_curve()
                myopia_auc = metrics.roc_auc_score(np.array(m_true_list),
                                                   np.array(m_pred_list))
                amd_auc = metrics.roc_auc_score(np.array(a_true_list),
                                                np.array(a_pred_list))
                glaucoma_auc = metrics.roc_auc_score(np.array(g_true_list),
                                                     np.array(g_pred_list))
                dr_auc = metrics.roc_auc_score(np.array(d_true_list),
                                               np.array(d_pred_list))
            else:
                total_true_list = abnormal_gt_list + normal_gt_list
                total_pred_list = abnormal_pred_list + normal_pred_list

            # get roc curve and compute the auc
            fpr, tpr, thresholds = metrics.roc_curve(np.array(total_true_list),
                                                     np.array(total_pred_list))
            total_auc = metrics.auc(fpr, tpr)
            """
            compute thereshold, and then compute the accuracy
            """
            percentage = 0.75
            _threshold_for_acc = sorted(normal_train_pred_list)[int(
                len(normal_train_pred_list) * percentage)]
            normal_cls_pred_list = [(0 if i < _threshold_for_acc else 1)
                                    for i in normal_pred_list]
            amd_cls_pred_list = [(0 if i < _threshold_for_acc else 1)
                                 for i in amd_pred_list]
            myopia_cls_pred_list = [(0 if i < _threshold_for_acc else 1)
                                    for i in myopia_pred_list]
            glaucoma_cls_pred_list = [(0 if i < _threshold_for_acc else 1)
                                      for i in glaucoma_pred_list]
            dr_cls_pred_list = [(0 if i < _threshold_for_acc else 1)
                                for i in dr_pred_list]

            # acc, sensitivity and specifity
            def calcu_cls_acc(pred_list, gt_list):
                cls_pred_list = normal_cls_pred_list + pred_list
                gt_list = normal_gt_list + gt_list
                acc = metrics.accuracy_score(y_true=gt_list,
                                             y_pred=cls_pred_list)
                tn, fp, fn, tp = metrics.confusion_matrix(
                    y_true=gt_list, y_pred=cls_pred_list).ravel()
                sen = tp / (tp + fn + 1e-7)
                spe = tn / (tn + fp + 1e-7)
                return acc, sen, spe

            total_acc, total_sen, total_spe = calcu_cls_acc(
                amd_cls_pred_list + myopia_cls_pred_list,
                amd_gt_list + myopia_gt_list)
            amd_acc, amd_sen, amd_spe = calcu_cls_acc(amd_cls_pred_list,
                                                      amd_gt_list)
            myopia_acc, myopia_sen, myopia_spe = calcu_cls_acc(
                myopia_cls_pred_list, myopia_gt_list)

            # update
            if self.args.data_modality:
                self.myopia_auc_last20.update(myopia_auc)
                self.amd_auc_last20.update(amd_auc)

            self.total_auc_last20.update(total_auc)
            mean, deviation = self.total_auc_top10.top_update_calc(total_auc)

            self.is_best = total_auc > self.best_auc
            self.best_auc = max(total_auc, self.best_auc)
            """
            plot metrics curve
            """
            # ROC curve
            self.vis.draw_roc(fpr, tpr)
            # total auc, primary metrics
            self.vis.plot_single_win(dict(value=total_auc,
                                          best=self.best_auc,
                                          last_avg=self.total_auc_last20.avg,
                                          last_std=self.total_auc_last20.std,
                                          top_avg=mean,
                                          top_dev=deviation),
                                     win='total_auc')

            self.vis.plot_single_win(dict(total_acc=total_acc,
                                          total_sen=total_sen,
                                          total_spe=total_spe,
                                          amd_acc=amd_acc,
                                          amd_sen=amd_sen,
                                          amd_spe=amd_spe,
                                          myopia_acc=myopia_acc,
                                          myopia_sen=myopia_sen,
                                          myopia_spe=myopia_spe),
                                     win='accuracy')

            # Difference
            if self.args.data_modality == 'fundus':
                self.vis.plot_single_win(dict(
                    value=amd_auc,
                    last_avg=self.amd_auc_last20.avg,
                    last_std=self.amd_auc_last20.std),
                                         win='amd_auc')
                self.vis.plot_single_win(dict(
                    value=myopia_auc,
                    last_avg=self.myopia_auc_last20.avg,
                    last_std=self.myopia_auc_last20.std),
                                         win='myopia_auc')

                metrics_str = 'best_auc = {:.4f},' \
                              'total_avg = {:.4f}, total_std = {:.4f}, ' \
                              'total_top_avg = {:.4f}, total_top_dev = {:.4f}, ' \
                              'amd_avg = {:.4f}, amd_std = {:.4f}, ' \
                              'myopia_avg = {:.4f}, myopia_std ={:.4f}'.format(self.best_auc,
                                       self.total_auc_last20.avg, self.total_auc_last20.std,
                                       mean, deviation,
                                       self.amd_auc_last20.avg, self.amd_auc_last20.std,
                                       self.myopia_auc_last20.avg, self.myopia_auc_last20.std)
                metrics_acc_str = '\n total_acc = {:.4f}, total_sen = {:.4f}, total_spe = {:.4f}, ' \
                                  'amd_acc = {:.4f}, amd_sen = {:.4f}, amd_spe = {:.4f}, ' \
                                  'myopia_acc = {:.4f}, myopia_sen = {:.4f}, myopia_spe = {:.4f}'\
                    .format(total_acc, total_sen, total_spe, amd_acc, amd_sen,
                            amd_spe, myopia_acc, myopia_sen, myopia_spe)

            else:
                metrics_str = 'best_auc = {:.4f},' \
                              'total_avg = {:.4f}, total_std = {:.4f}, ' \
                              'total_top_avg = {:.4f}, total_top_dev = {:.4f}'.format(self.best_auc,
                                      self.total_auc_last20.avg,
                                      self.total_auc_last20.std,
                                      mean, deviation)
                metrics_acc_str = '\n None'

            self.vis.text(metrics_str + metrics_acc_str)

        save_ckpt(version=self.args.version,
                  state={
                      'epoch': self.epoch,
                      'state_dict_G': self.model.model_G2.state_dict(),
                      'state_dict_D': self.model.model_D.state_dict(),
                  },
                  epoch=self.epoch,
                  is_best=self.is_best,
                  args=self.args)

        print('\n Save ckpt successfully!')
        print('\n', metrics_str + metrics_acc_str)

    def test_acc(self):
        self.model.train()

        with torch.no_grad():
            """
            Difference: abnormal dataloader and abnormal_list
            """
            _, normal_train_pred_list = self.forward_cls_dataloader(
                loader=self.train_loader, is_disease=False)

            if self.args.data_modality == 'fundus':
                myopia_gt_list, myopia_pred_list = self.forward_cls_dataloader(
                    loader=self.myopia_fundus_loader, is_disease=True)

                amd_gt_list, amd_pred_list = self.forward_cls_dataloader(
                    loader=self.amd_fundus_loader, is_disease=True)
            else:
                abnormal_gt_list, abnormal_pred_list = self.forward_cls_dataloader(
                    loader=self.oct_abnormal_loader, is_disease=True)

            normal_gt_list, normal_pred_list = self.forward_cls_dataloader(
                loader=self.normal_test_loader, is_disease=False)
            """
            compute metrics
            """
            # Difference: total_true_list and total_pred_list
            if self.args.data_modality == 'fundus':
                # test metrics for amd
                amd_auc_true_list = amd_gt_list + normal_gt_list
                amd_auc_pred_list = amd_pred_list + normal_pred_list
                # myopia
                myopia_auc_true_list = myopia_gt_list + normal_gt_list
                myopia_auc_pred_list = myopia_pred_list + normal_pred_list
                # total
                total_true_list = amd_auc_true_list + myopia_gt_list
                total_pred_list = amd_auc_pred_list + myopia_pred_list

                # fpr, tpr, thresholds = metrics.roc_curve()
                myopia_auc = metrics.roc_auc_score(
                    np.array(myopia_auc_true_list),
                    np.array(myopia_auc_pred_list))
                amd_auc = metrics.roc_auc_score(np.array(amd_auc_true_list),
                                                np.array(amd_auc_pred_list))

            else:
                total_true_list = abnormal_gt_list + normal_gt_list
                total_pred_list = abnormal_pred_list + normal_pred_list

            # get roc curve and compute the auc
            fpr, tpr, thresholds = metrics.roc_curve(np.array(total_true_list),
                                                     np.array(total_pred_list))
            total_auc = metrics.auc(fpr, tpr)
            """
            compute thereshold, and then compute the accuracy of AMD and Myopia
            """
            percentage = 0.75
            _threshold_for_acc = sorted(normal_train_pred_list)[int(
                len(normal_train_pred_list) * percentage)]

            normal_cls_pred_list = [(0 if i < _threshold_for_acc else 1)
                                    for i in normal_pred_list]
            amd_cls_pred_list = [(0 if i < _threshold_for_acc else 1)
                                 for i in amd_pred_list]
            myopia_cls_pred_list = [(0 if i < _threshold_for_acc else 1)
                                    for i in myopia_pred_list]

            # acc, sensitivity and specifity
            def calcu_cls_acc(pred_list, gt_list):
                cls_pred_list = normal_cls_pred_list + pred_list
                gt_list = normal_gt_list + gt_list
                acc = metrics.accuracy_score(y_true=gt_list,
                                             y_pred=cls_pred_list)
                tn, fp, fn, tp = metrics.confusion_matrix(
                    y_true=gt_list, y_pred=cls_pred_list).ravel()
                sen = tp / (tp + fn + 1e-7)
                spe = tn / (tn + fp + 1e-7)
                return acc, sen, spe

            amd_acc, amd_sen, amd_spe = calcu_cls_acc(amd_cls_pred_list,
                                                      amd_gt_list)
            myopia_acc, myopia_sen, myopia_spe = calcu_cls_acc(
                myopia_cls_pred_list, myopia_gt_list)
            """
            plot metrics curve
            """
            # ROC curve
            self.vis.draw_roc(fpr, tpr)

            metrics_auc_str = 'AUC = {:.4f}, AMD AUC = {:.4f}, Myopia AUC = {:.4f}'.\
                format(total_auc, amd_auc, myopia_auc)
            metrics_amd_acc_str = '\n amd_acc = {:.4f}, amd_sen = {:.4f}, amd_spe = {:.4f}'.\
                format(amd_acc, amd_sen, amd_spe)
            metrics_myopia_acc_str = '\n myopia_acc = {:.4f}, myopia_sen = {:.4f}, myopia_spe = {:.4f}'.\
                format(myopia_acc,  myopia_sen, myopia_spe)

            self.vis.text(metrics_auc_str + metrics_amd_acc_str +
                          metrics_myopia_acc_str)
            print(metrics_auc_str + metrics_amd_acc_str +
                  metrics_myopia_acc_str)

    def forward_cls_dataloader(self, loader, is_disease):
        gt_list = []
        pred_list = []
        for i, (image, image_name_item) in enumerate(loader):
            image = image.cuda(non_blocking=True)
            # val, forward
            seg_mask, image_rec = self.model(image)

            if self.args.data_modality == 'fundus':
                case_name = ['']
                image_name = image_name_item
            else:
                case_name, image_name = image_name_item
            """
            preditction
            """
            # BCWH -> B, anomaly score
            image_residual = torch.abs(image_rec - image)
            image_diff_mae = image_residual.mean(dim=3).mean(dim=2).mean(dim=1)

            # image: tensor
            # image_name: list
            # image_name.shape[0]: batch
            gt_list += [1 if is_disease else 0] * len(image_name)
            pred_list += image_diff_mae.tolist()
            """
            visdom
            """
            if i % self.args.vis_freq_inval == 0:
                image = image[:self.args.vis_batch]
                image_rec = image_rec[:self.args.vis_batch].clamp(0, 1)
                image_diff = torch.abs(image - image_rec)
                """
                Difference: seg_mask is different between fundus and oct images
                """
                if self.args.data_modality == 'fundus':
                    seg_mask = seg_mask[:self.args.vis_batch].clamp(0, 1)
                else:
                    seg_mask = torch.argmax(seg_mask[:self.args.vis_batch],
                                            dim=1).float()
                    seg_mask = (seg_mask.unsqueeze(dim=1) / 11).clamp(0, 1)

                vim_images = torch.cat(
                    [image, seg_mask, image_rec, image_diff], dim=0)

                self.vis.images(vim_images,
                                win_name='val',
                                nrow=self.args.vis_batch)
                """
                save images
                """
                output_save = os.path.join(self.args.output_root,
                                           self.args.project, 'output_v1_0812',
                                           '{}'.format(self.args.version),
                                           'val')

                if not os.path.exists(output_save):
                    os.makedirs(output_save)
                tv.utils.save_image(vim_images,
                                    os.path.join(
                                        output_save, '{}_{}.png'.format(
                                            case_name[0], image_name[0])),
                                    nrow=self.args.vis_batch)

        return gt_list, pred_list
コード例 #3
0
ファイル: srunitnet_4x.py プロジェクト: happog/FudanOCR
    def train(self, edgenetpath=None, srresnetpath=None, random_scale=True, rotate=True, fliplr=True, fliptb=True):
        vis = Visualizer(self.env)

        print('================ Loading datasets =================')
        # load training dataset
        print('## Current Mode: Train')
        train_data_loader = self.load_dataset(
            mode='train', random_scale=random_scale, rotate=rotate, fliplr=fliplr, fliptb=fliptb)

        ##########################################################
        ##################### build network ######################
        ##########################################################
        print('Building Networks and initialize parameters\' weights....')
        # init sr resnet
        srresnet = Upscale4xResnetGenerator(input_nc=3, output_nc=3, n_blocks=5,
                                          norm='batch', learn_residual=True)
        srresnet.apply(weights_init_normal)

        # init discriminator
        discnet = NLayerDiscriminator(input_nc=3, ndf=64, n_layers=5)

        # init edgenet
        edgenet = HED_1L()
        if edgenetpath is None or not os.path.exists(edgenetpath):
            raise Exception('Invalid edgenet model')
        else:
            pretrained_dict = torch.load(edgenetpath)
            model_dict = edgenet.state_dict()
            pretrained_dict = {k: v for k,
                               v in pretrained_dict.items() if k in model_dict}
            model_dict.update(pretrained_dict)
            edgenet.load_state_dict(model_dict)

        # init vgg feature
        featuremapping = VGGFeatureMap(models.vgg19(pretrained=True))

        # load pretrained srresnet or just initialize
        if srresnetpath is None or not os.path.exists(srresnetpath):
            print('===> initialize the deblurnet')
            print('======> No pretrained model')
        else:
            print('======> loading the weight from pretrained model')
            # deblurnet.load_state_dict(torch.load(srresnetpath))
            pretrained_dict = torch.load(srresnetpath)
            model_dict = srresnet.state_dict()

            pretrained_dict = {k: v for k,
                               v in pretrained_dict.items() if k in model_dict}
            model_dict.update(pretrained_dict)
            srresnet.load_state_dict(model_dict)

        # optimizer init
        # different learning rate
        lr = self.lr

        srresnet_optimizer = optim.Adam(
            srresnet.parameters(), lr=lr*10, betas=(0.9, 0.999))
        disc_optimizer = optim.Adam(
            discnet.parameters(), lr=lr/10, betas=(0.9, 0.999))

        # loss function init
        MSE_loss = nn.MSELoss()
        BCE_loss = nn.BCELoss()

        # cuda accelerate
        if USE_GPU:
            edgenet.cuda()
            srresnet.cuda()
            discnet.cuda()
            featuremapping.cuda()
            MSE_loss.cuda()
            BCE_loss.cuda()
            print('\tCUDA acceleration is available.')

        ##########################################################
        ##################### train network ######################
        ##########################################################
        import torchnet as tnt
        from tqdm import tqdm
        from PIL import Image

        batchnorm = nn.BatchNorm2d(1).cuda()

        edge_avg_loss = tnt.meter.AverageValueMeter()
        total_avg_loss = tnt.meter.AverageValueMeter()
        disc_avg_loss = tnt.meter.AverageValueMeter()
        psnr_2x_avg = tnt.meter.AverageValueMeter()
        ssim_2x_avg = tnt.meter.AverageValueMeter()
        psnr_4x_avg = tnt.meter.AverageValueMeter()
        ssim_4x_avg = tnt.meter.AverageValueMeter()

        save_dir = os.path.join(self.save_dir, 'train_result')
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        srresnet.train()
        discnet.train()
        itcnt = 0
        for epoch in range(self.num_epochs):
            psnr_2x_avg.reset()
            ssim_2x_avg.reset()
            psnr_4x_avg.reset()
            ssim_4x_avg.reset()

            # learning rate is decayed by a factor every 20 epoch
            if (epoch + 1 % 20) == 0:
                for param_group in srresnet_optimizer.param_groups:
                    param_group["lr"] /= 10.0
                print("Learning rate decay for srresnet: lr={}".format(
                    srresnet_optimizer.param_groups[0]["lr"]))
                for param_group in disc_optimizer.param_groups:
                    param_group["lr"] /= 10.0
                print("Learning rate decay for discnet: lr={}".format(
                    disc_optimizer.param_groups[0]["lr"]))

            itbar = tqdm(enumerate(train_data_loader))
            for ii, (hr, lr2x, lr4x, bc2x, bc4x) in itbar:

                mini_batch = hr.size()[0]

                hr_ = Variable(hr)
                lr2x_ = Variable(lr2x)
                lr4x_ = Variable(lr4x)
                bc2x_ = Variable(bc2x)
                bc4x_ = Variable(bc4x)
                real_label = Variable(torch.ones(mini_batch))
                fake_label = Variable(torch.zeros(mini_batch))

                # cuda mode setting
                if USE_GPU:
                    hr_ = hr_.cuda()
                    lr2x_ = lr2x_.cuda()
                    lr4x_ = lr4x_.cuda()
                    bc2x_ = bc2x_.cuda()
                    bc4x_ = bc4x_.cuda()
                    real_label = real_label.cuda()
                    fake_label = fake_label.cuda()

                # =============================================================== #
                # ================ Edge-based srresnet training ================= #
                # =============================================================== #
                sr2x_, sr4x_ = srresnet(lr4x_)

                '''===================== Train Discriminator ====================='''
                if epoch + 1 > self.pretrain_epochs:
                    disc_optimizer.zero_grad()

                    #===== 2x disc loss =====#
                    real_decision_2x = discnet(lr2x_)
                    real_loss_2x = BCE_loss(
                        real_decision_2x, real_label.detach())

                    fake_decision_2x = discnet(sr2x_.detach())
                    fake_loss_2x = BCE_loss(
                        fake_decision_2x, fake_label.detach())

                    disc_loss_2x = real_loss_2x + fake_loss_2x

                    disc_loss_2x.backward()
                    disc_optimizer.step()

                    #===== 4x disc loss =====#
                    real_decision_4x = discnet(hr_)
                    real_loss_4x = BCE_loss(
                        real_decision_4x, real_label.detach())

                    fake_decision_4x = discnet(sr4x_.detach())
                    fake_loss_4x = BCE_loss(
                        fake_decision_4x, fake_label.detach())

                    disc_loss_4x = real_loss_4x + fake_loss_4x

                    disc_loss_4x.backward()
                    disc_optimizer.step()

                    disc_avg_loss.add(
                        (disc_loss_2x + disc_loss_4x).data.item())

                '''=================== Train srresnet Generator ==================='''
                srresnet_optimizer.zero_grad()

                edge_trade_off = [0.7, 0.2, 0.1, 0.05, 0.01, 0.3]
                if epoch + 1 > self.pretrain_epochs:
                    a1, a2, a3 = 0.6, 0.1, 0.65
                else:
                    a1, a2, a3 = 0.45, 0.0, 0.95

                #============ calculate 2x loss ==============#
                #### Edgenet Loss ####
                pred = edgenet(sr2x_)
                real = edgenet(lr2x_)

                edge_loss_2x = BCE_loss(pred.detach(), real.detach())
                # for i in range(6):
                #     edge_loss_2x += edge_trade_off[i] * \
                #         BCE_loss(pred[i].detach(), real[i].detach())
                # edge_loss = 0.7 * BCE2d(pred[0], real[i]) + 0.3 * BCE2d(pred[5], real[i])

                #### Content Loss ####
                content_loss_2x = MSE_loss(sr2x_, lr2x_)

                #### Perceptual Loss ####
                real_feature = featuremapping(lr2x_)
                fake_feature = featuremapping(sr2x_)
                vgg_loss_2x = MSE_loss(fake_feature, real_feature.detach())

                #### Adversarial Loss ####
                advs_loss_2x = BCE_loss(discnet(sr2x_), real_label)

                total_loss_2x = a1 * edge_loss_2x + a2 * advs_loss_2x + \
                    a3 * content_loss_2x + (1.0 - a3) * vgg_loss_2x

                #============ calculate 4x loss ==============#
                #### Edgenet Loss ####
                pred = edgenet(sr4x_)
                real = edgenet(hr_)

                # edge_loss_4x = 0
                edge_loss_4x = BCE_loss(pred.detach(), real.detach())
                # for i in range(6):
                #     edge_loss_4x += edge_trade_off[i] * \
                #         BCE_loss(pred[i].detach(), real[i].detach())
                # edge_loss = 0.7 * BCE2d(pred[0], real[i]) + 0.3 * BCE2d(pred[5], real[i])

                #### Content Loss ####
                content_loss_4x = MSE_loss(sr4x_, hr_)

                #### Perceptual Loss ####
                real_feature = featuremapping(hr_)
                fake_feature = featuremapping(sr4x_)
                vgg_loss_4x = MSE_loss(fake_feature, real_feature.detach())

                #### Adversarial Loss ####
                advs_loss_4x = BCE_loss(discnet(sr4x_), real_label)

                total_loss_4x = a1 * edge_loss_4x + a2 * advs_loss_4x + \
                    a3 * content_loss_4x + (1.0 - a3) * vgg_loss_4x

                #============== loss backward ===============#
                total_loss = 0.01 * total_loss_2x + 1.0 * total_loss_2x
                total_loss.backward()
                srresnet_optimizer.step()

                #============ calculate scores ==============#
                psnr_2x_score_process = batch_compare_filter(
                    sr2x_.cpu().data, lr2x, PSNR)
                psnr_2x_avg.add(psnr_2x_score_process)

                ssim_2x_score_process = batch_compare_filter(
                    sr2x_.cpu().data, lr2x, SSIM)
                ssim_2x_avg.add(ssim_2x_score_process)

                psnr_4x_score_process = batch_compare_filter(
                    sr4x_.cpu().data, hr, PSNR)
                psnr_4x_avg.add(psnr_4x_score_process)

                ssim_4x_score_process = batch_compare_filter(
                    sr4x_.cpu().data, hr, SSIM)
                ssim_4x_avg.add(ssim_4x_score_process)

                total_avg_loss.add(total_loss.data.item())
                edge_avg_loss.add((edge_loss_2x+edge_loss_4x).data.item())
                disc_avg_loss.add((advs_loss_2x+advs_loss_4x).data.item())

                if (ii+1) % self.plot_iter == self.plot_iter-1:
                    res = {'edge loss': edge_avg_loss.value()[0],
                           'generate loss': total_avg_loss.value()[0],
                           'discriminate loss': disc_avg_loss.value()[0]}
                    vis.plot_many(res, 'Deblur net Loss')

                    psnr_2x_score_origin = batch_compare_filter(
                        bc2x, lr2x, PSNR)
                    psnr_4x_score_origin = batch_compare_filter(bc4x, hr, PSNR)
                    res_psnr = {'2x_origin_psnr': psnr_2x_score_origin,
                                '2x_sr_psnr': psnr_2x_score_process,
                                '4x_origin_psnr': psnr_4x_score_origin,
                                '4x_sr_psnr': psnr_4x_score_process}
                    vis.plot_many(res_psnr, 'PSNR Score')

                    ssim_2x_score_origin = batch_compare_filter(
                        bc2x, lr2x, SSIM)
                    ssim_4x_score_origin = batch_compare_filter(bc4x, hr, SSIM)
                    res_ssim = {'2x_origin_ssim': ssim_2x_score_origin,
                                '2x_sr_ssim': ssim_2x_score_process,
                                '4x_origin_ssim': ssim_4x_score_origin,
                                '4x_sr_ssim': ssim_4x_score_process}
                    vis.plot_many(res_ssim, 'SSIM Score')

                #======================= Output result of total training processing =======================#
                itcnt += 1
                itbar.set_description("Epoch: [%2d] [%d/%d] PSNR_2x_Avg: %.6f, SSIM_2x_Avg: %.6f, PSNR_4x_Avg: %.6f, SSIM_4x_Avg: %.6f"
                                      % ((epoch + 1), (ii + 1), len(train_data_loader),
                                         psnr_2x_avg.value()[0], ssim_2x_avg.value()[
                                          0],
                                         psnr_4x_avg.value()[0], ssim_4x_avg.value()[0]))

                if (ii+1) % self.plot_iter == self.plot_iter-1:
                    # test_ = deblurnet(torch.cat([y_.detach(), x_edge], 1))
                    hr_edge = edgenet(hr_)
                    sr2x_edge = edgenet(sr2x_)
                    sr4x_edge = edgenet(sr4x_)

                    vis.images(hr_edge.cpu().data, win='HR edge predict', opts=dict(
                        title='HR edge predict'))
                    vis.images(sr2x_edge.cpu().data, win='SR2X edge predict', opts=dict(
                        title='SR2X edge predict'))
                    vis.images(sr4x_edge.cpu().data, win='SR4X edge predict', opts=dict(
                        title='SR4X edge predict'))

                    vis.images(lr2x, win='LR2X image',
                               opts=dict(title='LR2X image'))
                    vis.images(lr4x, win='LR4X image',
                               opts=dict(title='LR4X image'))
                    vis.images(bc2x, win='BC2X image',
                               opts=dict(title='BC2X image'))
                    vis.images(bc4x, win='BC4X image',
                               opts=dict(title='BC4X image'))
                    vis.images(sr2x_.cpu().data, win='SR2X image',
                               opts=dict(title='SR2X image'))
                    vis.images(sr4x_.cpu().data, win='SR4X image',
                               opts=dict(title='SR4X image'))

                    vis.images(hr, win='HR image',
                               opts=dict(title='HR image'))

                t_save_dir = 'results/train_result/'+self.train_dataset
                if not os.path.exists(t_save_dir):
                    os.makedirs(t_save_dir)

            if (epoch + 1) % self.save_epochs == 0:
                self.save_model(srresnet, os.path.join(self.save_dir, 'checkpoints'), 'srresnet_param_batch{}_lr{}_epoch{}'.
                                format(self.batch_size, self.lr, epoch+1))

        # Save final trained model and results
        vis.save([self.env])
        self.save_model(srresnet, os.path.join(self.save_dir, 'checkpoints'), 'srresnet_param_batch{}_lr{}_epoch{}'.
                        format(self.batch_size, self.lr, self.num_epochs))
コード例 #4
0
ファイル: srcnn.py プロジェクト: happog/FudanOCR
    def train(self,
              srcnn_path=None,
              random_scale=True,
              rotate=True,
              fliplr=True,
              fliptb=True):
        vis = Visualizer(self.env)

        print('================ Loading datasets =================')
        # load training dataset
        print('## Current Mode: Train')
        # train_data_loader = self.load_dataset(mode='valid')
        train_data_loader = self.load_dataset(mode='train',
                                              random_scale=random_scale,
                                              rotate=rotate,
                                              fliplr=fliplr,
                                              fliptb=fliptb)

        ##########################################################
        ##################### build network ######################
        ##########################################################
        print('Building Networks and initialize parameters\' weights....')
        # init srnet
        srcnn = SRCNN()
        srcnn.apply(weights_init_normal)

        # load pretrained srresnet or just initialize
        if srcnn_path is None or not os.path.exists(srcnn_path):
            print('===> initialize the srcnn')
            print('======> No pretrained model')
        else:
            print('======> loading the weight from pretrained model')
            pretrained_dict = torch.load(srcnn_path)
            model_dict = srcnn.state_dict()

            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in model_dict
            }
            model_dict.update(pretrained_dict)
            srcnn.load_state_dict(model_dict)

        # optimizer init
        # different learning rate
        lr = self.lr

        srcnn_optimizer = optim.Adam(srcnn.parameters(),
                                     lr=lr,
                                     betas=(0.9, 0.999))

        # loss function init
        MSE_loss = nn.MSELoss()
        BCE_loss = nn.BCELoss()

        # cuda accelerate
        if USE_GPU:
            srcnn.cuda()
            MSE_loss.cuda()
            BCE_loss.cuda()
            print('\tCUDA acceleration is available.')

        ##########################################################
        ##################### train network ######################
        ##########################################################
        import torchnet as tnt
        from tqdm import tqdm
        from PIL import Image

        total_avg_loss = tnt.meter.AverageValueMeter()
        psnr_2x_avg = tnt.meter.AverageValueMeter()
        ssim_2x_avg = tnt.meter.AverageValueMeter()
        psnr_4x_avg = tnt.meter.AverageValueMeter()
        ssim_4x_avg = tnt.meter.AverageValueMeter()

        srcnn.train()
        itcnt = 0
        for epoch in range(self.num_epochs):
            psnr_2x_avg.reset()
            ssim_2x_avg.reset()
            psnr_4x_avg.reset()
            ssim_4x_avg.reset()

            # learning rate is decayed by a factor every 20 epoch
            if (epoch + 1 % 20) == 0:
                for param_group in srcnn_optimizer.param_groups:
                    param_group["lr"] /= 10.0
                print("Learning rate decay for srcnn: lr={}".format(
                    srcnn_optimizer.param_groups[0]["lr"]))

            itbar = tqdm(enumerate(train_data_loader))
            for ii, (hr, lr2x, lr4x, bc2x, bc4x) in itbar:

                mini_batch = hr.size()[0]

                hr_ = Variable(hr)
                lr2x_ = Variable(lr2x)
                lr4x_ = Variable(lr4x)
                bc2x_ = Variable(bc2x)
                bc4x_ = Variable(bc4x)

                # cuda mode setting
                if USE_GPU:
                    hr_ = hr_.cuda()
                    lr2x_ = lr2x_.cuda()
                    lr4x_ = lr4x_.cuda()
                    bc2x_ = bc2x_.cuda()
                    bc4x_ = bc4x_.cuda()

                # =============================================================== #
                # ======================= srcnn training ======================== #
                # =============================================================== #
                sr4x_ = srcnn(bc4x_)

                #============ calculate 4x loss ==============#
                srcnn_optimizer.zero_grad()

                #### Content Loss ####
                content_loss_4x = MSE_loss(sr4x_, hr_)

                #============ calculate scores ==============#
                psnr_4x_score_process = batch_compare_filter(
                    sr4x_.cpu().data, hr, PSNR)
                psnr_4x_avg.add(psnr_4x_score_process)

                ssim_4x_score_process = batch_compare_filter(
                    sr4x_.cpu().data, hr, SSIM)
                ssim_4x_avg.add(ssim_4x_score_process)

                #============== loss backward ===============#
                total_loss_4x = content_loss_4x

                total_loss_4x.backward()
                srcnn_optimizer.step()

                total_avg_loss.add(total_loss_4x.data.item())

                if (ii + 1) % self.plot_iter == self.plot_iter - 1:
                    res = {'generate loss': total_avg_loss.value()[0]}
                    vis.plot_many(res, 'SRCNN Loss')

                    psnr_4x_score_origin = batch_compare_filter(bc4x, hr, PSNR)
                    res_psnr = {
                        '4x_origin_psnr': psnr_4x_score_origin,
                        '4x_sr_psnr': psnr_4x_score_process
                    }
                    vis.plot_many(res_psnr, 'PSNR Score')

                    ssim_4x_score_origin = batch_compare_filter(bc4x, hr, SSIM)
                    res_ssim = {
                        '4x_origin_ssim': ssim_4x_score_origin,
                        '4x_sr_ssim': ssim_4x_score_process
                    }
                    vis.plot_many(res_ssim, 'SSIM Score')

                #======================= Output result of total training processing =======================#
                itcnt += 1
                itbar.set_description(
                    "Epoch: [%2d] [%d/%d] PSNR_2x_Avg: %.6f, SSIM_2x_Avg: %.6f, PSNR_4x_Avg: %.6f, SSIM_4x_Avg: %.6f"
                    % ((epoch + 1), (ii + 1), len(train_data_loader),
                       psnr_2x_avg.value()[0], ssim_2x_avg.value()[0],
                       psnr_4x_avg.value()[0], ssim_4x_avg.value()[0]))

                if (ii + 1) % self.plot_iter == self.plot_iter - 1:

                    vis.images(lr4x,
                               win='LR4X image',
                               opts=dict(title='LR4X image'))
                    vis.images(bc4x,
                               win='BC4X image',
                               opts=dict(title='BC4X image'))
                    vis.images(sr4x_.cpu().data,
                               win='SR4X image',
                               opts=dict(title='SR4X image'))

                    vis.images(hr, win='HR image', opts=dict(title='HR image'))

            if (epoch + 1) % self.save_epochs == 0:
                self.save_model(
                    srcnn, os.path.join(self.save_dir, 'checkpoints', 'srcnn'),
                    'srcnn_param_batch{}_lr{}_epoch{}'.format(
                        self.batch_size, self.lr, epoch + 1))

        # Save final trained model and results
        vis.save([self.env])
        self.save_model(
            srcnn, os.path.join(self.save_dir, 'checkpoints', 'srcnn'),
            'srcnn_param_batch{}_lr{}_epoch{}'.format(self.batch_size, self.lr,
                                                      self.num_epochs))
コード例 #5
0
class RunMyModel(object):
    def __init__(self):
        args = ParserArgs().get_args()
        cuda_visible(args.gpu)

        cudnn.benchmark = True

        self.vis = Visualizer(env='{}'.format(args.version),
                              port=args.port,
                              server=args.vis_server)

        if args.data_modality == 'fundus':
            self.source_loader = AnoDRIVE_Loader(
                data_root=args.fundus_data_root,
                batch=args.batch,
                scale=args.scale,
                pre=True  # pre-process
            ).data_load()
            # self.target_loader, _ = AnoIDRID_Loader(data_root=args.fundus_data_root,
            #                                      batch=args.batch,
            #                                      scale=args.scale,
            #                                     pre=True).data_load()
            self.target_loader = NewClsFundusDataloader(
                data_root=args.isee_fundus_root,
                batch=args.batch,
                scale=args.scale).load_for_seg()

        else:
            self.source_loader = ChengOCTloader(
                data_root=args.cheng_oct,
                batch=args.batch,
                scale=args.scale,
                flip=args.flip,
                rotate=args.rotate,
                enhance_p=args.enhance_p).data_load()
            self.target_loader, _ = ChallengeOCTloader(
                data_root=args.challenge_oct,
                batch=args.batch,
                scale=args.scale).data_load()

        print_args(args)
        self.args = args
        self.new_lr = self.args.lr
        self.model = SegTransferModel(args)

        if args.predict:
            self.validate_loader(self.target_loader)
        else:
            self.train_validate()

    def train_validate(self):
        for epoch in range(self.args.start_epoch, self.args.n_epochs):
            _ = adjust_lr(self.args.lr, self.model.optimizer_G, epoch,
                          [40, 80, 160, 240])
            new_lr = adjust_lr(self.args.lr, self.model.optimizer_D, epoch,
                               [40, 80, 160, 240])
            self.new_lr = min(new_lr, self.new_lr)

            self.epoch = epoch

            self.train()
            if epoch % self.args.validate_freq == 0 and epoch > self.args.save_freq:
                self.validate()
                # self.validate_loader(self.normal_test_loader)
                # self.validate_loader(self.amd_fundus_loader)
                # self.validate_loader(self.myopia_fundus_loader)

            print('\n', '*' * 10, 'Program Information', '*' * 10)
            print('Node: {}'.format(self.args.node))
            print('GPU: {}'.format(self.args.gpu))
            print('Version: {}\n'.format(self.args.version))

    def train(self):
        self.model.train()

        prev_time = time.time()

        target_loader_iter = self.target_loader.__iter__()
        # target_loader_isee_iter = self.target_loader.__iter__()
        for i, (image_source, mask_source_gt,
                _) in enumerate(self.source_loader):
            mask_source_gt = mask_source_gt.cuda(non_blocking=True)
            image_source = image_source.cuda(non_blocking=True).float()

            image_target, _ = next(target_loader_iter)
            image_target = image_target.cuda(non_blocking=True)
            output_source_mask, output_target_mask, logs = \
                self.model.process(image_source, mask_source_gt, image_target)

            # if self.epoch % 2 == 0:
            #     # train on IDRiD dataset
            #     image_target, _, _ = next(target_loader_iter)
            #     image_target = image_target.cuda(non_blocking=True)
            #     output_source_mask, output_target_mask, logs = \
            #         self.model.process(image_source, mask_source_gt, image_target)
            # else:
            #     # train on iSee dataset
            #     image_target, _, = next(target_loader_isee_iter)
            #     image_target = image_target.cuda(non_blocking=True)
            #     output_source_mask, output_target_mask, logs = \
            #         self.model.process(image_source, mask_source_gt, image_target)

            # --------------
            #  Log Progress
            # --------------
            # Determine approximate time left
            batches_done = self.epoch * self.source_loader.__len__() + i
            batches_left = self.args.n_epochs * self.source_loader.__len__(
            ) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] ETA: %s"
                % (self.epoch, self.args.n_epochs, i,
                   self.source_loader.__len__(), logs['dis_loss'].item(),
                   logs['gen_loss'].item(), time_left))

            # --------------
            #  Visdom
            # --------------
            if i % self.args.vis_freq == 0:
                image_source = image_source[:self.args.vis_batch]
                image_target = image_target[:self.args.vis_batch]
                if self.args.data_modality == 'oct':
                    # OCT: {0, 1, ..., 11}, BWH
                    # BWH -> B1WH,
                    mask_source_gt = mask_source_gt[:self.args.
                                                    vis_batch].unsqueeze(
                                                        dim=1) / 11
                    # B1WH
                    output_source_mask = torch.clamp(
                        output_source_mask[:self.args.vis_batch] / 11, 0, 1)
                    output_target_mask = torch.clamp(
                        output_target_mask[:self.args.vis_batch] / 11, 0, 1)
                else:
                    # fundus: {0, 1}, B1WH
                    mask_source_gt = mask_source_gt[:self.args.vis_batch]
                    output_source_mask = torch.clamp(
                        output_source_mask[:self.args.vis_batch], 0, 1)
                    output_target_mask = torch.clamp(
                        output_target_mask[:self.args.vis_batch], 0, 1)

                vim_images = torch.cat([
                    image_source, mask_source_gt, output_source_mask,
                    image_target, output_target_mask
                ],
                                       dim=0)
                self.vis.images(vim_images,
                                win_name='train',
                                nrow=self.args.vis_batch)

            if i + 1 == self.source_loader.__len__():
                self.vis.plot_multi_win(
                    dict(dis_loss=logs['dis_loss'].item(),
                         seg_loss=logs['seg_loss'].item(),
                         lr=self.new_lr))
                self.vis.plot_single_win(dict(
                    gen_loss=logs['gen_loss'].item(),
                    gen_fm_loss=logs['gen_fm_loss'].item(),
                    gen_gan_loss=logs['gen_gan_loss'].item(),
                    gen_content_loss=logs['gen_content_loss'].item(),
                    gen_style_loss=logs['gen_style_loss'].item()),
                                         win='gen_loss')

    def validate(self):
        self.model.eval()
        with torch.no_grad():
            for i, (image, _) in enumerate(self.target_loader):
                image = image.cuda(non_blocking=True).float()

                # forward
                output_mask = self.model(image)

                if i % self.args.vis_freq_inval == 0:
                    image = image[:self.args.vis_batch]
                    if self.args.data_modality == 'oct':
                        # OCT: {0, 1, ..., 11}
                        # gt: BWH
                        # model output: BCWH (C=12)

                        # BCWH -> BWH -> B1WH
                        output_mask = F.log_softmax(output_mask, dim=1)
                        _, output_mask = torch.max(output_mask, dim=1)
                        output_mask = output_mask.float().unsqueeze(dim=1)

                        # {0, 1, ..., 11} -> (0, 1)
                        output_mask = torch.clamp(
                            output_mask[:self.args.vis_batch] / 11, 0, 1)
                    else:
                        # fundus: {0, 1}, B1WH
                        output_mask = output_mask[:self.args.vis_batch]

                    save_images = torch.cat([image, output_mask], dim=0)
                    output_save = os.path.join(self.args.output_root,
                                               self.args.project, 'output',
                                               self.args.version, 'val')
                    if not os.path.exists(output_save):
                        os.makedirs(output_save)
                    tv.utils.save_image(save_images,
                                        os.path.join(output_save,
                                                     '{}.png'.format(i)),
                                        nrow=self.args.vis_batch)

                    # print('val: [Batch {}/{}]'.format(i, self.target_loader.__len__()))

        save_ckpt(version=self.args.version,
                  state={
                      'epoch': self.epoch,
                      'state_dict_G': self.model.model_G.state_dict(),
                      'state_dict_D': self.model.model_D.state_dict(),
                  },
                  epoch=self.epoch,
                  args=self.args)
        print('Save ckpt successfully!')

    def validate_loader(self, dataloader):
        self.model.eval()
        with torch.no_grad():
            for i, (image, image_name) in enumerate(dataloader):
                image = image.cuda(non_blocking=True).float()

                # forward
                output_mask = self.model(image)

                if i % self.args.vis_freq_inval == 0:
                    image = image[:self.args.vis_batch]
                    if self.args.data_modality == 'oct':
                        # OCT: {0, 1, ..., 11}
                        # gt: BWH
                        # model output: BCWH (C=12)

                        # BCWH -> BWH -> B1WH
                        output_mask = F.log_softmax(output_mask, dim=1)
                        _, output_mask = torch.max(output_mask, dim=1)
                        output_mask = output_mask.float().unsqueeze(dim=1)

                        # {0, 1, ..., 11} -> (0, 1)
                        output_mask = torch.clamp(
                            output_mask[:self.args.vis_batch] / 11, 0, 1)
                    else:
                        # fundus: {0, 1}, B1WH
                        output_mask = output_mask[:self.args.vis_batch]

                    save_images = torch.cat([image, output_mask], dim=0)
                    output_save = os.path.join(self.args.output_root,
                                               self.args.project, 'output',
                                               self.args.version, 'val')
                    if not os.path.exists(output_save):
                        os.makedirs(output_save)
                    tv.utils.save_image(save_images,
                                        os.path.join(
                                            output_save,
                                            '{}.png'.format(image_name[0])),
                                        nrow=self.args.vis_batch)

    def predict(self):
        self.model.eval()
        with torch.no_grad():
            for i, (image, _, item_name) in enumerate(self.target_loader):
                image = image.cuda(non_blocking=True).float()

                if self.args.batch == 1:
                    if self.args.data_modality == 'oct':
                        case_name, image_name = item_name
                        case_name = case_name[0]
                        image_name = image_name[0]
                    else:
                        case_name = 'fundus'
                        image_name = item_name[0]
                else:
                    raise NotImplementedError('error')

                # forward
                output_mask = self.model(image)

                dim_channel = 1
                if self.args.data_modality == 'oct':
                    # mask prob for CRF
                    mask_prob = F.softmax(output_mask, dim=dim_channel)

                    # output the segmentation mask
                    output_mask = F.log_softmax(output_mask, dim=dim_channel)
                    _, output_mask = torch.max(output_mask, dim=dim_channel)
                    output_mask = output_mask.float().unsqueeze(
                        dim=dim_channel)
                    # {0, 1, ..., 11} -> (0, 1)
                    _output_mask = torch.clamp(output_mask / 11, 0, 1)

                    if self.args.use_crf:
                        # CHW -> HWC (224, 224, 1)
                        # optimize: tensor.permute(2, 0, 1)
                        _image = image.squeeze(dim=0).cpu().transpose(
                            0, 2).transpose(0, 1)
                        # OCT, 1 channel. (224, 224, 1) -> (224, 224, 3)
                        _image = _image.repeat(1, 1, 3)
                        mask = mask_prob.squeeze(dim=0).cpu()
                        crf_mask = dense_crf(
                            np.array(_image).astype(np.uint8), mask)
                        _crf_mask = torch.Tensor(crf_mask.astype(
                            np.float)) / 11
                        # HW -> BCHW
                        _crf_mask = _crf_mask.expand((1, 1, -1, -1)).cuda()
                    else:
                        _crf_mask = output_mask

                else:
                    # fundus: {0, 1}, B1WH
                    _output_mask = output_mask.clamp(0, 1)
                    # raise NotImplementedError('error for fundus mode')

                save_images = torch.cat([image, _output_mask], dim=0)
                output_save_path = os.path.join(
                    '/home/imed/new_disk/workspace/', self.args.project,
                    'output', self.args.version, 'predict')
                save_name = '{}_{}.png'.format(case_name, image_name)
                self.vis.images(save_images, win_name='predict')
                if not os.path.exists(output_save_path):
                    os.makedirs(output_save_path)
                tv.utils.save_image(save_images,
                                    os.path.join(output_save_path, save_name),
                                    nrow=2)

                pdb.set_trace()

                # ---------
                # save mask
                # ---------
                # To optimize
                save_flag = False
                if save_flag:
                    save_path = os.path.join(mask_vgg_root, case_name)
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    self.save_oct(output_mask,
                                  os.path.join(save_path, image_name))

                    save_path = os.path.join(mask_crf_root, case_name)
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    self.save_oct(crf_mask,
                                  os.path.join(save_path, image_name),
                                  crf_mode=True)

    def save_oct(self, tensor, filename, crf_mode=False):
        if crf_mode:
            misc.imsave(filename, tensor)
        else:
            B, C, _, _ = tensor.shape
            assert B == 1 and C == 1, 'error about shape'
            tensor = tensor.squeeze()
            ndarr = tensor.cpu().numpy()
            misc.imsave(filename, ndarr)
コード例 #6
0
class ResnetRunner(object):
    def __init__(self):
        args = ParserArgs().args
        cuda_visible(args.gpu)

        model = resnet50(in_channels=1, num_classes=2)
        model = nn.DataParallel(model).cuda()
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)

        # Optionally resume from a checkpoint
        if args.resume:
            ckpt_root = os.path.join('/root/workspace', args.project,
                                     'checkpoints')
            ckpt_path = os.path.join(ckpt_root, args.resume)
            if os.path.isfile(ckpt_path):
                print("=> loading checkpoint '{}'".format(args.resume))
            #     checkpoint = torch.load(ckpt_path)
            #     args.start_epoch = checkpoint['epoch']
            #     self.val_best_iou = checkpoint['best_iou']
            #     model.load_state_dict(checkpoint['state_dict'])
            #     optimizer.load_state_dict(checkpoint['optimizer'])
            #     print("=> loaded checkpoint '{}' (epoch {})"
            #           .format(args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        cudnn.benchmark = True

        self.vis = Visualizer(env='{}'.format(args.version), port=args.port)

        self.train_loader = ultraLoader(root=args.dataroot,
                                        batch=args.batch,
                                        version='train').data_load()
        self.val_loader = ultraLoader(root=args.dataroot,
                                      batch=args.batch,
                                      version='validation').data_load()
        self.test_loader = ultraLoader(root=args.dataroot,
                                       batch=args.batch,
                                       version='test_ours').data_load()
        self.test_loader_bigan = ultraLoader(root=args.dataroot,
                                             batch=args.batch,
                                             version='bigan').data_load()
        self.test_loader_cyclegan = ultraLoader(
            root=args.dataroot, batch=args.batch,
            version='cyclegan').data_load()

        print_args(args)
        self.args = args
        self.model = model
        self.optimizer = optimizer
        self.criterion = nn.CrossEntropyLoss().cuda()

    def train_test(self):
        self.best_acc = 0
        for epoch in range(self.args.n_epochs):
            adjust_lr(self.args.lr, self.optimizer, epoch, 30)
            self.epoch = epoch

            self.train()
            self.test(self.val_loader, 'validation')
            self.test(self.test_loader, 'test_ours')
            self.test(self.test_loader_bigan, 'bigan')
            self.test(self.test_loader_cyclegan, 'cyclegan')
            print('\n', '*' * 10, 'Program Information', '*' * 10)
            print('Node: {}'.format(self.args.node))
            print('Version: {}\n'.format(self.args.version))

    def train(self):
        self.model.train()
        for i, (img, label) in enumerate(self.train_loader):
            img = img.cuda(non_blocking=True)
            label = label.cuda(non_blocking=True)

            output = self.model(img)
            _, pred = torch.max(output, 1)

            loss = self.criterion(output, label)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if i % 2 == 0:
                self.vis.images(img[0].squeeze(),
                                name='train',
                                img_name='{}_{}'.format(
                                    label[0].item(), pred[0].item()))

            if i + 1 == self.train_loader.__len__():
                self.vis.plot_many(dict(loss=loss.item()))
            if i % self.args.print_freq == 0:
                print('[{}] Epoch: [{}][{}/{}]\t, Loss: {:.4f}'.format(
                    datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                    self.epoch, i, self.train_loader.__len__(), loss))

    def test(self, test_loader, version):
        prob_list = []
        pred_list = []
        true_list = []

        self.model.eval()
        with torch.no_grad():
            for i, (img, label) in enumerate(test_loader):
                img = img.cuda(non_blocking=True)
                label = label.cuda(non_blocking=True)

                output = self.model(img)
                output = F.softmax(output, dim=1)
                _, pred = torch.max(output, 1)

                prob_list.append(output[0][1].item())
                pred_list.append(pred.item())
                true_list.append(label.item())

                if i % 3 == 0:
                    self.vis.images(img.squeeze(),
                                    name=version,
                                    img_name='{}_{}'.format(
                                        label.item(), label.item()))

            # fpr, tpr, thresholds = metrics.roc_curve(
            #     y_true=true_list, y_score=prob_list, pos_label=1, drop_intermediate=False)
            #
            # pdb.set_trace()
            # auc = metrics.auc(fpr, tpr)
            auc = metrics.roc_auc_score(y_true=true_list, y_score=prob_list)
            acc = metrics.accuracy_score(y_true=true_list, y_pred=pred_list)

            if version == 'validation':
                is_best = acc > self.best_acc
                self.best_acc = max(acc, self.best_acc)
                save_ckpt(version=self.args.version,
                          state={
                              'epoch': self.epoch + 1,
                              'state_dict': self.model.state_dict(),
                              'best_acc': self.best_acc,
                              'optimizer': self.optimizer.state_dict(),
                          },
                          is_best=is_best,
                          epoch=self.epoch + 1,
                          project='2018_OCT_transfer')
                print('Save ckpt successfully!')

            print('*' * 10, 'Auc = {:.3f}, Acc = {:.3f}'.format(auc, acc),
                  '*' * 10)
            self.vis.plot_legend(win='auc', name=version, y=auc)
            self.vis.plot_legend(win='acc', name=version, y=acc)