def main(): ############# # init args # ############# train_T2_path = '/home/Multi_Modality/data/fold4/train/T2_aug' train_target_path = '/home/Multi_Modality/data/fold4/train/label_aug' train_DWI_path = '/home/Multi_Modality/data/fold4/train/DWI_aug' test_T2_path = '/home/Multi_Modality/data/fold4/test/test_T2' test_target_path = '/home/Multi_Modality/data/fold4/test/test_label' test_DWI_path = '/home/Multi_Modality/data/fold4/test/test_DWI' args = get_args() best_prec1 = 0. best_prec2 = 0. best_prec3 = 0. args.cuda = torch.cuda.is_available() if args.inference == '': args.save = args.save or 'work/AMRSegNet_fold4_SE.{}'.format(datestr()) else: args.save = args.save or 'work/AMRSegNet_fold4_SE_inference.{}'.format( datestr()) weight_decay = args.weight_decay setproctitle.setproctitle(args.save) torch.manual_seed(1) if args.cuda: torch.cuda.manual_seed(1) if args.inference == '': # writer for tensorboard if args.save and args.inference == '': idx = args.save.rfind('/') log_dir = 'runs' + args.save[idx:] print('log_dir', log_dir) writer = SummaryWriter(log_dir) else: writer = SummaryWriter() else: idx = args.save.rfind('/') log_dir = 'runs' + args.save[idx:] print('log_dir', log_dir) writer = SummaryWriter(log_dir) ######################### # building AMRSegNet # ######################### print("building AMRSegNet-----") batch_size = args.ngpu * args.batchSz # model = unet.UNet(relu=False) model = AMRSegNet_noalpha.AMRSegNet() x = torch.zeros((1, 1, 256, 256)) writer.add_graph(model, (x, x)) if args.cuda: model = model.cuda() model = nn.parallel.DataParallel(model, list(range(args.ngpu))) if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict'], strict=False) print("=> loaded checkpoint '{}' (epoch {})".format( args.evaluate, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: model.apply(weights_init) print('Number of params: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) # if args.cuda: # model = model.cuda() if os.path.exists(args.save): shutil.rmtree(args.save) os.makedirs(args.save, exist_ok=True) # define a logger and write information logger = Logger(os.path.join(args.save, 'log.txt')) logger.print3('batch size is %d' % args.batchSz) logger.print3('nums of gpu is %d' % args.ngpu) logger.print3('num of epochs is %d' % args.nEpochs) logger.print3('start-epoch is %d' % args.start_epoch) logger.print3('weight-decay is %e' % args.weight_decay) logger.print3('optimizer is %s' % args.opt) ################ # prepare data # ################ # train_transform = transforms.Compose([RandomHorizontalFlip(p=0.7), # RandomRotation(30), # Crop(), # ToTensor(), # Normalize(0.5, 0.5)]) train_transform = transforms.Compose( [Crop(), ToTensor(), Normalize(0.5, 0.5)]) # train_transform = transforms.Compose([Crop(), ToTensor(), Normalize(0.5, 0.5)]) test_transform = transforms.Compose( [Crop(), ToTensor(), Normalize(0.5, 0.5)]) # inference dataset if args.inference != '': if not args.resume: print("args.resume must be set to do inference") exit(1) kwargs = {'num_workers': 0} if args.cuda else {} T2_src = args.inference DWI_src = args.dwiinference tar = args.target inference_batch_size = 1 dataset = Lung_dataset(image_path=T2_src, image2_path=DWI_src, target_path=tar, transform=test_transform) loader = DataLoader(dataset, batch_size=inference_batch_size, shuffle=False, **kwargs) inference(args, loader, model) return # tarin dataset print("loading train set --- ") kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {} train_set = Lung_dataset(image_path=train_T2_path, image2_path=train_DWI_path, target_path=train_target_path, transform=train_transform) test_set = Lung_dataset(image_path=test_T2_path, image2_path=test_DWI_path, target_path=test_target_path, transform=test_transform, mode='test') train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, **kwargs) test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, **kwargs) # class_weights target_mean = train_set.get_target_mean() bg_weight = target_mean / (1. + target_mean) fg_weight = 1. - bg_weight class_weights = torch.FloatTensor([bg_weight, fg_weight]) if args.cuda: class_weights = class_weights.cuda() ############# # optimizer # ############# lr = 0.7 * 1e-2 if args.opt == 'sgd': optimizer = optim.SGD(model.parameters(), lr=1e-1, momentum=0.99, weight_decay=weight_decay) elif args.opt == 'adam': optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) elif args.opt == 'rmsprop': optimizer = optim.RMSprop(model.parameters(), weight_decay=weight_decay) # loss function loss_fn = {} loss_fn['surface_loss'] = SurfaceLoss() loss_fn['ti_loss'] = TILoss() loss_fn['dice_loss'] = DiceLoss() loss_fn['l1_loss'] = nn.L1Loss() loss_fn['CELoss'] = nn.CrossEntropyLoss() ############ # training # ############ trainF = open(os.path.join(args.save, 'train.csv'), 'w') testF = open(os.path.join(args.save, 'test.csv'), 'w') err_best = 0. for epoch in range(1, args.nEpochs + 1): # adjust_opt(args.opt, optimizer, epoch) if epoch > 20: lr = 1e-3 if epoch > 30: lr = 1e-4 if epoch > 50: lr = 1e-5 # if epoch > 40: # lr = 1e-5 for param_group in optimizer.param_groups: param_group['lr'] = lr mean_loss = train(args, epoch, model, train_loader, optimizer, trainF, loss_fn, writer) dice, recall, precision = test(args, epoch, model, test_loader, optimizer, testF, loss_fn, logger, writer) writer.add_scalar('fold4_train_loss/epoch', mean_loss, epoch) is_best1, is_best2, is_best3 = False, False, False if dice > best_prec1: is_best1 = True best_prec1 = dice save_checkpoint( { 'epoch': epoch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1 }, is_best1, args.save, "AMRSegNet_dice") trainF.close() testF.close() writer.close()
def train(model, device, args, num_fold=0): dataset_train = myDataset(args.data_root, args.target_root, args.crop_size, "train", k_fold=args.k_fold, imagefile_csv=args.dataset_file_list, num_fold=num_fold) dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) num_train_data = len(dataset_train) # 训练数据大小 dataset_val = myDataset(args.data_root, args.target_root, args.crop_size, "val", k_fold=args.k_fold, imagefile_csv=args.dataset_file_list, num_fold=num_fold) dataloader_val = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=True) num_train_val = len(dataset_val) # 验证数据大小 #################### writer = SummaryWriter(log_dir=args.log_dir[num_fold], comment=f'tb_log') if args.optim == "SGD": opt = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) else: opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # 定义损失函数 if args.OHEM: criterion = OhemCrossEntropy(thres=0.8, min_kept=10000) else: criterion = nn.CrossEntropyLoss( torch.tensor(args.class_weight, device=device)) criterion_dice = DiceLoss() cp_manager = utils.save_checkpoint_manager(3) step = 0 for epoch in range(args.num_epochs): model.train() lr = utils.poly_learning_rate(args, opt, epoch) # 学习率调节 with tqdm( total=num_train_data, desc= f'[Train] fold[{num_fold}/{args.k_fold}] Epoch[{epoch + 1}/{args.num_epochs} LR{lr:.8f}] ', unit='img') as pbar: for batch in dataloader_train: step += 1 # 读取训练数据 image = batch["image"] label = batch["label"] assert len(image.size()) == 4 assert len(label.size()) == 3 image = image.to(device, dtype=torch.float32) label = label.to(device, dtype=torch.long) # 前向传播 opt.zero_grad() outputs = model(image) main_out = outputs["main_out"] # 计算损失 diceloss = criterion_dice(main_out, label) celoss = criterion(main_out, label) totall_loss = celoss + diceloss * args.dice_weight if "sim_loss" in outputs.keys(): totall_loss += outputs["sim_loss"] * 0.2 if "aux_out" in outputs.keys(): # 计算辅助损失函数 aux_losses = 0 for aux_p in outputs["aux_out"]: auxloss = (criterion(aux_p, label) + criterion_dice( aux_p, label) * args.dice_weight) * args.aux_weight totall_loss += auxloss aux_losses += auxloss if "mu" in outputs.keys(): # EMAU 的基更新 with torch.no_grad(): mu = outputs["mu"] mu = mu.mean(dim=0, keepdim=True) momentum = 0.9 # model.emau.mu *= momentum # model.emau.mu += mu * (1 - momentum) model.effcient_module.em.mu *= momentum model.effcient_module.em.mu += mu * (1 - momentum) if "mu1" in outputs.keys(): with torch.no_grad(): mu1 = outputs['mu1'].mean(dim=0, keepdim=True) model.donv_up1.em.mu = model.donv_up1.em.mu * 0.9 + mu1 * ( 1 - 0.9) mu2 = outputs['mu2'].mean(dim=0, keepdim=True) model.donv_up2.em.mu = model.donv_up2.em.mu * 0.9 + mu2 * ( 1 - 0.9) mu3 = outputs['mu3'].mean(dim=0, keepdim=True) model.donv_up3.em.mu = model.donv_up3.em.mu * 0.9 + mu3 * ( 1 - 0.9) mu4 = outputs['mu4'].mean(dim=0, keepdim=True) model.donv_up4.em.mu = model.donv_up4.em.mu * 0.9 + mu4 * ( 1 - 0.9) totall_loss.backward() opt.step() if step % 5 == 0: writer.add_scalar("Train/CE_loss", celoss.item(), step) writer.add_scalar("Train/Dice_loss", diceloss.item(), step) if args.aux: writer.add_scalar("Train/aux_losses", aux_losses, step) if "sim_loss" in outputs.keys(): writer.add_scalar("Train/sim_loss", outputs["sim_loss"], step) writer.add_scalar("Train/Totall_loss", totall_loss.item(), step) pbar.set_postfix(**{'loss': totall_loss.item()}) # 显示loss pbar.update(image.size()[0]) if (epoch + 1) % args.val_step == 0: # 验证 mDice, mIoU, mAcc, mSensitivity, mSpecificity = val( model, dataloader_val, num_train_val, device, args) writer.add_scalar("Valid/Dice_val", mDice, step) writer.add_scalar("Valid/IoU_val", mIoU, step) writer.add_scalar("Valid/Acc_val", mAcc, step) writer.add_scalar("Valid/Sen_val", mSensitivity, step) writer.add_scalar("Valid/Spe_val", mSpecificity, step) # 写入csv文件 val_result = [ num_fold, epoch + 1, mDice, mIoU, mAcc, mSensitivity, mSpecificity ] with open(args.val_result_file, "a") as f: w = csv.writer(f) w.writerow(val_result) # 保存模型 cp_manager.save( model, os.path.join(args.checkpoint_dir[num_fold], f'CP_epoch{epoch + 1}_{float(mDice):.4f}.pth'), float(mDice)) if (epoch + 1) == (args.num_epochs): torch.save( model.state_dict(), os.path.join( args.checkpoint_dir[num_fold], f'CP_epoch{epoch + 1}_{float(mDice):.4f}.pth'))
def inference(args, loader, model): src = args.inference model.eval() dice_list = [] mean_precision = [] mean_recall = [] mean_hausdorff = [] mean_jaccard = [] with torch.no_grad(): for num, sample in enumerate(loader): data, data2, target = sample['image'], sample['image_b'], sample[ 'target'] if args.cuda: data, data2, target = data.cuda(), data2.cuda(), target.cuda() data, data2, target = Variable(data), Variable(data2), Variable( target) output = model(data, data2) loss, jaccard = DiceLoss.dice_coeficient(output, target) precision, recall = confusion(output, target) hausdorff_distance = compute_hausdorff(output.cpu().numpy(), target.cpu().numpy()) # dice = loss.cpu().numpy().astype(np.float32) dice = loss.cpu().numpy() dice_list.append(dice) mean_precision.append(precision.item()) mean_recall.append(recall.item()) mean_hausdorff.append(hausdorff_distance) mean_jaccard.append(jaccard.item()) data = (data * 0.5 + 0.5) data2 = (data2 * 0.5 + 0.5) img = make_grid(data, padding=20).cpu().numpy().transpose(1, 2, 0)[:, :, 0] img2 = make_grid(data2, padding=20).cpu().numpy().transpose(1, 2, 0)[:, :, 0] target = target.view(data.shape) target = target.float() gt = make_grid(target, padding=20).cpu().numpy().transpose(1, 2, 0)[:, :, 0] # _, pre = output_softmax.max(1) pre = output > 0.5 pre = pre.float() pre = pre.view(data.shape) pre = make_grid(pre, padding=20).cpu().numpy().transpose(1, 2, 0)[:, :, 0] gt_img = label2rgb(gt, img, bg_label=0) pre_img = label2rgb(pre, img, bg_label=0) gt_img2 = label2rgb(gt, img2, bg_label=0) fig = plt.figure() ax = fig.add_subplot(231) ax.imshow(gt_img) ax.set_title('T2 ground truth') ax.axis('off') ax = fig.add_subplot(233) ax.imshow(pre_img) ax.set_title('prediction') ax.axis('off') ax = fig.add_subplot(232) ax.imshow(gt_img2) ax.set_title('DWI ground truth') ax.axis('off') ax = fig.add_subplot(234) ax.imshow(img) ax.set_title('T2 image') ax.axis('off') ax = fig.add_subplot(235) ax.imshow(img2) ax.set_title('DWI image') ax.axis('off') fig.tight_layout() fig.savefig( '/home/Multi_Modality/data/fold5/inference/AMRSegNet_noalpha/%d_%4f.png' % (num, dice)) print('processing {}/{}\r dice:{}'.format(num, len(loader.dataset), dice)) mean_jaccard = np.array(mean_jaccard).mean() mean_dice = np.array(dice_list).mean() std_dice = np.std(np.array(dice_list)) mean_recall = np.mean(mean_recall) mean_precision = np.mean(mean_precision) mean_hausdorff = np.mean(mean_hausdorff) F1_score = 2 * mean_recall * mean_precision / (mean_recall + mean_precision) print('mean_jaccard: %4f' % mean_jaccard) print('mean_dice: %4f' % mean_dice) print('std: %4f' % std_dice) print('mean_recall: %4f' % mean_recall) print('mean_precision: %4f' % mean_precision) print('F1_score: ', F1_score) print('mean_hausdorff: %4f' % mean_hausdorff)
def test(args, epoch, model, test_loader, optimizer, testF, loss_fn, logger, writer): model.eval() mean_dice = [] mean_jaccard = [] mean_precision = [] mean_recall = [] mean_hausdorff = [] with torch.no_grad(): for sample in test_loader: data, data2, target = sample['image'], sample['image_b'], sample[ 'target'] if args.cuda: data, data2, target = data.cuda(), data2.cuda(), target.cuda() data, data2, target = Variable(data), Variable(data2), Variable( target, requires_grad=False) output = model(data, data2) # target = target.view(target.numel()) # loss = loss_fn['dice_loss'](output, target[:,:,7:-7,7:-7]) # dice = 1 - loss # m = nn.Softmax(dim=1) # output = m(output) # pdb.set_trace() # Hausdorff Distance hausdorff_distance = compute_hausdorff(output.cpu().numpy(), target.cpu().numpy()) # Dice coefficient dice, jaccard = DiceLoss.dice_coeficient(output, target) precision, recall = confusion(output, target) mean_precision.append(precision.item()) mean_recall.append(recall.item()) mean_dice.append(dice.item()) mean_jaccard.append(jaccard.item()) mean_hausdorff.append(hausdorff_distance) # show the last sample shape = [data.shape[0], 1, data.shape[2], data.shape[3]] if epoch % 1 == 0: # img = make_grid(data, padding=20).cpu().numpy().transpose(1, 2, 0)[0] data = (data * 0.5 + 0.5) data2 = (data2 * 0.5 + 0.5) img = make_grid(data, padding=20).cpu().numpy().transpose(1, 2, 0)[:, :, 0] img2 = make_grid(data2, padding=20).cpu().numpy().transpose(1, 2, 0)[:, :, 0] # print('img.shape', img.shape) target = target.view(shape) target = target.float() gt = make_grid(target, padding=20).cpu().numpy().transpose(1, 2, 0)[:, :, 0] # _, pre = output_softmax.max(1) pre = output > 0.5 pre = pre.float() pre = pre.view(shape) pre = make_grid(pre, padding=20).cpu().numpy().transpose(1, 2, 0)[:, :, 0] gt_img = label2rgb(gt, img, bg_label=0) pre_img = label2rgb(pre, img, bg_label=0) gt_img2 = label2rgb(gt, img2, bg_label=0) fig = plt.figure() ax = fig.add_subplot(311) ax.imshow(gt_img) ax.set_title('T2 ground truth') ax = fig.add_subplot(312) ax.imshow(pre_img) ax.set_title('prediction') ax = fig.add_subplot(313) ax.imshow(gt_img2) ax.set_title('DWI ground truth') fig.tight_layout() writer.add_figure('test result', fig, epoch) fig.clear() writer.add_scalar('fold4_test_dice/epoch', np.mean(mean_dice), epoch) writer.add_scalar('fold4_test_jaccard/epoch', np.mean(mean_jaccard), epoch) writer.add_scalar('fold4_test_precisin/epoch', np.mean(mean_precision), epoch) writer.add_scalar('fold4_test_recall/epoch', np.mean(mean_recall), epoch) writer.add_scalar('fold4_hausdorff_distance/epoch', np.mean(mean_hausdorff), epoch) print('test mean_dice: ', np.mean(mean_dice)) print('test mean jaccard: ', np.mean(mean_jaccard)) print('mean_dice_length ', len(mean_dice)) testF.write('{},{},{},{}\n'.format(epoch, np.mean(mean_dice), np.mean(mean_precision), np.mean(mean_recall))) testF.flush() return np.mean(mean_dice), np.mean(mean_recall), np.mean( mean_precision)
def train(args, epoch, model, train_loader, optimizer, trainF, loss_fn, writer): model.train() nProcessed = 0 nTrain = len(train_loader.dataset) loss_list = [] print('--------------------Epoch{}------------------------'.format(epoch)) for batch_idx, sample in enumerate(train_loader): # read data data, data2, target = sample['image'], sample['image_b'], sample[ 'target'] # pdb.set_trace() if args.cuda: data, data2, target = data.cuda(), data2.cuda(), target.cuda() data, data2, target = Variable(data), Variable(data2), Variable( target, requires_grad=False) # print('data.shape: ', data.shape) # print('data2.shape: ', data2.shape) # feed to model output = model(data, data2) target = target.view(output.shape[0], target.numel() // output.shape[0]) # loss loss = loss_fn['dice_loss'](output, target) target = target.long() dice, jaccard = DiceLoss.dice_coeficient(output > 0.5, target) precision, recall = confusion(output > 0.5, target) # back propagation optimizer.zero_grad() loss.backward() optimizer.step() # show some result on tensorboard nProcessed += len(data) partialEpoch = epoch + batch_idx / len(train_loader) - 1 print('Train Epoch: {:.2f} [{}/{} ({:.0f}%)]\tLoss: {:.8f}'.format( partialEpoch, nProcessed, nTrain, 100. * batch_idx / len(train_loader), loss.item())) print('Jaccard index: %6f, soft dice: %6f' % (jaccard, dice)) # writer.add_scalar('train_loss/epoch', loss, partialEpoch) trainF.write('{},{},{}\n'.format(partialEpoch, loss.item(), loss.item())) trainF.flush() # show images on tensorboard with torch.no_grad(): shape = [data.shape[0], 1, data.shape[2], data.shape[3]] if batch_idx % 4 == 0: # img = make_grid(data, padding=20).cpu().numpy().transpose(1, 2, 0)[0] # pdb.set_trace() data = (data * 0.5 + 0.5) data2 = (data2 * 0.5 + 0.5) img = make_grid(data, padding=20).cpu().detach().numpy().transpose( 1, 2, 0)[:, :, 0] img2 = make_grid(data2, padding=20).cpu().detach().numpy().transpose( 1, 2, 0)[:, :, 0] # print('img.shape', img.shape) target = target.view(shape) target = target.float() gt = make_grid(target, padding=20).cpu().detach().numpy().transpose( 1, 2, 0)[:, :, 0] # _, pre = output_softmax.max(1) pre = output > 0.5 pre = pre.float() # pdb.set_trace() pre = pre.view(shape) pre = make_grid(pre, padding=20).cpu().numpy().transpose(1, 2, 0)[:, :, 0] # pdb.set_trace() gt_img = label2rgb(gt, img, bg_label=0) pre_img = label2rgb(pre, img, bg_label=0) gt_img2 = label2rgb(gt, img2, bg_label=0) # pdb.set_trace() fig = plt.figure() ax = fig.add_subplot(311) ax.imshow(gt_img) ax.set_title('T2 ground truth') ax = fig.add_subplot(312) ax.imshow(pre_img) ax.set_title('prediction') ax = fig.add_subplot(313) ax.imshow(gt_img2) ax.set_title('DWI ground truth') fig.tight_layout() writer.add_figure('train result', fig, epoch) fig.clear() loss_list.append(loss.item()) return np.mean(loss_list)
def Train_Val(epoches, net, train_data,val_data): net = net.train() net = net.cuda() loss1 = nn.BCEWithLogitsLoss().cuda() loss2 = DiceLoss().cuda() Sum_Train_miou = 0 Sum_Val_miou=0 for e in range(epoches): #train_loss = 0 train_mean_iou = 0 j = 0 process = tqdm(train_data) losses = [] for data in process: j+=1 with torch.no_grad(): im = Variable(data[0].cuda()) label = Variable(data[1].cuda()) #lable_onehot #label1 = Variable(data[2].cuda()) #print("im.shape:",im.shape) #torch.Size([2, 3, 256, 768]) #print("label.shape:",label.shape) #torch.Size([2, 8, 256, 768]) out = net(im) #out_softmax=F.log_softmax(out, dim=1) sig = torch.sigmoid(out) loss = loss1(out,label)+loss2(sig,label) losses.append(loss.item()) #backward optimizer.zero_grad() loss.backward() optimizer.step() # Update learning rate process.set_postfix_str(f"loss {np.mean(losses)}") pred = torch.argmax(F.softmax(out, dim=1), dim=1) mask = torch.argmax(F.softmax(label, dim=1), dim=1) #print("pred.shape:",pred.shape)#torch.Size([2, 256, 768]) #print("mask.shape:",mask.shape) # torch.Size([2, 256, 768]) result = compute_iou(pred, mask) if j % 200 == 0: tmiou =[] TP_all=0 TA_all=0 for i in range(1, 8): if result["TA"][i] !=0: t_miou_i=result["TP"][i] / result["TA"][i] result_string = "{}: {:.4f} \n".format(i, t_miou_i) print(result_string) tmiou.append(t_miou_i) #tmiou = tmiou / 7 t_miou=np.mean(tmiou) print("train_mean_iou:",t_miou) TP_sum=[] TA_sum=[] for i,j in result["TP"].items(): TP_sum.append(j) for i,j in result["TA"].items(): TA_sum.append(j) TP_sum=np.array(TP_sum) TA_sum=np.array(TA_sum) TP_sum=TP_sum[1:].sum() TA_sum=TA_sum[1:].sum() print("acc:",'%.5f' %(TP_sum/TA_sum)) if j % 500 == 0: torch.save(net.state_dict(), 'deeplabv3p_baidulane.pth') torch.save(net.state_dict(), 'deeplabv3p_baidulane.pth') j=0 #net.load_state_dict(torch.load('./deeplabv3p_baidulane.pth')) #net=net.cuda() process = tqdm(val_data) losses = [] result = { "TP": {i: 0 for i in range(8)}, "TA": {i: 0 for i in range(8)} } net = net.eval() val_mean_iou = 0 for data in process: j+=1 with torch.no_grad(): im = Variable(data[0].cuda()) label = Variable(data[1].cuda()) #label_1 = Variable(data[2].cuda()) # forward out = net(im) sig = torch.sigmoid(out) loss = loss1(out,label)+loss2(sig,label) losses.append(loss.item()) pred = torch.argmax(F.softmax(out, dim=1), dim=1) mask = torch.argmax(F.softmax(label, dim=1), dim=1) result = compute_iou(pred, mask) process.set_postfix_str(f"loss {np.mean(losses)}") if j % 200 == 0: vmiou = [] for i in range(1, 8): if result["TA"][i] !=0: v_miou_i=result["TP"][i] / result["TA"][i] result_string = "{}: {:.4f} \n".format(i, v_miou_i) print(result_string) vmiou.append(v_miou_i) v_miou=np.mean(vmiou) print("val_mean_iou:",v_miou) TP_sum=[] TA_sum=[] for i,j in result["TP"].items(): TP_sum.append(j) for i,j in result["TA"].items(): TA_sum.append(j) TP_sum=np.array(TP_sum) TA_sum=np.array(TA_sum) TP_sum=TP_sum[1:].sum() TA_sum=TA_sum[1:].sum() print("acc:",'%.5f' %(TP_sum/TA_sum)) epoch_str = ('Epoch: {}, Train Mean IoU: {:.5f}, Valid Mean IU: {:.5f} '.format(e, t_miou,v_miou)) print(epoch_str)
def train_general(args): args.optimizer = 'Adam' args.n_classes = 2 args.batch_size = 8 # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus # print(args.model_name) # print(args.test) if args.model_name == 'FCNet': model = FCNet(args).cuda() model = torch.nn.DataParallel(model) if args.optimizer == 'SGD': optimizer = SGD(model.parameters(), .1, weight_decay=5e-4, momentum=.99) elif args.optimizer == 'Adam': optimizer = Adam(model.parameters(), .1, weight_decay=5e-4) criterion = cross_entropy2d scheduler = MultiStepLR(optimizer, [100, 200, 400, 800, 3200], .1) elif args.model_name == 'CENet': model = CE_Net_(args).cuda() model = torch.nn.DataParallel(model) if args.optimizer == 'SGD': optimizer = SGD(model.parameters(), .1, weight_decay=5e-4, momentum=.99) scheduler = MultiStepLR(optimizer, [100, 200, 400, 800, 3200], .1) elif args.optimizer == 'Adam': optimizer = Adam(model.parameters(), .001, weight_decay=5e-4) scheduler = MultiStepLR(optimizer, [400, 3200], .1) # criterion = cross_entropy2d criterion = DiceLoss() # scheduler = MultiStepLR(optimizer, [100, 200, 400, 800, 3200], .1) start_iter = 0 if args.model_path is not None: if os.path.isfile(args.model_path): checkpoint = torch.load(args.model_path) model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) scheduler.load_state_dict(checkpoint["scheduler_state"]) start_iter = checkpoint["epoch"] else: print('Unable to load {}'.format(args.model_name)) train_loader, valid_loader = get_loaders(args) try: os.mkdir('logs/') except: pass try: os.mkdir('results/') except: pass try: os.mkdir('results/' + args.model_name) except: pass writer = SummaryWriter(log_dir='logs/') best = -100.0 i = start_iter flag = True running_metrics_val = Acc_Meter() val_loss_meter = averageMeter() time_meter = averageMeter() # while i <= args.niter and flag: while i <= 300000 and flag: for (images, labels) in train_loader: i += 1 start_ts = time.time() scheduler.step() model.train() images = images.cuda() labels = labels.cuda() optimizer.zero_grad() outputs = model(images) loss = criterion(input=outputs, target=labels) loss.backward() optimizer.step() time_meter.update(time.time() - start_ts) # if (i + 1) % cfg["training"]["print_interval"] == 0: if (i + 1) % 50 == 0: fmt_str = "Iter [{:d}/{:d}] Loss: {:.4f} Time/Image: {:.4f}" print_str = fmt_str.format( i + 1, 300000, loss.item(), time_meter.avg / args.batch_size, ) print(print_str) # logger.info(print_str) # writer.add_scalar("loss/train_loss", loss.item(), i + 1) # time_meter.reset() # if (i + 1) % cfg["training"]["val_interval"] == 0 or (i + 1) == cfg["training"]["train_iters"]: if (i + 1) % 500 == 0 or (i + 1) == 300000: model.eval() with torch.no_grad(): for i_val, (images_val, labels_val) in tqdm(enumerate(valid_loader)): images_val = images_val.cuda() # to(device) labels_val = labels_val.cuda() # to(device) outputs = model(images_val) # val_loss = loss_fn(input=outputs, target=labels_val) val_loss = criterion(input=outputs, target=labels_val) pred = outputs.data.max(1)[1].cpu().numpy() gt = labels_val.data.cpu().numpy() running_metrics_val.update(gt, pred) val_loss_meter.update(val_loss.item()) # writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1) print("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg)) results = running_metrics_val.get_acc() for k, v in results.items(): writer.add_scalar(k, v, i + 1) print(results) val_loss_meter.reset() running_metrics_val.reset() if results['cls_acc'] >= best: best = results['cls_acc'] state = { "epoch": i + 1, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "best": best, } save_path = os.path.join( "results/{}/results_{}_best_model.pkl".format( args.model_name, i + 1), ) torch.save(state, save_path) if (i + 1) == 300000: flag = False break writer.close()