def data_analysis(): import time from utils.visualizer import Visualizer vis = Visualizer(env='{}'.format('v99_99_debug'), port=31434) dataroot = '/root/workspace/2018_US_project/ultrsound_zhy' train_loader, test_loader = ultraDataLoader(dataroot, 1).data_load() for i, (sample, target, name) in enumerate(train_loader): # sample.shape = (B, 1, 480, 480) vis.images(sample) vis.plot_single_win(dict(max=sample.max(), mean=sample.mean(), min=sample.min()), win='sample') vis.plot_multi_win(dict(target=target.item())) time.sleep(2)
class RunMyModel(object): def __init__(self): args = ParserArgs().get_args() cuda_visible(args.gpu) cudnn.benchmark = True self.vis = Visualizer(env='{}'.format(args.version), port=args.port, server=args.vis_server) if args.data_modality == 'fundus': # IDRiD dataset for segmentation # image, mask, image_name_item # iSee dataset for classification # image, image_name self.train_loader, self.normal_test_loader, \ self.amd_fundus_loader, self.myopia_fundus_loader, \ self.glaucoma_fundus_loader, self.dr_fundus_loader = \ NewClsFundusDataloader(data_root=self.args.isee_fundus_root, batch=self.args.batch, scale=self.args.scale).data_load() else: # Challenge OCT dataset for classification # image, [case_name, image_name] self.train_loader, self.normal_test_loader, self.oct_abnormal_loader = OCT_ClsDataloader( data_root=args.challenge_oct, batch=args.batch, scale=args.scale).data_load() print_args(args) self.args = args self.new_lr = self.args.lr self.model = PNetModel(args) if args.predict: self.test_acc() else: self.train_val() def train_val(self): # general metrics self.best_auc = 0 self.is_best = False # self.total_auc_top10 = AverageMeter() self.total_auc_last10 = LastAvgMeter(length=10) self.acc_last10 = LastAvgMeter(length=10) # metrics for iSee self.myopia_auc_last10 = LastAvgMeter(length=10) self.amd_auc_last10 = LastAvgMeter(length=10) self.glaucoma_auc_last10 = LastAvgMeter(length=10) self.dr_auc_last10 = LastAvgMeter(length=10) for epoch in range(self.args.start_epoch, self.args.n_epochs): if self.args.data_modality == 'fundus': # total: 1000 adjust_lr_epoch_list = [40, 80, 160, 240] else: # total: 180 adjust_lr_epoch_list = [20, 40, 80, 120] _ = adjust_lr(self.args.lr, self.model.optimizer_G, epoch, adjust_lr_epoch_list) new_lr = adjust_lr(self.args.lr, self.model.optimizer_D, epoch, adjust_lr_epoch_list) self.new_lr = min(new_lr, self.new_lr) self.epoch = epoch self.train() # last 80 epoch, validate with freq if epoch > self.args.validate_start_epoch \ and (epoch % self.args.validate_freq == 0 or epoch > (self.args.n_epochs - self.args.validate_each_epoch)): self.validate_cls() print('\n', '*' * 10, 'Program Information', '*' * 10) print('Node: {}'.format(self.args.node)) print('GPU: {}'.format(self.args.gpu)) print('Version: {}\n'.format(self.args.version)) def train(self): self.model.train() prev_time = time.time() train_loader = self.train_loader for i, ( image, _, ) in enumerate(train_loader): image = image.cuda(non_blocking=True) # train seg_mask, image_rec, gen_loss, dis_loss, logs = \ self.model.process(image) # backward self.model.backward(gen_loss, dis_loss) # -------------- # Log Progress # -------------- # Determine approximate time left batches_done = self.epoch * train_loader.__len__() + i batches_left = self.args.n_epochs * train_loader.__len__( ) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() # Print log sys.stdout.write( "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] ETA: %s" % (self.epoch, self.args.n_epochs, i, train_loader.__len__(), dis_loss.item(), gen_loss.item(), time_left)) # -------------- # Visdom # -------------- if i % self.args.vis_freq == 0: image = image[:self.args.vis_batch] if self.args.data_modality == 'oct': # BCWH -> BWH, torch.max in Channel dimension seg_mask = torch.argmax(seg_mask[:self.args.vis_batch], dim=1).float() # BWH -> B1WH, 11 -> 1 seg_mask = (seg_mask.unsqueeze(dim=1) / 11).clamp(0, 1) else: seg_mask = seg_mask[:self.args.vis_batch].clamp(0, 1) image_rec = image_rec[:self.args.vis_batch].clamp(0, 1) image_diff = torch.abs(image - image_rec) vim_images = torch.cat( [image, seg_mask, image_rec, image_diff], dim=0) self.vis.images(vim_images, win_name='train', nrow=self.args.vis_batch) output_save = os.path.join(self.args.output_root, self.args.project, 'output_v1_0812', self.args.version, 'train') if not os.path.exists(output_save): os.makedirs(output_save) tv.utils.save_image(vim_images, os.path.join(output_save, '{}.png'.format(i)), nrow=4) if i + 1 == train_loader.__len__(): self.vis.plot_multi_win( dict(dis_loss=dis_loss.item(), lr=self.new_lr)) self.vis.plot_single_win(dict( gen_loss=gen_loss.item(), gen_l1_loss=logs['gen_l1_loss'].item(), gen_fm_loss=logs['gen_fm_loss'].item(), gen_gan_loss=logs['gen_gan_loss'].item(), gen_content_loss=logs['gen_content_loss'].item(), gen_style_loss=logs['gen_style_loss'].item()), win='gen_loss') def validate_cls(self): # self.model.eval() self.model.train() with torch.no_grad(): """ Difference: abnormal dataloader and abnormal_list """ if self.args.data_modality == 'fundus': myopia_gt_list, myopia_pred_list = self.forward_cls_dataloader( loader=self.myopia_fundus_loader, is_disease=True) amd_gt_list, amd_pred_list = self.forward_cls_dataloader( loader=self.amd_fundus_loader, is_disease=True) glaucoma_gt_list, glaucoma_pred_list = self.forward_cls_dataloader( loader=self.glaucoma_fundus_loader, is_disease=True) dr_gt_list, dr_pred_list = self.forward_cls_dataloader( loader=self.dr_fundus_loader, is_disease=True) else: abnormal_gt_list, abnormal_pred_list = self.forward_cls_dataloader( loader=self.oct_abnormal_loader, is_disease=True) _, normal_train_pred_list = self.forward_cls_dataloader( loader=self.train_loader, is_disease=False) normal_gt_list, normal_pred_list = self.forward_cls_dataloader( loader=self.normal_test_loader, is_disease=False) """ computer metrics """ # Difference: total_true_list and total_pred_list if self.args.data_modality == 'fundus': # test metrics for myopia m_true_list = myopia_gt_list + normal_gt_list m_pred_list = myopia_pred_list + normal_pred_list # test metrics for amd a_true_list = amd_gt_list + normal_gt_list a_pred_list = amd_pred_list + normal_pred_list # test metrics for glaucoma g_true_list = glaucoma_gt_list + normal_gt_list g_pred_list = glaucoma_pred_list + normal_pred_list # test metrics for amd d_true_list = dr_gt_list + normal_gt_list d_pred_list = dr_pred_list + normal_pred_list # total total_true_list = a_true_list + myopia_gt_list + glaucoma_gt_list + dr_gt_list total_pred_list = a_pred_list + myopia_pred_list + glaucoma_pred_list + dr_pred_list # fpr, tpr, thresholds = metrics.roc_curve() myopia_auc = metrics.roc_auc_score(np.array(m_true_list), np.array(m_pred_list)) amd_auc = metrics.roc_auc_score(np.array(a_true_list), np.array(a_pred_list)) glaucoma_auc = metrics.roc_auc_score(np.array(g_true_list), np.array(g_pred_list)) dr_auc = metrics.roc_auc_score(np.array(d_true_list), np.array(d_pred_list)) else: total_true_list = abnormal_gt_list + normal_gt_list total_pred_list = abnormal_pred_list + normal_pred_list # get roc curve and compute the auc fpr, tpr, thresholds = metrics.roc_curve(np.array(total_true_list), np.array(total_pred_list)) total_auc = metrics.auc(fpr, tpr) """ compute thereshold, and then compute the accuracy """ percentage = 0.75 _threshold_for_acc = sorted(normal_train_pred_list)[int( len(normal_train_pred_list) * percentage)] normal_cls_pred_list = [(0 if i < _threshold_for_acc else 1) for i in normal_pred_list] amd_cls_pred_list = [(0 if i < _threshold_for_acc else 1) for i in amd_pred_list] myopia_cls_pred_list = [(0 if i < _threshold_for_acc else 1) for i in myopia_pred_list] glaucoma_cls_pred_list = [(0 if i < _threshold_for_acc else 1) for i in glaucoma_pred_list] dr_cls_pred_list = [(0 if i < _threshold_for_acc else 1) for i in dr_pred_list] # acc, sensitivity and specifity def calcu_cls_acc(pred_list, gt_list): cls_pred_list = normal_cls_pred_list + pred_list gt_list = normal_gt_list + gt_list acc = metrics.accuracy_score(y_true=gt_list, y_pred=cls_pred_list) tn, fp, fn, tp = metrics.confusion_matrix( y_true=gt_list, y_pred=cls_pred_list).ravel() sen = tp / (tp + fn + 1e-7) spe = tn / (tn + fp + 1e-7) return acc, sen, spe total_acc, total_sen, total_spe = calcu_cls_acc( amd_cls_pred_list + myopia_cls_pred_list, amd_gt_list + myopia_gt_list) amd_acc, amd_sen, amd_spe = calcu_cls_acc(amd_cls_pred_list, amd_gt_list) myopia_acc, myopia_sen, myopia_spe = calcu_cls_acc( myopia_cls_pred_list, myopia_gt_list) # update if self.args.data_modality: self.myopia_auc_last20.update(myopia_auc) self.amd_auc_last20.update(amd_auc) self.total_auc_last20.update(total_auc) mean, deviation = self.total_auc_top10.top_update_calc(total_auc) self.is_best = total_auc > self.best_auc self.best_auc = max(total_auc, self.best_auc) """ plot metrics curve """ # ROC curve self.vis.draw_roc(fpr, tpr) # total auc, primary metrics self.vis.plot_single_win(dict(value=total_auc, best=self.best_auc, last_avg=self.total_auc_last20.avg, last_std=self.total_auc_last20.std, top_avg=mean, top_dev=deviation), win='total_auc') self.vis.plot_single_win(dict(total_acc=total_acc, total_sen=total_sen, total_spe=total_spe, amd_acc=amd_acc, amd_sen=amd_sen, amd_spe=amd_spe, myopia_acc=myopia_acc, myopia_sen=myopia_sen, myopia_spe=myopia_spe), win='accuracy') # Difference if self.args.data_modality == 'fundus': self.vis.plot_single_win(dict( value=amd_auc, last_avg=self.amd_auc_last20.avg, last_std=self.amd_auc_last20.std), win='amd_auc') self.vis.plot_single_win(dict( value=myopia_auc, last_avg=self.myopia_auc_last20.avg, last_std=self.myopia_auc_last20.std), win='myopia_auc') metrics_str = 'best_auc = {:.4f},' \ 'total_avg = {:.4f}, total_std = {:.4f}, ' \ 'total_top_avg = {:.4f}, total_top_dev = {:.4f}, ' \ 'amd_avg = {:.4f}, amd_std = {:.4f}, ' \ 'myopia_avg = {:.4f}, myopia_std ={:.4f}'.format(self.best_auc, self.total_auc_last20.avg, self.total_auc_last20.std, mean, deviation, self.amd_auc_last20.avg, self.amd_auc_last20.std, self.myopia_auc_last20.avg, self.myopia_auc_last20.std) metrics_acc_str = '\n total_acc = {:.4f}, total_sen = {:.4f}, total_spe = {:.4f}, ' \ 'amd_acc = {:.4f}, amd_sen = {:.4f}, amd_spe = {:.4f}, ' \ 'myopia_acc = {:.4f}, myopia_sen = {:.4f}, myopia_spe = {:.4f}'\ .format(total_acc, total_sen, total_spe, amd_acc, amd_sen, amd_spe, myopia_acc, myopia_sen, myopia_spe) else: metrics_str = 'best_auc = {:.4f},' \ 'total_avg = {:.4f}, total_std = {:.4f}, ' \ 'total_top_avg = {:.4f}, total_top_dev = {:.4f}'.format(self.best_auc, self.total_auc_last20.avg, self.total_auc_last20.std, mean, deviation) metrics_acc_str = '\n None' self.vis.text(metrics_str + metrics_acc_str) save_ckpt(version=self.args.version, state={ 'epoch': self.epoch, 'state_dict_G': self.model.model_G2.state_dict(), 'state_dict_D': self.model.model_D.state_dict(), }, epoch=self.epoch, is_best=self.is_best, args=self.args) print('\n Save ckpt successfully!') print('\n', metrics_str + metrics_acc_str) def test_acc(self): self.model.train() with torch.no_grad(): """ Difference: abnormal dataloader and abnormal_list """ _, normal_train_pred_list = self.forward_cls_dataloader( loader=self.train_loader, is_disease=False) if self.args.data_modality == 'fundus': myopia_gt_list, myopia_pred_list = self.forward_cls_dataloader( loader=self.myopia_fundus_loader, is_disease=True) amd_gt_list, amd_pred_list = self.forward_cls_dataloader( loader=self.amd_fundus_loader, is_disease=True) else: abnormal_gt_list, abnormal_pred_list = self.forward_cls_dataloader( loader=self.oct_abnormal_loader, is_disease=True) normal_gt_list, normal_pred_list = self.forward_cls_dataloader( loader=self.normal_test_loader, is_disease=False) """ compute metrics """ # Difference: total_true_list and total_pred_list if self.args.data_modality == 'fundus': # test metrics for amd amd_auc_true_list = amd_gt_list + normal_gt_list amd_auc_pred_list = amd_pred_list + normal_pred_list # myopia myopia_auc_true_list = myopia_gt_list + normal_gt_list myopia_auc_pred_list = myopia_pred_list + normal_pred_list # total total_true_list = amd_auc_true_list + myopia_gt_list total_pred_list = amd_auc_pred_list + myopia_pred_list # fpr, tpr, thresholds = metrics.roc_curve() myopia_auc = metrics.roc_auc_score( np.array(myopia_auc_true_list), np.array(myopia_auc_pred_list)) amd_auc = metrics.roc_auc_score(np.array(amd_auc_true_list), np.array(amd_auc_pred_list)) else: total_true_list = abnormal_gt_list + normal_gt_list total_pred_list = abnormal_pred_list + normal_pred_list # get roc curve and compute the auc fpr, tpr, thresholds = metrics.roc_curve(np.array(total_true_list), np.array(total_pred_list)) total_auc = metrics.auc(fpr, tpr) """ compute thereshold, and then compute the accuracy of AMD and Myopia """ percentage = 0.75 _threshold_for_acc = sorted(normal_train_pred_list)[int( len(normal_train_pred_list) * percentage)] normal_cls_pred_list = [(0 if i < _threshold_for_acc else 1) for i in normal_pred_list] amd_cls_pred_list = [(0 if i < _threshold_for_acc else 1) for i in amd_pred_list] myopia_cls_pred_list = [(0 if i < _threshold_for_acc else 1) for i in myopia_pred_list] # acc, sensitivity and specifity def calcu_cls_acc(pred_list, gt_list): cls_pred_list = normal_cls_pred_list + pred_list gt_list = normal_gt_list + gt_list acc = metrics.accuracy_score(y_true=gt_list, y_pred=cls_pred_list) tn, fp, fn, tp = metrics.confusion_matrix( y_true=gt_list, y_pred=cls_pred_list).ravel() sen = tp / (tp + fn + 1e-7) spe = tn / (tn + fp + 1e-7) return acc, sen, spe amd_acc, amd_sen, amd_spe = calcu_cls_acc(amd_cls_pred_list, amd_gt_list) myopia_acc, myopia_sen, myopia_spe = calcu_cls_acc( myopia_cls_pred_list, myopia_gt_list) """ plot metrics curve """ # ROC curve self.vis.draw_roc(fpr, tpr) metrics_auc_str = 'AUC = {:.4f}, AMD AUC = {:.4f}, Myopia AUC = {:.4f}'.\ format(total_auc, amd_auc, myopia_auc) metrics_amd_acc_str = '\n amd_acc = {:.4f}, amd_sen = {:.4f}, amd_spe = {:.4f}'.\ format(amd_acc, amd_sen, amd_spe) metrics_myopia_acc_str = '\n myopia_acc = {:.4f}, myopia_sen = {:.4f}, myopia_spe = {:.4f}'.\ format(myopia_acc, myopia_sen, myopia_spe) self.vis.text(metrics_auc_str + metrics_amd_acc_str + metrics_myopia_acc_str) print(metrics_auc_str + metrics_amd_acc_str + metrics_myopia_acc_str) def forward_cls_dataloader(self, loader, is_disease): gt_list = [] pred_list = [] for i, (image, image_name_item) in enumerate(loader): image = image.cuda(non_blocking=True) # val, forward seg_mask, image_rec = self.model(image) if self.args.data_modality == 'fundus': case_name = [''] image_name = image_name_item else: case_name, image_name = image_name_item """ preditction """ # BCWH -> B, anomaly score image_residual = torch.abs(image_rec - image) image_diff_mae = image_residual.mean(dim=3).mean(dim=2).mean(dim=1) # image: tensor # image_name: list # image_name.shape[0]: batch gt_list += [1 if is_disease else 0] * len(image_name) pred_list += image_diff_mae.tolist() """ visdom """ if i % self.args.vis_freq_inval == 0: image = image[:self.args.vis_batch] image_rec = image_rec[:self.args.vis_batch].clamp(0, 1) image_diff = torch.abs(image - image_rec) """ Difference: seg_mask is different between fundus and oct images """ if self.args.data_modality == 'fundus': seg_mask = seg_mask[:self.args.vis_batch].clamp(0, 1) else: seg_mask = torch.argmax(seg_mask[:self.args.vis_batch], dim=1).float() seg_mask = (seg_mask.unsqueeze(dim=1) / 11).clamp(0, 1) vim_images = torch.cat( [image, seg_mask, image_rec, image_diff], dim=0) self.vis.images(vim_images, win_name='val', nrow=self.args.vis_batch) """ save images """ output_save = os.path.join(self.args.output_root, self.args.project, 'output_v1_0812', '{}'.format(self.args.version), 'val') if not os.path.exists(output_save): os.makedirs(output_save) tv.utils.save_image(vim_images, os.path.join( output_save, '{}_{}.png'.format( case_name[0], image_name[0])), nrow=self.args.vis_batch) return gt_list, pred_list
def train(self, edgenetpath=None, srresnetpath=None, random_scale=True, rotate=True, fliplr=True, fliptb=True): vis = Visualizer(self.env) print('================ Loading datasets =================') # load training dataset print('## Current Mode: Train') train_data_loader = self.load_dataset( mode='train', random_scale=random_scale, rotate=rotate, fliplr=fliplr, fliptb=fliptb) ########################################################## ##################### build network ###################### ########################################################## print('Building Networks and initialize parameters\' weights....') # init sr resnet srresnet = Upscale4xResnetGenerator(input_nc=3, output_nc=3, n_blocks=5, norm='batch', learn_residual=True) srresnet.apply(weights_init_normal) # init discriminator discnet = NLayerDiscriminator(input_nc=3, ndf=64, n_layers=5) # init edgenet edgenet = HED_1L() if edgenetpath is None or not os.path.exists(edgenetpath): raise Exception('Invalid edgenet model') else: pretrained_dict = torch.load(edgenetpath) model_dict = edgenet.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) edgenet.load_state_dict(model_dict) # init vgg feature featuremapping = VGGFeatureMap(models.vgg19(pretrained=True)) # load pretrained srresnet or just initialize if srresnetpath is None or not os.path.exists(srresnetpath): print('===> initialize the deblurnet') print('======> No pretrained model') else: print('======> loading the weight from pretrained model') # deblurnet.load_state_dict(torch.load(srresnetpath)) pretrained_dict = torch.load(srresnetpath) model_dict = srresnet.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) srresnet.load_state_dict(model_dict) # optimizer init # different learning rate lr = self.lr srresnet_optimizer = optim.Adam( srresnet.parameters(), lr=lr*10, betas=(0.9, 0.999)) disc_optimizer = optim.Adam( discnet.parameters(), lr=lr/10, betas=(0.9, 0.999)) # loss function init MSE_loss = nn.MSELoss() BCE_loss = nn.BCELoss() # cuda accelerate if USE_GPU: edgenet.cuda() srresnet.cuda() discnet.cuda() featuremapping.cuda() MSE_loss.cuda() BCE_loss.cuda() print('\tCUDA acceleration is available.') ########################################################## ##################### train network ###################### ########################################################## import torchnet as tnt from tqdm import tqdm from PIL import Image batchnorm = nn.BatchNorm2d(1).cuda() edge_avg_loss = tnt.meter.AverageValueMeter() total_avg_loss = tnt.meter.AverageValueMeter() disc_avg_loss = tnt.meter.AverageValueMeter() psnr_2x_avg = tnt.meter.AverageValueMeter() ssim_2x_avg = tnt.meter.AverageValueMeter() psnr_4x_avg = tnt.meter.AverageValueMeter() ssim_4x_avg = tnt.meter.AverageValueMeter() save_dir = os.path.join(self.save_dir, 'train_result') if not os.path.exists(save_dir): os.makedirs(save_dir) srresnet.train() discnet.train() itcnt = 0 for epoch in range(self.num_epochs): psnr_2x_avg.reset() ssim_2x_avg.reset() psnr_4x_avg.reset() ssim_4x_avg.reset() # learning rate is decayed by a factor every 20 epoch if (epoch + 1 % 20) == 0: for param_group in srresnet_optimizer.param_groups: param_group["lr"] /= 10.0 print("Learning rate decay for srresnet: lr={}".format( srresnet_optimizer.param_groups[0]["lr"])) for param_group in disc_optimizer.param_groups: param_group["lr"] /= 10.0 print("Learning rate decay for discnet: lr={}".format( disc_optimizer.param_groups[0]["lr"])) itbar = tqdm(enumerate(train_data_loader)) for ii, (hr, lr2x, lr4x, bc2x, bc4x) in itbar: mini_batch = hr.size()[0] hr_ = Variable(hr) lr2x_ = Variable(lr2x) lr4x_ = Variable(lr4x) bc2x_ = Variable(bc2x) bc4x_ = Variable(bc4x) real_label = Variable(torch.ones(mini_batch)) fake_label = Variable(torch.zeros(mini_batch)) # cuda mode setting if USE_GPU: hr_ = hr_.cuda() lr2x_ = lr2x_.cuda() lr4x_ = lr4x_.cuda() bc2x_ = bc2x_.cuda() bc4x_ = bc4x_.cuda() real_label = real_label.cuda() fake_label = fake_label.cuda() # =============================================================== # # ================ Edge-based srresnet training ================= # # =============================================================== # sr2x_, sr4x_ = srresnet(lr4x_) '''===================== Train Discriminator =====================''' if epoch + 1 > self.pretrain_epochs: disc_optimizer.zero_grad() #===== 2x disc loss =====# real_decision_2x = discnet(lr2x_) real_loss_2x = BCE_loss( real_decision_2x, real_label.detach()) fake_decision_2x = discnet(sr2x_.detach()) fake_loss_2x = BCE_loss( fake_decision_2x, fake_label.detach()) disc_loss_2x = real_loss_2x + fake_loss_2x disc_loss_2x.backward() disc_optimizer.step() #===== 4x disc loss =====# real_decision_4x = discnet(hr_) real_loss_4x = BCE_loss( real_decision_4x, real_label.detach()) fake_decision_4x = discnet(sr4x_.detach()) fake_loss_4x = BCE_loss( fake_decision_4x, fake_label.detach()) disc_loss_4x = real_loss_4x + fake_loss_4x disc_loss_4x.backward() disc_optimizer.step() disc_avg_loss.add( (disc_loss_2x + disc_loss_4x).data.item()) '''=================== Train srresnet Generator ===================''' srresnet_optimizer.zero_grad() edge_trade_off = [0.7, 0.2, 0.1, 0.05, 0.01, 0.3] if epoch + 1 > self.pretrain_epochs: a1, a2, a3 = 0.6, 0.1, 0.65 else: a1, a2, a3 = 0.45, 0.0, 0.95 #============ calculate 2x loss ==============# #### Edgenet Loss #### pred = edgenet(sr2x_) real = edgenet(lr2x_) edge_loss_2x = BCE_loss(pred.detach(), real.detach()) # for i in range(6): # edge_loss_2x += edge_trade_off[i] * \ # BCE_loss(pred[i].detach(), real[i].detach()) # edge_loss = 0.7 * BCE2d(pred[0], real[i]) + 0.3 * BCE2d(pred[5], real[i]) #### Content Loss #### content_loss_2x = MSE_loss(sr2x_, lr2x_) #### Perceptual Loss #### real_feature = featuremapping(lr2x_) fake_feature = featuremapping(sr2x_) vgg_loss_2x = MSE_loss(fake_feature, real_feature.detach()) #### Adversarial Loss #### advs_loss_2x = BCE_loss(discnet(sr2x_), real_label) total_loss_2x = a1 * edge_loss_2x + a2 * advs_loss_2x + \ a3 * content_loss_2x + (1.0 - a3) * vgg_loss_2x #============ calculate 4x loss ==============# #### Edgenet Loss #### pred = edgenet(sr4x_) real = edgenet(hr_) # edge_loss_4x = 0 edge_loss_4x = BCE_loss(pred.detach(), real.detach()) # for i in range(6): # edge_loss_4x += edge_trade_off[i] * \ # BCE_loss(pred[i].detach(), real[i].detach()) # edge_loss = 0.7 * BCE2d(pred[0], real[i]) + 0.3 * BCE2d(pred[5], real[i]) #### Content Loss #### content_loss_4x = MSE_loss(sr4x_, hr_) #### Perceptual Loss #### real_feature = featuremapping(hr_) fake_feature = featuremapping(sr4x_) vgg_loss_4x = MSE_loss(fake_feature, real_feature.detach()) #### Adversarial Loss #### advs_loss_4x = BCE_loss(discnet(sr4x_), real_label) total_loss_4x = a1 * edge_loss_4x + a2 * advs_loss_4x + \ a3 * content_loss_4x + (1.0 - a3) * vgg_loss_4x #============== loss backward ===============# total_loss = 0.01 * total_loss_2x + 1.0 * total_loss_2x total_loss.backward() srresnet_optimizer.step() #============ calculate scores ==============# psnr_2x_score_process = batch_compare_filter( sr2x_.cpu().data, lr2x, PSNR) psnr_2x_avg.add(psnr_2x_score_process) ssim_2x_score_process = batch_compare_filter( sr2x_.cpu().data, lr2x, SSIM) ssim_2x_avg.add(ssim_2x_score_process) psnr_4x_score_process = batch_compare_filter( sr4x_.cpu().data, hr, PSNR) psnr_4x_avg.add(psnr_4x_score_process) ssim_4x_score_process = batch_compare_filter( sr4x_.cpu().data, hr, SSIM) ssim_4x_avg.add(ssim_4x_score_process) total_avg_loss.add(total_loss.data.item()) edge_avg_loss.add((edge_loss_2x+edge_loss_4x).data.item()) disc_avg_loss.add((advs_loss_2x+advs_loss_4x).data.item()) if (ii+1) % self.plot_iter == self.plot_iter-1: res = {'edge loss': edge_avg_loss.value()[0], 'generate loss': total_avg_loss.value()[0], 'discriminate loss': disc_avg_loss.value()[0]} vis.plot_many(res, 'Deblur net Loss') psnr_2x_score_origin = batch_compare_filter( bc2x, lr2x, PSNR) psnr_4x_score_origin = batch_compare_filter(bc4x, hr, PSNR) res_psnr = {'2x_origin_psnr': psnr_2x_score_origin, '2x_sr_psnr': psnr_2x_score_process, '4x_origin_psnr': psnr_4x_score_origin, '4x_sr_psnr': psnr_4x_score_process} vis.plot_many(res_psnr, 'PSNR Score') ssim_2x_score_origin = batch_compare_filter( bc2x, lr2x, SSIM) ssim_4x_score_origin = batch_compare_filter(bc4x, hr, SSIM) res_ssim = {'2x_origin_ssim': ssim_2x_score_origin, '2x_sr_ssim': ssim_2x_score_process, '4x_origin_ssim': ssim_4x_score_origin, '4x_sr_ssim': ssim_4x_score_process} vis.plot_many(res_ssim, 'SSIM Score') #======================= Output result of total training processing =======================# itcnt += 1 itbar.set_description("Epoch: [%2d] [%d/%d] PSNR_2x_Avg: %.6f, SSIM_2x_Avg: %.6f, PSNR_4x_Avg: %.6f, SSIM_4x_Avg: %.6f" % ((epoch + 1), (ii + 1), len(train_data_loader), psnr_2x_avg.value()[0], ssim_2x_avg.value()[ 0], psnr_4x_avg.value()[0], ssim_4x_avg.value()[0])) if (ii+1) % self.plot_iter == self.plot_iter-1: # test_ = deblurnet(torch.cat([y_.detach(), x_edge], 1)) hr_edge = edgenet(hr_) sr2x_edge = edgenet(sr2x_) sr4x_edge = edgenet(sr4x_) vis.images(hr_edge.cpu().data, win='HR edge predict', opts=dict( title='HR edge predict')) vis.images(sr2x_edge.cpu().data, win='SR2X edge predict', opts=dict( title='SR2X edge predict')) vis.images(sr4x_edge.cpu().data, win='SR4X edge predict', opts=dict( title='SR4X edge predict')) vis.images(lr2x, win='LR2X image', opts=dict(title='LR2X image')) vis.images(lr4x, win='LR4X image', opts=dict(title='LR4X image')) vis.images(bc2x, win='BC2X image', opts=dict(title='BC2X image')) vis.images(bc4x, win='BC4X image', opts=dict(title='BC4X image')) vis.images(sr2x_.cpu().data, win='SR2X image', opts=dict(title='SR2X image')) vis.images(sr4x_.cpu().data, win='SR4X image', opts=dict(title='SR4X image')) vis.images(hr, win='HR image', opts=dict(title='HR image')) t_save_dir = 'results/train_result/'+self.train_dataset if not os.path.exists(t_save_dir): os.makedirs(t_save_dir) if (epoch + 1) % self.save_epochs == 0: self.save_model(srresnet, os.path.join(self.save_dir, 'checkpoints'), 'srresnet_param_batch{}_lr{}_epoch{}'. format(self.batch_size, self.lr, epoch+1)) # Save final trained model and results vis.save([self.env]) self.save_model(srresnet, os.path.join(self.save_dir, 'checkpoints'), 'srresnet_param_batch{}_lr{}_epoch{}'. format(self.batch_size, self.lr, self.num_epochs))
def train(self, srcnn_path=None, random_scale=True, rotate=True, fliplr=True, fliptb=True): vis = Visualizer(self.env) print('================ Loading datasets =================') # load training dataset print('## Current Mode: Train') # train_data_loader = self.load_dataset(mode='valid') train_data_loader = self.load_dataset(mode='train', random_scale=random_scale, rotate=rotate, fliplr=fliplr, fliptb=fliptb) ########################################################## ##################### build network ###################### ########################################################## print('Building Networks and initialize parameters\' weights....') # init srnet srcnn = SRCNN() srcnn.apply(weights_init_normal) # load pretrained srresnet or just initialize if srcnn_path is None or not os.path.exists(srcnn_path): print('===> initialize the srcnn') print('======> No pretrained model') else: print('======> loading the weight from pretrained model') pretrained_dict = torch.load(srcnn_path) model_dict = srcnn.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) srcnn.load_state_dict(model_dict) # optimizer init # different learning rate lr = self.lr srcnn_optimizer = optim.Adam(srcnn.parameters(), lr=lr, betas=(0.9, 0.999)) # loss function init MSE_loss = nn.MSELoss() BCE_loss = nn.BCELoss() # cuda accelerate if USE_GPU: srcnn.cuda() MSE_loss.cuda() BCE_loss.cuda() print('\tCUDA acceleration is available.') ########################################################## ##################### train network ###################### ########################################################## import torchnet as tnt from tqdm import tqdm from PIL import Image total_avg_loss = tnt.meter.AverageValueMeter() psnr_2x_avg = tnt.meter.AverageValueMeter() ssim_2x_avg = tnt.meter.AverageValueMeter() psnr_4x_avg = tnt.meter.AverageValueMeter() ssim_4x_avg = tnt.meter.AverageValueMeter() srcnn.train() itcnt = 0 for epoch in range(self.num_epochs): psnr_2x_avg.reset() ssim_2x_avg.reset() psnr_4x_avg.reset() ssim_4x_avg.reset() # learning rate is decayed by a factor every 20 epoch if (epoch + 1 % 20) == 0: for param_group in srcnn_optimizer.param_groups: param_group["lr"] /= 10.0 print("Learning rate decay for srcnn: lr={}".format( srcnn_optimizer.param_groups[0]["lr"])) itbar = tqdm(enumerate(train_data_loader)) for ii, (hr, lr2x, lr4x, bc2x, bc4x) in itbar: mini_batch = hr.size()[0] hr_ = Variable(hr) lr2x_ = Variable(lr2x) lr4x_ = Variable(lr4x) bc2x_ = Variable(bc2x) bc4x_ = Variable(bc4x) # cuda mode setting if USE_GPU: hr_ = hr_.cuda() lr2x_ = lr2x_.cuda() lr4x_ = lr4x_.cuda() bc2x_ = bc2x_.cuda() bc4x_ = bc4x_.cuda() # =============================================================== # # ======================= srcnn training ======================== # # =============================================================== # sr4x_ = srcnn(bc4x_) #============ calculate 4x loss ==============# srcnn_optimizer.zero_grad() #### Content Loss #### content_loss_4x = MSE_loss(sr4x_, hr_) #============ calculate scores ==============# psnr_4x_score_process = batch_compare_filter( sr4x_.cpu().data, hr, PSNR) psnr_4x_avg.add(psnr_4x_score_process) ssim_4x_score_process = batch_compare_filter( sr4x_.cpu().data, hr, SSIM) ssim_4x_avg.add(ssim_4x_score_process) #============== loss backward ===============# total_loss_4x = content_loss_4x total_loss_4x.backward() srcnn_optimizer.step() total_avg_loss.add(total_loss_4x.data.item()) if (ii + 1) % self.plot_iter == self.plot_iter - 1: res = {'generate loss': total_avg_loss.value()[0]} vis.plot_many(res, 'SRCNN Loss') psnr_4x_score_origin = batch_compare_filter(bc4x, hr, PSNR) res_psnr = { '4x_origin_psnr': psnr_4x_score_origin, '4x_sr_psnr': psnr_4x_score_process } vis.plot_many(res_psnr, 'PSNR Score') ssim_4x_score_origin = batch_compare_filter(bc4x, hr, SSIM) res_ssim = { '4x_origin_ssim': ssim_4x_score_origin, '4x_sr_ssim': ssim_4x_score_process } vis.plot_many(res_ssim, 'SSIM Score') #======================= Output result of total training processing =======================# itcnt += 1 itbar.set_description( "Epoch: [%2d] [%d/%d] PSNR_2x_Avg: %.6f, SSIM_2x_Avg: %.6f, PSNR_4x_Avg: %.6f, SSIM_4x_Avg: %.6f" % ((epoch + 1), (ii + 1), len(train_data_loader), psnr_2x_avg.value()[0], ssim_2x_avg.value()[0], psnr_4x_avg.value()[0], ssim_4x_avg.value()[0])) if (ii + 1) % self.plot_iter == self.plot_iter - 1: vis.images(lr4x, win='LR4X image', opts=dict(title='LR4X image')) vis.images(bc4x, win='BC4X image', opts=dict(title='BC4X image')) vis.images(sr4x_.cpu().data, win='SR4X image', opts=dict(title='SR4X image')) vis.images(hr, win='HR image', opts=dict(title='HR image')) if (epoch + 1) % self.save_epochs == 0: self.save_model( srcnn, os.path.join(self.save_dir, 'checkpoints', 'srcnn'), 'srcnn_param_batch{}_lr{}_epoch{}'.format( self.batch_size, self.lr, epoch + 1)) # Save final trained model and results vis.save([self.env]) self.save_model( srcnn, os.path.join(self.save_dir, 'checkpoints', 'srcnn'), 'srcnn_param_batch{}_lr{}_epoch{}'.format(self.batch_size, self.lr, self.num_epochs))
class RunMyModel(object): def __init__(self): args = ParserArgs().get_args() cuda_visible(args.gpu) cudnn.benchmark = True self.vis = Visualizer(env='{}'.format(args.version), port=args.port, server=args.vis_server) if args.data_modality == 'fundus': self.source_loader = AnoDRIVE_Loader( data_root=args.fundus_data_root, batch=args.batch, scale=args.scale, pre=True # pre-process ).data_load() # self.target_loader, _ = AnoIDRID_Loader(data_root=args.fundus_data_root, # batch=args.batch, # scale=args.scale, # pre=True).data_load() self.target_loader = NewClsFundusDataloader( data_root=args.isee_fundus_root, batch=args.batch, scale=args.scale).load_for_seg() else: self.source_loader = ChengOCTloader( data_root=args.cheng_oct, batch=args.batch, scale=args.scale, flip=args.flip, rotate=args.rotate, enhance_p=args.enhance_p).data_load() self.target_loader, _ = ChallengeOCTloader( data_root=args.challenge_oct, batch=args.batch, scale=args.scale).data_load() print_args(args) self.args = args self.new_lr = self.args.lr self.model = SegTransferModel(args) if args.predict: self.validate_loader(self.target_loader) else: self.train_validate() def train_validate(self): for epoch in range(self.args.start_epoch, self.args.n_epochs): _ = adjust_lr(self.args.lr, self.model.optimizer_G, epoch, [40, 80, 160, 240]) new_lr = adjust_lr(self.args.lr, self.model.optimizer_D, epoch, [40, 80, 160, 240]) self.new_lr = min(new_lr, self.new_lr) self.epoch = epoch self.train() if epoch % self.args.validate_freq == 0 and epoch > self.args.save_freq: self.validate() # self.validate_loader(self.normal_test_loader) # self.validate_loader(self.amd_fundus_loader) # self.validate_loader(self.myopia_fundus_loader) print('\n', '*' * 10, 'Program Information', '*' * 10) print('Node: {}'.format(self.args.node)) print('GPU: {}'.format(self.args.gpu)) print('Version: {}\n'.format(self.args.version)) def train(self): self.model.train() prev_time = time.time() target_loader_iter = self.target_loader.__iter__() # target_loader_isee_iter = self.target_loader.__iter__() for i, (image_source, mask_source_gt, _) in enumerate(self.source_loader): mask_source_gt = mask_source_gt.cuda(non_blocking=True) image_source = image_source.cuda(non_blocking=True).float() image_target, _ = next(target_loader_iter) image_target = image_target.cuda(non_blocking=True) output_source_mask, output_target_mask, logs = \ self.model.process(image_source, mask_source_gt, image_target) # if self.epoch % 2 == 0: # # train on IDRiD dataset # image_target, _, _ = next(target_loader_iter) # image_target = image_target.cuda(non_blocking=True) # output_source_mask, output_target_mask, logs = \ # self.model.process(image_source, mask_source_gt, image_target) # else: # # train on iSee dataset # image_target, _, = next(target_loader_isee_iter) # image_target = image_target.cuda(non_blocking=True) # output_source_mask, output_target_mask, logs = \ # self.model.process(image_source, mask_source_gt, image_target) # -------------- # Log Progress # -------------- # Determine approximate time left batches_done = self.epoch * self.source_loader.__len__() + i batches_left = self.args.n_epochs * self.source_loader.__len__( ) - batches_done time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) prev_time = time.time() # Print log sys.stdout.write( "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] ETA: %s" % (self.epoch, self.args.n_epochs, i, self.source_loader.__len__(), logs['dis_loss'].item(), logs['gen_loss'].item(), time_left)) # -------------- # Visdom # -------------- if i % self.args.vis_freq == 0: image_source = image_source[:self.args.vis_batch] image_target = image_target[:self.args.vis_batch] if self.args.data_modality == 'oct': # OCT: {0, 1, ..., 11}, BWH # BWH -> B1WH, mask_source_gt = mask_source_gt[:self.args. vis_batch].unsqueeze( dim=1) / 11 # B1WH output_source_mask = torch.clamp( output_source_mask[:self.args.vis_batch] / 11, 0, 1) output_target_mask = torch.clamp( output_target_mask[:self.args.vis_batch] / 11, 0, 1) else: # fundus: {0, 1}, B1WH mask_source_gt = mask_source_gt[:self.args.vis_batch] output_source_mask = torch.clamp( output_source_mask[:self.args.vis_batch], 0, 1) output_target_mask = torch.clamp( output_target_mask[:self.args.vis_batch], 0, 1) vim_images = torch.cat([ image_source, mask_source_gt, output_source_mask, image_target, output_target_mask ], dim=0) self.vis.images(vim_images, win_name='train', nrow=self.args.vis_batch) if i + 1 == self.source_loader.__len__(): self.vis.plot_multi_win( dict(dis_loss=logs['dis_loss'].item(), seg_loss=logs['seg_loss'].item(), lr=self.new_lr)) self.vis.plot_single_win(dict( gen_loss=logs['gen_loss'].item(), gen_fm_loss=logs['gen_fm_loss'].item(), gen_gan_loss=logs['gen_gan_loss'].item(), gen_content_loss=logs['gen_content_loss'].item(), gen_style_loss=logs['gen_style_loss'].item()), win='gen_loss') def validate(self): self.model.eval() with torch.no_grad(): for i, (image, _) in enumerate(self.target_loader): image = image.cuda(non_blocking=True).float() # forward output_mask = self.model(image) if i % self.args.vis_freq_inval == 0: image = image[:self.args.vis_batch] if self.args.data_modality == 'oct': # OCT: {0, 1, ..., 11} # gt: BWH # model output: BCWH (C=12) # BCWH -> BWH -> B1WH output_mask = F.log_softmax(output_mask, dim=1) _, output_mask = torch.max(output_mask, dim=1) output_mask = output_mask.float().unsqueeze(dim=1) # {0, 1, ..., 11} -> (0, 1) output_mask = torch.clamp( output_mask[:self.args.vis_batch] / 11, 0, 1) else: # fundus: {0, 1}, B1WH output_mask = output_mask[:self.args.vis_batch] save_images = torch.cat([image, output_mask], dim=0) output_save = os.path.join(self.args.output_root, self.args.project, 'output', self.args.version, 'val') if not os.path.exists(output_save): os.makedirs(output_save) tv.utils.save_image(save_images, os.path.join(output_save, '{}.png'.format(i)), nrow=self.args.vis_batch) # print('val: [Batch {}/{}]'.format(i, self.target_loader.__len__())) save_ckpt(version=self.args.version, state={ 'epoch': self.epoch, 'state_dict_G': self.model.model_G.state_dict(), 'state_dict_D': self.model.model_D.state_dict(), }, epoch=self.epoch, args=self.args) print('Save ckpt successfully!') def validate_loader(self, dataloader): self.model.eval() with torch.no_grad(): for i, (image, image_name) in enumerate(dataloader): image = image.cuda(non_blocking=True).float() # forward output_mask = self.model(image) if i % self.args.vis_freq_inval == 0: image = image[:self.args.vis_batch] if self.args.data_modality == 'oct': # OCT: {0, 1, ..., 11} # gt: BWH # model output: BCWH (C=12) # BCWH -> BWH -> B1WH output_mask = F.log_softmax(output_mask, dim=1) _, output_mask = torch.max(output_mask, dim=1) output_mask = output_mask.float().unsqueeze(dim=1) # {0, 1, ..., 11} -> (0, 1) output_mask = torch.clamp( output_mask[:self.args.vis_batch] / 11, 0, 1) else: # fundus: {0, 1}, B1WH output_mask = output_mask[:self.args.vis_batch] save_images = torch.cat([image, output_mask], dim=0) output_save = os.path.join(self.args.output_root, self.args.project, 'output', self.args.version, 'val') if not os.path.exists(output_save): os.makedirs(output_save) tv.utils.save_image(save_images, os.path.join( output_save, '{}.png'.format(image_name[0])), nrow=self.args.vis_batch) def predict(self): self.model.eval() with torch.no_grad(): for i, (image, _, item_name) in enumerate(self.target_loader): image = image.cuda(non_blocking=True).float() if self.args.batch == 1: if self.args.data_modality == 'oct': case_name, image_name = item_name case_name = case_name[0] image_name = image_name[0] else: case_name = 'fundus' image_name = item_name[0] else: raise NotImplementedError('error') # forward output_mask = self.model(image) dim_channel = 1 if self.args.data_modality == 'oct': # mask prob for CRF mask_prob = F.softmax(output_mask, dim=dim_channel) # output the segmentation mask output_mask = F.log_softmax(output_mask, dim=dim_channel) _, output_mask = torch.max(output_mask, dim=dim_channel) output_mask = output_mask.float().unsqueeze( dim=dim_channel) # {0, 1, ..., 11} -> (0, 1) _output_mask = torch.clamp(output_mask / 11, 0, 1) if self.args.use_crf: # CHW -> HWC (224, 224, 1) # optimize: tensor.permute(2, 0, 1) _image = image.squeeze(dim=0).cpu().transpose( 0, 2).transpose(0, 1) # OCT, 1 channel. (224, 224, 1) -> (224, 224, 3) _image = _image.repeat(1, 1, 3) mask = mask_prob.squeeze(dim=0).cpu() crf_mask = dense_crf( np.array(_image).astype(np.uint8), mask) _crf_mask = torch.Tensor(crf_mask.astype( np.float)) / 11 # HW -> BCHW _crf_mask = _crf_mask.expand((1, 1, -1, -1)).cuda() else: _crf_mask = output_mask else: # fundus: {0, 1}, B1WH _output_mask = output_mask.clamp(0, 1) # raise NotImplementedError('error for fundus mode') save_images = torch.cat([image, _output_mask], dim=0) output_save_path = os.path.join( '/home/imed/new_disk/workspace/', self.args.project, 'output', self.args.version, 'predict') save_name = '{}_{}.png'.format(case_name, image_name) self.vis.images(save_images, win_name='predict') if not os.path.exists(output_save_path): os.makedirs(output_save_path) tv.utils.save_image(save_images, os.path.join(output_save_path, save_name), nrow=2) pdb.set_trace() # --------- # save mask # --------- # To optimize save_flag = False if save_flag: save_path = os.path.join(mask_vgg_root, case_name) if not os.path.exists(save_path): os.makedirs(save_path) self.save_oct(output_mask, os.path.join(save_path, image_name)) save_path = os.path.join(mask_crf_root, case_name) if not os.path.exists(save_path): os.makedirs(save_path) self.save_oct(crf_mask, os.path.join(save_path, image_name), crf_mode=True) def save_oct(self, tensor, filename, crf_mode=False): if crf_mode: misc.imsave(filename, tensor) else: B, C, _, _ = tensor.shape assert B == 1 and C == 1, 'error about shape' tensor = tensor.squeeze() ndarr = tensor.cpu().numpy() misc.imsave(filename, ndarr)
class ResnetRunner(object): def __init__(self): args = ParserArgs().args cuda_visible(args.gpu) model = resnet50(in_channels=1, num_classes=2) model = nn.DataParallel(model).cuda() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # Optionally resume from a checkpoint if args.resume: ckpt_root = os.path.join('/root/workspace', args.project, 'checkpoints') ckpt_path = os.path.join(ckpt_root, args.resume) if os.path.isfile(ckpt_path): print("=> loading checkpoint '{}'".format(args.resume)) # checkpoint = torch.load(ckpt_path) # args.start_epoch = checkpoint['epoch'] # self.val_best_iou = checkpoint['best_iou'] # model.load_state_dict(checkpoint['state_dict']) # optimizer.load_state_dict(checkpoint['optimizer']) # print("=> loaded checkpoint '{}' (epoch {})" # .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True self.vis = Visualizer(env='{}'.format(args.version), port=args.port) self.train_loader = ultraLoader(root=args.dataroot, batch=args.batch, version='train').data_load() self.val_loader = ultraLoader(root=args.dataroot, batch=args.batch, version='validation').data_load() self.test_loader = ultraLoader(root=args.dataroot, batch=args.batch, version='test_ours').data_load() self.test_loader_bigan = ultraLoader(root=args.dataroot, batch=args.batch, version='bigan').data_load() self.test_loader_cyclegan = ultraLoader( root=args.dataroot, batch=args.batch, version='cyclegan').data_load() print_args(args) self.args = args self.model = model self.optimizer = optimizer self.criterion = nn.CrossEntropyLoss().cuda() def train_test(self): self.best_acc = 0 for epoch in range(self.args.n_epochs): adjust_lr(self.args.lr, self.optimizer, epoch, 30) self.epoch = epoch self.train() self.test(self.val_loader, 'validation') self.test(self.test_loader, 'test_ours') self.test(self.test_loader_bigan, 'bigan') self.test(self.test_loader_cyclegan, 'cyclegan') print('\n', '*' * 10, 'Program Information', '*' * 10) print('Node: {}'.format(self.args.node)) print('Version: {}\n'.format(self.args.version)) def train(self): self.model.train() for i, (img, label) in enumerate(self.train_loader): img = img.cuda(non_blocking=True) label = label.cuda(non_blocking=True) output = self.model(img) _, pred = torch.max(output, 1) loss = self.criterion(output, label) self.optimizer.zero_grad() loss.backward() self.optimizer.step() if i % 2 == 0: self.vis.images(img[0].squeeze(), name='train', img_name='{}_{}'.format( label[0].item(), pred[0].item())) if i + 1 == self.train_loader.__len__(): self.vis.plot_many(dict(loss=loss.item())) if i % self.args.print_freq == 0: print('[{}] Epoch: [{}][{}/{}]\t, Loss: {:.4f}'.format( datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), self.epoch, i, self.train_loader.__len__(), loss)) def test(self, test_loader, version): prob_list = [] pred_list = [] true_list = [] self.model.eval() with torch.no_grad(): for i, (img, label) in enumerate(test_loader): img = img.cuda(non_blocking=True) label = label.cuda(non_blocking=True) output = self.model(img) output = F.softmax(output, dim=1) _, pred = torch.max(output, 1) prob_list.append(output[0][1].item()) pred_list.append(pred.item()) true_list.append(label.item()) if i % 3 == 0: self.vis.images(img.squeeze(), name=version, img_name='{}_{}'.format( label.item(), label.item())) # fpr, tpr, thresholds = metrics.roc_curve( # y_true=true_list, y_score=prob_list, pos_label=1, drop_intermediate=False) # # pdb.set_trace() # auc = metrics.auc(fpr, tpr) auc = metrics.roc_auc_score(y_true=true_list, y_score=prob_list) acc = metrics.accuracy_score(y_true=true_list, y_pred=pred_list) if version == 'validation': is_best = acc > self.best_acc self.best_acc = max(acc, self.best_acc) save_ckpt(version=self.args.version, state={ 'epoch': self.epoch + 1, 'state_dict': self.model.state_dict(), 'best_acc': self.best_acc, 'optimizer': self.optimizer.state_dict(), }, is_best=is_best, epoch=self.epoch + 1, project='2018_OCT_transfer') print('Save ckpt successfully!') print('*' * 10, 'Auc = {:.3f}, Acc = {:.3f}'.format(auc, acc), '*' * 10) self.vis.plot_legend(win='auc', name=version, y=auc) self.vis.plot_legend(win='acc', name=version, y=acc)