def main(args): savedir = '/home/shyam.nandan/NewExp/final_code/save/' + args.savedir #change path here if not os.path.exists(savedir): os.makedirs(savedir) with open(savedir + '/opts.txt', "w") as myfile: myfile.write(str(args)) rmodel = UNet() rmodel = torch.nn.DataParallel(rmodel).cuda() pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000)) pretrainedEnc.load_state_dict( torch.load(args.pretrainedEncoder)['state_dict']) pretrainedEnc = next(pretrainedEnc.children()).features.encoder model = Net(NUM_CLASSES) model = fill_weights(model, pretrainedEnc) model = torch.nn.DataParallel(model).cuda() #model = train(args, rmodel, model, False) PATH = '/home/shyam.nandan/NewExp/final_code/results/CB_iFL/rmodel_best.pth' rmodel.load_state_dict(torch.load(PATH)) PATH = '/home/shyam.nandan/NewExp/final_code/results/CB_iFL/model_best.pth' model.load_state_dict(torch.load(PATH)) model = train(args, rmodel, model, False)
def test(): device = torch.device(conf.cuda if torch.cuda.is_available() else "cpu") test_dataset = Testinging_Dataset(conf.data_path_test, conf.test_noise_param, conf.crop_img_size) test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) print('Loading model from: {}'.format(conf.model_path_test)) model = UNet(in_channels=conf.img_channel, out_channels=conf.img_channel) print('loading model') model.load_state_dict(torch.load(conf.model_path_test)) model.eval() model.to(device) result_dir = conf.denoised_dir if not os.path.exists(result_dir): os.mkdir(result_dir) for batch_idx, (source, img_cropped) in enumerate(test_loader): source_img = tvF.to_pil_image(source.squeeze(0)) img_truth = img_cropped.squeeze(0).numpy().astype(np.uint8) source = source.to(device) denoised_img = model(source).detach().cpu() img_name = test_loader.dataset.image_list[batch_idx] denoised_result = tvF.to_pil_image( torch.clamp(denoised_img.squeeze(0), 0, 1)) fname = os.path.splitext(img_name)[0] source_img.save(os.path.join(result_dir, f'{fname}-noisy.png')) denoised_result.save(os.path.join(result_dir, f'{fname}-denoised.png')) io.imsave(os.path.join(result_dir, f'{fname}-ground_truth.png'), img_truth)
def main(args): if args.cuda and not torch.cuda.is_available(): raise Exception("No GPU found, please run without --cuda") model = UNet(n_channels=args.colordim, n_classes=args.num_class) model2 = UNet(n_channels=args.colordim2, n_classes=args.num_class2) if args.cuda: model = model.cuda() model2 = model2.cuda() model.load_state_dict(torch.load(args.pretrain_net)) model2.load_state_dict(torch.load(args.pretrain_net2)) model.eval() model2.eval() predDataset = generateDataset(args.pre_root_dir, args.img_size, args.colordim, isTrain=False) predLoader = DataLoader(dataset=predDataset, batch_size=args.predictbatchsize, num_workers=args.threads) with torch.no_grad(): cm_w = np.zeros((2, 2)) for batch_idx, (batch_x, batch_name) in enumerate(predLoader): batch_x = batch_x if args.cuda: batch_x = batch_x.float().cuda() out1 = model(batch_x) prediction2 = torch.cat((batch_x, out1), 1) out = model2(prediction2) pred_prop, pred_label = torch.max(out, 1) pred_label_np = pred_label.cpu().numpy() for id in range(len(batch_name)): predLabel_filename = args.preDir + '/' + batch_name[id] + '.png' pred_label_single = pred_label_np[id, :, :] label_filename = args.label_root_dir + batch_name[id] + '.png' label = io.imread(label_filename) cm = confusion_matrix(label.ravel(), pred_label_single.ravel()) pred_label_single = np.where(pred_label_single > 0, 255, 0) print(np.max(pred_label_single)) print(batch_name[id]) if (np.max(pred_label_single) > 0): io.imsave(predLabel_filename, pred_label_single.astype(np.uint8)) #else: #io.imsave(predLabel_filename, pred_label_single.astype(np.int32)) cm_w = cm_w + cm #OA_s, F1_s, IoU_s = evaluate(cm) #print('OA_s = ' + str(OA_s) + ', F1_s = ' + str(F1_s) + ', IoU = ' + str(IoU_s)) print(cm_w) OA_w, F1_w, IoU_w = evaluate(cm_w) print('OA_w = ' + str(OA_w) + ', F1_w = ' + str(F1_w) + ', IoU = ' + str(IoU_w))
def submit_mnms(model_path, input_data_directory, output_data_directory, device): data_paths = load_path(input_data_directory) net = UNet(n_channels=1, n_classes=4, bilinear=True) net.load_state_dict(torch.load(model_path, map_location=device)) net.to(device) for path in data_paths: ED_np, ES_np = load_phase(path) # HxWxF ED_masks = [] ES_masks = [] for i in range(ED_np.shape[2]): img_np = ED_np[:, :, i] img_tensor = pre_transform(img_np) img_tensor = img_tensor.to(device) mask = predict_img(net, img_tensor) mask = post_transform(img_np, mask[0:3, :, :]) ED_masks.append(mask) for i in range(ES_np.shape[2]): img_np = ES_np[:, :, i] img_tensor = pre_transform(img_np) img_tensor = img_tensor.to(device) mask = predict_img(net, img_tensor) mask = post_transform(img_np, mask[0:3, :, :]) ES_masks.append(mask) ED_masks = np.concatenate(ED_masks, axis=2) ES_masks = np.concatenate(ES_masks, axis=2) save_phase(ED_masks, ES_masks, output_data_directory, path)
import os import cv2 from torchvision import transforms from unet_model import UNet if __name__ == "__main__": # 选择设备,有cuda用cuda,没有就用cpu device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 加载网络,图片单通道,分类为1。 net = UNet(n_channels=1, n_classes=1) # 将网络拷贝到deivce中 net.to(device=device) # 加载模型参数 # net.load_state_dict(torch.load('best_model_100X.pth', map_location=device)) net.load_state_dict(torch.load('best_model.pth', map_location=device)) # 测试模式 net.eval() # 读取所有图片路径 tests_path = glob.glob('dataS/test/*.png') # tests_path = glob.glob('data100X/test/*.png') # 遍历素有图片 for test_path in tests_path: # 保存结果地址 save_res_path = test_path.split('.')[0] + '_res.png' # 读取图片 img = cv2.imread(test_path) # 转为灰度图 img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) # 转为batch为1,通道为1,大小为512*512的数组 img = img.reshape(1, 1, img.shape[0], img.shape[1])
import os import numpy as np import torch import matplotlib.pyplot as plt import tqdm import cv2 from unet_model import UNet #device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') device = 'cpu' best_model_name = 'best_model.pt' best_model = torch.load(best_model_name) model = UNet() model.load_state_dict(best_model['state_dict']) model.eval() model.to(device) test_dir = '../data/poster/images/' out_dir = '../data/poster/model/' test_images = [os.path.join(test_dir, x) for x in os.listdir(test_dir)] counter = 0 i = 0 for i in range(len(test_images)): test_image_one = test_images[i] #if 'post' not in test_image_one: # i += 1 # continue #counter += 1 #i += 1
# console printing redirection logfile = f'logs/train/{args.dataset_name}_{args.epochs}ep_{args.batch_size}bs.log' sys.stdout = Logger(logfile) print(args) # tensorboard writer # writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') checkpoint_dir = args.checkpoints_dir + '/' + args.dataset_name os.makedirs(checkpoint_dir, exist_ok=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ## --- Set up model net = UNet(n_channels=1, n_classes=1).to( device) # we are using gray images and only have one labelled class if args.pretrained_weights: net.load_state_dict(torch.load(args.pretrained_weights)) print(f'Pretrained weights loaded from {args.load}') print( f' Using device {device}\n Network:\n \t{net.n_channels} input channels\n \t{net.n_classes} output channels (classes)\n \t{"Bilinear" if net.bilinear else "Dilated conv"} upscaling' ) # faster convolutions, but more memory # cudnn.benchmark = True ## --- Set up data imgs_dir = 'data/imgs/' + args.dataset_name masks_dir = 'data/masks/' + args.dataset_name print(f'imgs_dir: {imgs_dir} masks_dir: {masks_dir}') dataset = BasicDataset(imgs_dir, masks_dir, args.down_scale) n_val = int(len(dataset) * args.valid_ratio) n_train = len(dataset) - n_val print(f'n_val: {n_val} n_train: {n_train}')
test_in_path= "/home/star/0_code_lhj/DL-SIM-github/TESTING_DATA/microtuble/HE_X2/", transform=ToTensor(), img_type='tif', in_size=256) test_dataloader = torch.utils.data.DataLoader( SRRFDATASET, batch_size=batch_size, shuffle=True, pin_memory=True) # better than for loop model = UNet(n_channels=3, n_classes=1) print("{} paramerters in total".format( sum(x.numel() for x in model.parameters()))) model.cuda(cuda) model.load_state_dict( torch.load( "/home/star/0_code_lhj/DL-SIM-github/MODELS/UNet_SIM3_microtubule.pkl" )) model.eval() for batch_idx, items in enumerate(test_dataloader): image = items['image_in'] image_name = items['image_name'] print(image_name[0]) model.train() image = np.swapaxes(image, 1, 3) image = np.swapaxes(image, 2, 3) image = image.float() image = image.cuda(cuda)
LE_img = imgread(os.path.join(dir_path_LE, "LE_01.tif")) LE_512 = cropImage(LE_img, IMG_SHAPE[0],IMG_SHAPE[1]) sample_le = {} for le_512 in LE_512: tiles = crop_prepare(le_512, CROP_STEP, IMG_SIZE) for n,img in enumerate(tiles): if n not in sample_le: sample_le[n] = [] img = transform.resize(img,(IMG_SIZE*2, IMG_SIZE*2),preserve_range=True,order=3) sample_le[n].append(img) SNR_model = UNet(n_channels=15, n_classes=15) print("{} paramerters in total".format(sum(x.numel() for x in SNR_model.parameters()))) SNR_model.cuda(cuda) SNR_model.load_state_dict(torch.load(SNR_model_path)) # SNR_model.load_state_dict(torch.load(os.path.join(dir_path,"model","LE_HE_mito","LE_HE_0825.pkl"))) SNR_model.eval() SIM_UNET = UNet(n_channels=15, n_classes=1) print("{} paramerters in total".format(sum(x.numel() for x in SIM_UNET.parameters()))) SIM_UNET.cuda(cuda) SIM_UNET.load_state_dict(torch.load(SIM_UNET_model_path)) # SIM_UNET.load_state_dict(torch.load(os.path.join(dir_path,"model","HE_HER_mito","HE_X2_HER_0825.pkl"))) SIM_UNET.eval() SRRFDATASET = ReconsDataset( img_dict=sample_le, transform=ToTensor(), in_norm = LE_in_norm, img_type=".tif",
class UNetObjPrior(nn.Module): """ Wrapper around UNet that takes object priors (gaussians) and images as input. """ def __init__(self, params, depth=5): super(UNetObjPrior, self).__init__() self.in_channels = 4 self.model = UNet(1, self.in_channels, depth, cuda=params['cuda']) self.params = params self.device = torch.device('cuda' if params['cuda'] else 'cpu') def forward(self, im, obj_prior): x = torch.cat((im, obj_prior), dim=1) return self.model(x) def train(self, dataloader_train, dataloader_val): since = time.time() best_loss = float("inf") dataloader_train.mode = 'train' dataloader_val.mode = 'val' dataloaders = {'train': dataloader_train, 'val': dataloader_val} optimizer = optim.SGD(self.model.parameters(), momentum=self.params['momentum'], lr=self.params['lr'], weight_decay=self.params['weight_decay']) train_logger = LossLogger('train', self.params['batch_size'], len(dataloader_train), self.params['out_dir']) val_logger = LossLogger('val', self.params['batch_size'], len(dataloader_val), self.params['out_dir']) loggers = {'train': train_logger, 'val': val_logger} # self.criterion = WeightedMSE(dataloader_train.get_classes_weights(), # cuda=self.params['cuda']) self.criterion = nn.MSELoss() for epoch in range(self.params['num_epochs']): print('Epoch {}/{}'.format(epoch, self.params['num_epochs'] - 1)) print('-' * 10) # Each epoch has a training and validation phase for phase in ['train', 'val']: if phase == 'train': #scheduler.step() self.model.train() else: self.model.eval() # Set model to evaluate mode running_loss = 0.0 running_corrects = 0 # Iterate over data. samp = 1 for i, data in enumerate(dataloaders[phase]): # zero the parameter gradients optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase == 'train'): out = self.forward(data.image, data.obj_prior) loss = self.criterion(out, data.truth) # backward + optimize only if in training phase if phase == 'train': loss.backward() optimizer.step() loggers[phase].update(epoch, samp, loss.item()) samp += 1 loggers[phase].print_epoch(epoch) # Generate train prediction for check if phase == 'train': path = os.path.join(self.params['out_dir'], 'previews', 'epoch_{:04d}.jpg'.format(epoch)) data = dataloaders['val'].sample_uniform() pred = self.forward(data.image, data.obj_prior) im_ = data.image[0] truth_ = data.truth[0] pred_ = pred[0, ...] utls.save_tensors(im_, pred_, truth_, path) if phase == 'val' and (loggers['val'].get_loss(epoch) < best_loss): best_loss = loggers['val'].get_loss(epoch) loggers[phase].save('log_{}.csv'.format(phase)) # save checkpoint if phase == 'val': is_best = loggers['val'].get_loss(epoch) <= best_loss path = os.path.join(self.params['out_dir'], 'checkpoint.pth.tar') utls.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.model.state_dict(), 'best_loss': best_loss, 'optimizer': optimizer.state_dict() }, is_best, path=path) def load_checkpoint(self, path, device='gpu'): if (device != 'gpu'): checkpoint = torch.load(path, map_location=lambda storage, loc: storage) else: checkpoint = torch.load(path) self.model.load_state_dict(checkpoint['state_dict'])
action='store_true', dest='gpu', default=False, help='use cuda') parser.add_option('-c', '--load', dest='load', default=False, help='load file model') (options, args) = parser.parse_args() net = UNet(3, 1) if options.load: net.load_state_dict(torch.load(options.load)) print('Model loaded from {}'.format(options.load)) if options.gpu: net.cuda() cudnn.benchmark = True try: train_net(net, options.epochs, options.batchsize, options.lr, gpu=options.gpu) except KeyboardInterrupt: torch.save(net.state_dict(), 'INTERRUPTED.pth') print('Saved interrupt')
parser.add_argument('--mask_threshold', type=float, default=0.5, help="Minimum probability value to consider a mask pixel white") parser.add_argument('--scale', type=float, default=1.0, help="downscale factor for the input images") args = parser.parse_args() ckpt_str = args.weights_path.split('/')[-1][:-4] logfile = f'logs/segment/{args.dataset_name}_{ckpt_str}.log' sys.stdout = Logger(logfile) print(args) output_dir = args.output_dir + '/' + args.dataset_name os.makedirs(output_dir, exist_ok=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') net = UNet(n_channels=1, n_classes=1).to(device) net.load_state_dict(torch.load(args.weights_path)['state_dict']) net.eval() count = 0 inference_time_total = 0.0 val_score_total = 0.0 image_folder = args.image_folder + '/' + args.dataset_name for i, fn in enumerate(os.listdir(image_folder)): if not fn.endswith('.jpg'): continue count += 1 # input image img_path = os.path.join(image_folder, fn) target_path = img_path.replace('imgs', 'masks') # single image prediction
def test(experiment_path, test_epoch): # ========= CONFIG FILE TO READ FROM ======= config = configparser.RawConfigParser() config.read('./' + experiment_path + '/' + experiment_path + '_config.txt') # =========================================== # run the training on invariant or local path_data = config.get('data paths', 'path_local') model = config.get('training settings', 'model') # original test images (for FOV selection) DRIVE_test_imgs_original = path_data + config.get('data paths', 'test_imgs_original') test_imgs_orig = load_hdf5(DRIVE_test_imgs_original) full_img_height = test_imgs_orig.shape[2] full_img_width = test_imgs_orig.shape[3] # the border masks provided by the DRIVE DRIVE_test_border_masks = path_data + config.get('data paths', 'test_border_masks') test_border_masks = load_hdf5(DRIVE_test_border_masks) # dimension of the patches patch_height = int(config.get('data attributes', 'patch_height')) patch_width = int(config.get('data attributes', 'patch_width')) # the stride in case output with average stride_height = int(config.get('testing settings', 'stride_height')) stride_width = int(config.get('testing settings', 'stride_width')) assert (stride_height < patch_height and stride_width < patch_width) # model name name_experiment = config.get('experiment name', 'name') path_experiment = './' + name_experiment + '/' # N full images to be predicted Imgs_to_test = int(config.get('testing settings', 'full_images_to_test')) # Grouping of the predicted images N_visual = int(config.get('testing settings', 'N_group_visual')) # ====== average mode =========== average_mode = config.getboolean('testing settings', 'average_mode') #N_subimgs = int(config.get('training settings', 'N_subimgs')) #batch_size = int(config.get('training settings', 'batch_size')) #epoch_size = N_subimgs // (batch_size) # #ground truth # gtruth= path_data + config.get('data paths', 'test_groundTruth') # img_truth= load_hdf5(gtruth) # visualize(group_images(test_imgs_orig[0:20,:,:,:],5),'original')#.show() # visualize(group_images(test_border_masks[0:20,:,:,:],5),'borders')#.show() # visualize(group_images(img_truth[0:20,:,:,:],5),'gtruth')#.show() # ============ Load the data and divide in patches patches_imgs_test = None new_height = None new_width = None masks_test = None patches_masks_test = None if average_mode == True: patches_imgs_test, new_height, new_width, masks_test= get_data_testing_overlap( DRIVE_test_imgs_original = DRIVE_test_imgs_original, #original'DRIVE_datasets_training_testing/test_hard_masks.npy' DRIVE_test_groudTruth = path_data + config.get('data paths', 'test_groundTruth'), #masks Imgs_to_test = int(config.get('testing settings', 'full_images_to_test')), patch_height = patch_height, patch_width = patch_width, stride_height = stride_height, stride_width = stride_width) else: patches_imgs_test, patches_masks_test = get_data_testing_test( DRIVE_test_imgs_original = DRIVE_test_imgs_original, #original DRIVE_test_groudTruth = path_data + config.get('data paths', 'test_groundTruth'), #masks Imgs_to_test = int(config.get('testing settings', 'full_images_to_test')), patch_height = patch_height, patch_width = patch_width ) #np.save(path_experiment + 'test_patches.npy', patches_imgs_test) #visualize(group_images(patches_imgs_test,100),'./'+name_experiment+'/'+"test_patches") # ================ Run the prediction of the patches ================================== best_last = config.get('testing settings', 'best_last') # Load the saved model if model == 'UNet': net = UNet(n_channels=1, n_classes=2) elif model == 'UNet_cat': net = UNet_cat(n_channels=1, n_classes=2) else: net = UNet_level4_our(n_channels=1, n_classes=2) # load data test_data = data.TensorDataset(torch.tensor(patches_imgs_test),torch.zeros(patches_imgs_test.shape[0])) test_loader = data.DataLoader(test_data, batch_size=1, pin_memory=True, shuffle=False) trained_model = path_experiment + 'DRIVE_' + str(test_epoch) + 'epoch.pth' print(trained_model) # trained_model= path_experiment+'DRIVE_unet2_B'+str(60*epoch_size)+'.pth' net.load_state_dict(torch.load(trained_model)) net.eval() print('Finished loading model :' + trained_model) net = net.cuda() cudnn.benchmark = True # Calculate the predictions predictions_out = np.empty((patches_imgs_test.shape[0],patch_height*patch_width,2)) for i_batch, (images, targets) in enumerate(test_loader): images = Variable(images.float().cuda()) out1= net(images) pred = out1.permute(0,2,3,1) pred = F.softmax(pred, dim=-1) pred = pred.data.view(-1,patch_height*patch_width,2) predictions_out[i_batch] = pred # ===== Convert the prediction arrays in corresponding images pred_patches_out = pred_to_imgs(predictions_out, patch_height, patch_width, "original") #np.save(path_experiment + 'pred_patches_' + str(test_epoch) + "_epoch" + '.npy', pred_patches_out) #visualize(group_images(pred_patches_out,100),'./'+name_experiment+'/'+"pred_patches") #========== Elaborate and visualize the predicted images ==================== pred_imgs_out = None orig_imgs = None gtruth_masks = None if average_mode == True: pred_imgs_out = recompone_overlap(pred_patches_out,new_height,new_width, stride_height, stride_width) orig_imgs = my_PreProc(test_imgs_orig[0:pred_imgs_out.shape[0],:,:,:]) #originals gtruth_masks = masks_test #ground truth masks else: pred_imgs_out = recompone(pred_patches_out,10,9) # predictions orig_imgs = recompone(patches_imgs_test,10,9) # originals gtruth_masks = recompone(patches_masks_test,10,9) #masks # apply the DRIVE masks on the repdictions #set everything outside the FOV to zero!! # DRIVE MASK #only for visualization kill_border(pred_imgs_out, test_border_masks) # back to original dimensions orig_imgs = orig_imgs[:,:,0:full_img_height,0:full_img_width] pred_imgs_out = pred_imgs_out[:, :, 0:full_img_height, 0:full_img_width] gtruth_masks = gtruth_masks[:, :, 0:full_img_height, 0:full_img_width] print ("Orig imgs shape: "+str(orig_imgs.shape)) print("pred imgs shape: " + str(pred_imgs_out.shape)) print("Gtruth imgs shape: " + str(gtruth_masks.shape)) np.save(path_experiment + 'pred_img_' + str(test_epoch) + "_epoch" + '.npy',pred_imgs_out) # visualize(group_images(orig_imgs,N_visual),path_experiment+"all_originals")#.show() if average_mode == True: visualize(group_images(pred_imgs_out, N_visual), path_experiment + "all_predictions_" + str(test_epoch) + "thresh_epoch") else: visualize(group_images(pred_imgs_out, N_visual), path_experiment + "all_predictions_" + str(test_epoch) + "epoch_no_average") visualize(group_images(gtruth_masks, N_visual), path_experiment + "all_groundTruths") # visualize results comparing mask and prediction: # assert (orig_imgs.shape[0] == pred_imgs_out.shape[0] and orig_imgs.shape[0] == gtruth_masks.shape[0]) # N_predicted = orig_imgs.shape[0] # group = N_visual # assert (N_predicted%group == 0) # ====== Evaluate the results print("\n\n======== Evaluate the results =======================") # predictions only inside the FOV y_scores, y_true = pred_only_FOV(pred_imgs_out, gtruth_masks, test_border_masks) # returns data only inside the FOV ''' print("Calculating results only inside the FOV:") print("y scores pixels: " + str( y_scores.shape[0]) + " (radius 270: 270*270*3.14==228906), including background around retina: " + str( pred_imgs_out.shape[0] * pred_imgs_out.shape[2] * pred_imgs_out.shape[3]) + " (584*565==329960)") print("y true pixels: " + str( y_true.shape[0]) + " (radius 270: 270*270*3.14==228906), including background around retina: " + str( gtruth_masks.shape[2] * gtruth_masks.shape[3] * gtruth_masks.shape[0]) + " (584*565==329960)") ''' # Area under the ROC curve fpr, tpr, thresholds = roc_curve((y_true), y_scores) AUC_ROC = roc_auc_score(y_true, y_scores) # test_integral = np.trapz(tpr,fpr) #trapz is numpy integration print("\nArea under the ROC curve: " + str(AUC_ROC)) rOc_curve = plt.figure() plt.plot(fpr, tpr, '-', label='Area Under the Curve (AUC = %0.4f)' % AUC_ROC) plt.title('ROC curve') plt.xlabel("FPR (False Positive Rate)") plt.ylabel("TPR (True Positive Rate)") plt.legend(loc="lower right") plt.savefig(path_experiment + "ROC.png") # Precision-recall curve precision, recall, thresholds = precision_recall_curve(y_true, y_scores) precision = np.fliplr([precision])[0] # so the array is increasing (you won't get negative AUC) recall = np.fliplr([recall])[0] # so the array is increasing (you won't get negative AUC) AUC_prec_rec = np.trapz(precision, recall) print("\nArea under Precision-Recall curve: " + str(AUC_prec_rec)) prec_rec_curve = plt.figure() plt.plot(recall, precision, '-', label='Area Under the Curve (AUC = %0.4f)' % AUC_prec_rec) plt.title('Precision - Recall curve') plt.xlabel("Recall") plt.ylabel("Precision") plt.legend(loc="lower right") plt.savefig(path_experiment + "Precision_recall.png") # Confusion matrix threshold_confusion = 0.5 print("\nConfusion matrix: Custom threshold (for positive) of " + str(threshold_confusion)) y_pred = np.empty((y_scores.shape[0])) for i in range(y_scores.shape[0]): if y_scores[i] >= threshold_confusion: y_pred[i] = 1 else: y_pred[i] = 0 confusion = confusion_matrix(y_true, y_pred) print(confusion) accuracy = 0 if float(np.sum(confusion)) != 0: accuracy = float(confusion[0, 0] + confusion[1, 1]) / float(np.sum(confusion)) print("Global Accuracy: " + str(accuracy)) specificity = 0 if float(confusion[0, 0] + confusion[0, 1]) != 0: specificity = float(confusion[0, 0]) / float(confusion[0, 0] + confusion[0, 1]) print("Specificity: " + str(specificity)) sensitivity = 0 if float(confusion[1, 1] + confusion[1, 0]) != 0: sensitivity = float(confusion[1, 1]) / float(confusion[1, 1] + confusion[1, 0]) print("Sensitivity: " + str(sensitivity)) precision = 0 if float(confusion[1, 1] + confusion[0, 1]) != 0: precision = float(confusion[1, 1]) / float(confusion[1, 1] + confusion[0, 1]) print("Precision: " + str(precision)) # Jaccard similarity index jaccard_index = jaccard_similarity_score(y_true, y_pred, normalize=True) print("\nJaccard similarity score: " + str(jaccard_index)) # F1 score F1_score = f1_score(y_true, y_pred, labels=None, average='binary', sample_weight=None) print("\nF1 score (F-measure): " + str(F1_score)) ####evaluate the thin vessels thin_3pixel_recall_indivi = [] thin_3pixel_auc_roc = [] for j in range(pred_imgs_out.shape[0]): thick3=opening(gtruth_masks[j, 0, :, :], square(3)) thin_gt = gtruth_masks[j, 0, :, :] - thick3 thin_pred=pred_imgs_out[j, 0, :, :] thin_pred[thick3==1]=0 thin_3pixel_recall_indivi.append(round(thin_recall(thin_gt, pred_imgs_out[j, 0, :, :], thresh=0.5), 4)) thin_3pixel_auc_roc.append(round(roc_auc_score(thin_gt.flatten(), thin_pred.flatten()), 4)) thin_2pixel_recall_indivi = [] thin_2pixel_auc_roc = [] for j in range(pred_imgs_out.shape[0]): thick=opening(gtruth_masks[j, 0, :, :], square(2)) thin_gt = gtruth_masks[j, 0, :, :] - thick #thin_gt_only=thin_gt[thin_gt==1] #print(thin_gt_only) thin_pred=pred_imgs_out[j, 0, :, :] #thin_pred=thin_pred[thin_gt==1] thin_pred[thick==1]=0 thin_2pixel_recall_indivi.append(round(thin_recall(thin_gt, pred_imgs_out[j, 0, :, :], thresh=0.5), 4)) thin_2pixel_auc_roc.append(round(roc_auc_score(thin_gt.flatten(), thin_pred.flatten()), 4)) #print("thin 2vessel recall:", thin_2pixel_recall_indivi) #print('thin 2vessel auc score', thin_2pixel_auc_roc) # Save the results with open(path_experiment + 'test_performances_all_epochs.txt', mode='a') as f: f.write("\n\n" + path_experiment + " test epoch:" + str(test_epoch) + '\naverage mode is:' + str(average_mode) + "\nArea under the ROC curve: %.4f" % (AUC_ROC) + "\nArea under Precision-Recall curve: %.4f" % (AUC_prec_rec) + "\nJaccard similarity score: %.4f" % (jaccard_index) + "\nF1 score (F-measure): %.4f" % (F1_score) + "\nConfusion matrix:" + str(confusion) + "\nACCURACY: %.4f" % (accuracy) + "\nSENSITIVITY: %.4f" % (sensitivity) + "\nSPECIFICITY: %.4f" % (specificity) + "\nPRECISION: %.4f" % (precision) + "\nthin 2vessels recall indivi:\n" + str(thin_2pixel_recall_indivi) + "\nthin 2vessels recall mean:%.4f" % (np.mean(thin_2pixel_recall_indivi)) + "\nthin 2vessels auc indivi:\n" + str(thin_2pixel_auc_roc) + "\nthin 2vessels auc score mean:%.4f" % (np.mean(thin_2pixel_auc_roc)) + "\nthin 3vessels recall indivi:\n" + str(thin_3pixel_recall_indivi) + "\nthin 3vessels recall mean:%.4f" % (np.mean(thin_3pixel_recall_indivi)) + "\nthin 3vessels auc indivi:\n" + str(thin_3pixel_auc_roc) + "\nthin 3vessels auc score mean:%.4f" % (np.mean(thin_3pixel_auc_roc)) )
sample_le = {} for le_512 in LE_512: tiles = crop_prepare(le_512, CROP_STEP, IMG_SIZE) for n, img in enumerate(tiles): if n not in sample_le: sample_le[n] = [] img = transform.resize(img, (IMG_SIZE * 2, IMG_SIZE * 2), preserve_range=True, order=3) sample_le[n].append(img) SC_UNET = UNet(n_channels=15, n_classes=1) print("{} paramerters in total".format( sum(x.numel() for x in SC_UNET.parameters()))) SC_UNET.cuda(cuda) SC_UNET.load_state_dict(torch.load(model_path)) # SC_UNET.load_state_dict(torch.load(os.path.join(dir_path,"model","HE_HER_mito","HE_X2_HER_0825.pkl"))) SC_UNET.eval() SRRFDATASET = ReconsDataset(img_dict=sample_he, transform=ToTensor(), in_norm=LE_in_norm, img_type=".tif", in_size=256) test_dataloader = torch.utils.data.DataLoader( SRRFDATASET, batch_size=1, shuffle=False, pin_memory=True) # better than for loop result = np.zeros((256, 256, len(SRRFDATASET))) for batch_idx, items in enumerate(test_dataloader): image = items['image_in'] image_idx = items['image_name']
test_dir = dir_names.test_dir + '/' + experiment_name + '/' if not os.path.exists(test_dir): os.makedirs(test_dir) if not load_model: c.force_create(model_dir) c.force_create(tfboard_dir) #Define image dataset (reads in full images and segmentations) test_dataset = p.ImageDataset_withPrior(csv_file=c.final_test_csv) num_class = 3 model_file = model_dir + 'model_17.pth' net = UNet(num_class, in_channels=2) net.load_state_dict(torch.load(model_file, map_location=device)) net.eval() net.to(device) pad_size = c.half_patch[0] include_prior = True with torch.no_grad(): for i in range(6, len(test_dataset)): sample = test_dataset[i] if include_prior: prior = sample['prior'] test_patches = p.GeneratePatches(sample, is_training=False, transform=False,
import os from PIL import Image from predict import * from utils import encode from unet_model import UNet def submit(net, gpu=False): dir = 'data/test/' N = len(list(os.listdir(dir))) with open('SUBMISSION.csv', 'a') as f: f.write('img,rle_mask\n') for index, i in enumerate(os.listdir(dir)): print('{}/{}'.format(index, N)) img = Image.open(dir + i) mask = predict_img(net, img, gpu) enc = rle_encode(mask) f.write('{},{}\n'.format(i, ' '.join(map(str, enc)))) if __name__ == '__main__': net = UNet(3, 1).cuda() net.load_state_dict(torch.load('MODEL.pth')) submit(net, True)
loss = criterion(y_pred, y.unsqueeze(0).float()) l += loss.data[0] loss.backward() if i % 10 == 0: optimizer.step() print('Stepped') print('{0:.4f}%\t\t{1:.6f}'.format(i / len(ids) * 100, loss.data[0])) l = l / len(ids) print('Loss : {}'.format(l)) torch.save(net.state_dict(), 'MODEL_EPOCH{}_LOSS{}.pth'.format(epoch + 1, l)) print('Saved') try: net.load_state_dict(torch.load('MODEL_INTERRUPTED.pth')) train(net) except KeyboardInterrupt: print('Interrupted') torch.save(net.state_dict(), 'MODEL_INTERRUPTED.pth') try: sys.exit(0) except SystemExit: os._exit(0)
if convert_to_2d: inp_set = get_2d_converted_data(inp_set) inp_set = torch.from_numpy(inp_set).float() file_name = os.path.basename(inp_file) out_file = os.path.join(out_dir, file_name) data.append((inp_set, out_file)) return data def save_pred(model, data): model.eval() for image, file_path in data: img = image.cuda(cuda) pred = model(img) pred = pred.detach().cpu().numpy()[0] pred = (pred * 255) .astype(np.uint8) # save_path = file_path.replace('.mat', '.png') # cv2.imwrite(save_path, pred) pred = pred.transpose((1, 2, 0)) savemat(file_path, {'crop_g': pred}) if __name__ == '__main__': cuda = torch.device('cuda') model = UNet(n_channels=45, n_classes=3) print("{} Parameters in total".format(sum(x.numel() for x in model.parameters()))) model.cuda(cuda) model.load_state_dict(torch.load(model_loc+"Model_Final_999_3_5.pkl")) model.eval() model.cuda(cuda) data = get_images() save_pred(model, data)
img = transform(img) img = img.unsqueeze(0) def get_layer_param(model): return sum([torch.numel(param) for param in model.parameters()]) net = UNet(1, 3).to(device) print(net) print('parameters:', get_layer_param(net)) print("Loading checkpoint...") checkpoint = torch.load(ckpt_path) net.load_state_dict(checkpoint['net_state_dict']) net.eval() print("Starting Test...") # ----------------------------------------------------------- # Initial batch data_A = img.to(device) # ----------------------------------------------------------- # Generate fake img: fake_B = net(data_A) # ----------------------------------------------------------- # Output training stats # vutils.save_image(data_A, os.path.join(samples_path, 'result', '%s_data_A.jpg' % str(i).zfill(6)), # padding=2, nrow=2, normalize=True) vutils.save_image(fake_B, os.path.join('./', '%s_fake_B_leaky.jpg' % filename[0:-4]), padding=0, nrow=1, normalize=True)
def train_second_layer(model, criterion, optimizer, scheduler, num_epochs=25): model1 = UNet() model2 = UNet() model1.load_state_dict(best_model_wts) model2.load_state_dict(best_model_wts)
import torchvision from torch.utils.data import DataLoader, Dataset from torchvision.utils import make_grid, save_image from PIL import Image import os import numpy as np from PIL import Image from unet_model import UNet import random input_dir = "../Test_Data/" output_dir = "../Generated_Test/" model_path = "models/model_gen_latest" generator = UNet(n_channels=3, n_classes=2) generator.load_state_dict(torch.load(model_path)) generator.eval() for filename in random.sample(os.listdir(input_dir), len(os.listdir(input_dir))): img = Image.open(os.path.join(input_dir, filename)) # img = normalize(img) img = torch.stack([ transforms.Compose( [transforms.Resize((75, 210)), transforms.ToTensor()])(img) ]) output_img = generator(img) save_image(output_img, output_dir + "/" + filename)
pin_memory=False) # better than for loop test_dataloader = torch.utils.data.DataLoader( test_data, batch_size=batch_size, shuffle=True, pin_memory=False) # better than for loop if is_3d and not (convert_to_2d, data_reduced): # 3D Processing directly model = UNet3D(n_channels=X_train.shape[1], n_classes=y_train.shape[1]) if data_reduced and single_unet: # 3D Converted to 2D with data reduction of reduced_size model = UNet3SIM(n_channels=X_train.shape[1], n_classes=y_train.shape[1]) else: model = UNet(n_channels=X_train.shape[1], n_classes=y_train.shape[1]) ## Thats tautology TO-DO - fix and optimize start_epoch = 0 if load_model: weight = torch.load(model_loc + "650.pt") model.load_state_dict(weight['model_state_dict']) start_epoch = weight['epoch'] print("{} Parameters in Total".format( sum(x.numel() for x in model.parameters()))) print(" See model input shape first:", X_train.shape[1], y_train.shape[1]) if have_cuda: model.cuda(cuda) learning_rate = 0.001 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999)) train_loss = [] learn_rate = [] c = 0 val_metrics = {
class Train(object): def __init__(self, configs): self.batch_size = configs.get("batch_size", "16") self.epochs = configs.get("epochs", "100") self.lr = configs.get("lr", "0.0001") device_args = configs.get("device", "cuda") self.device = torch.device( "cpu" if not torch.cuda.is_available() else device_args) self.workers = configs.get("workers", "4") self.vis_images = configs.get("vis_images", "200") self.vis_freq = configs.get("vis_freq", "10") self.weights = configs.get("weights", "./weights") if not os.path.exists(self.weights): os.mkdir(self.weights) self.logs = configs.get("logs", "./logs") if not os.path.exists(self.weights): os.mkdir(self.weights) self.images_path = configs.get("images_path", "./data") self.is_resize = config.get("is_resize", False) self.image_short_side = config.get("image_short_side", 256) self.is_padding = config.get("is_padding", False) is_multi_gpu = config.get("DateParallel", False) pre_train = config.get("pre_train", False) model_path = config.get("model_path", './weights/unet_idcard_adam.pth') # self.image_size = configs.get("image_size", "256") # self.aug_scale = configs.get("aug_scale", "0.05") # self.aug_angle = configs.get("aug_angle", "15") self.step = 0 self.dsc_loss = DiceLoss() self.model = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels) if pre_train: self.model.load_state_dict(torch.load(model_path, map_location=self.device), strict=False) if is_multi_gpu: self.model = nn.DataParallel(self.model) self.model.to(self.device) self.best_validation_dsc = 0.0 self.loader_train, self.loader_valid = self.data_loaders() self.params = [p for p in self.model.parameters() if p.requires_grad] self.optimizer = optim.Adam(self.params, lr=self.lr, weight_decay=0.0005) # self.optimizer = torch.optim.SGD(self.params, lr=self.lr, momentum=0.9, weight_decay=0.0005) self.scheduler = lr_scheduler.LR_Scheduler_Head( 'poly', self.lr, self.epochs, len(self.loader_train)) def datasets(self): train_datasets = Dataset( images_dir=self.images_path, # image_size=self.image_size, subset="train", # train transform=get_transforms(train=True), is_resize=self.is_resize, image_short_side=self.image_short_side, is_padding=self.is_padding) # valid_datasets = train_datasets valid_datasets = Dataset( images_dir=self.images_path, # image_size=self.image_size, subset="validation", # validation transform=get_transforms(train=False), is_resize=self.is_resize, image_short_side=self.image_short_side, is_padding=False) return train_datasets, valid_datasets def data_loaders(self): dataset_train, dataset_valid = self.datasets() loader_train = DataLoader( dataset_train, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=self.workers, ) loader_valid = DataLoader( dataset_valid, batch_size=1, drop_last=False, num_workers=self.workers, ) return loader_train, loader_valid @staticmethod def dsc_per_volume(validation_pred, validation_true): assert len(validation_pred) == len(validation_true) dsc_list = [] for p in range(len(validation_pred)): y_pred = np.array([validation_pred[p]]) y_true = np.array([validation_true[p]]) dsc_list.append(dsc(y_pred, y_true)) return dsc_list @staticmethod def get_logger(filename, verbosity=1, name=None): level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING} formatter = logging.Formatter( "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s" ) logger = logging.getLogger(name) logger.setLevel(level_dict[verbosity]) fh = logging.FileHandler(filename, "w") fh.setFormatter(formatter) logger.addHandler(fh) sh = logging.StreamHandler() sh.setFormatter(formatter) logger.addHandler(sh) return logger def train_one_epoch(self, epoch): self.model.train() loss_train = [] for i, data in enumerate(self.loader_train): self.scheduler(self.optimizer, i, epoch, self.best_validation_dsc) x, y_true = data x, y_true = x.to(self.device), y_true.to(self.device) y_pred = self.model(x) # print('1111', y_pred.size()) # print('2222', y_true.size()) loss = self.dsc_loss(y_pred, y_true) loss_train.append(loss.item()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() # lr_scheduler.step() if self.step % 200 == 0: print('Epoch:[{}/{}]\t iter:[{}]\t loss={:.5f}\t '.format( epoch, self.epochs, i, loss)) self.step += 1 def eval_model(self, patience): self.model.eval() loss_valid = [] validation_pred = [] validation_true = [] # early_stopping = EarlyStopping(patience=patience, verbose=True) for i, data in enumerate(self.loader_valid): x, y_true = data x, y_true = x.to(self.device), y_true.to(self.device) # print(x.size()) # print(333,x[0][2]) with torch.no_grad(): y_pred = self.model(x) loss = self.dsc_loss(y_pred, y_true) # print(y_pred.shape) mask = y_pred > 0.5 mask = mask * 255 mask = mask.cpu().numpy()[0][0] # print(mask) # print(mask.shape()) cv2.imwrite('result.png', mask) loss_valid.append(loss.item()) y_pred_np = y_pred.detach().cpu().numpy() validation_pred.extend( [y_pred_np[s] for s in range(y_pred_np.shape[0])]) y_true_np = y_true.detach().cpu().numpy() validation_true.extend( [y_true_np[s] for s in range(y_true_np.shape[0])]) # early_stopping(loss_valid, self.model) # if early_stopping.early_stop: # print('Early stopping') # import sys # sys.exit(1) mean_dsc = np.mean( self.dsc_per_volume( validation_pred, validation_true, )) # print('mean_dsc:', mean_dsc) if mean_dsc > self.best_validation_dsc: self.best_validation_dsc = mean_dsc torch.save(self.model.state_dict(), os.path.join(self.weights, "unet_xia_adam.pth")) print("Best validation mean DSC: {:4f}".format( self.best_validation_dsc)) def main(self): # print('train is begin.....') # print('load data end.....') # loaders = {"train": loader_train, "valid": loader_valid} for epoch in tqdm(range(self.epochs), total=self.epochs): self.train_one_epoch(epoch) self.eval_model(patience=10) torch.save(self.model.state_dict(), os.path.join(self.weights, "unet_final.pth"))
class OnePredict(object): def __init__(self, params): self.params = params self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") self.model_path = params['model_path'] self.model = UNet(in_channels=3, out_channels=1) self.threshold = 0.5 self.resume() # self.model.eval() self.transform = get_transforms_3() self.is_resize = True self.image_short_side = 1024 self.init_torch_tensor() self.model.eval() def init_torch_tensor(self): torch.set_default_tensor_type('torch.FloatTensor') if torch.cuda.is_available(): self.device = torch.device('cuda') torch.set_default_tensor_type('torch.cuda.FloatTensor') else: self.device = torch.device('cpu') # self.model.to(self.device) def resume(self): self.model.load_state_dict(torch.load(self.model_path, map_location=self.device), strict=False) self.model.to(self.device) def resize_img(self, img): '''输入PIL格式的图片''' width, height = img.size # print('111', img.size) if self.is_resize: if height < width: new_height = self.image_short_side new_width = int(math.ceil(new_height / height * width / 32) * 32) else: new_width = self.image_short_side new_height = int(math.ceil(new_width / width * height / 32) * 32) else: if height < width: scale = int(height / 32) new_image_short_side = scale * 32 new_height = new_image_short_side new_width = int(math.ceil(new_height / height * width / 32) * 32) else: scale = int(width / 32) new_image_short_side = scale * 32 new_width = new_image_short_side new_height = int(math.ceil(new_width / width * height / 32) * 32) # print('test1:', np.array(img)) # print('new:', (new_width, new_height)) resized_img = img.resize((new_width, new_height), Image.ANTIALIAS) # print(new_height, new_width) # print('test2:', np.array(resized_img)) return resized_img def format_output(self): pass @staticmethod def pre_process(img): return img @staticmethod def pad_sample(img): a = img.size[0] b = img.size[1] if a == b: return img diff = (max(a, b) - min(a, b)) / 2.0 if a > b: padding = (0, int(np.floor(diff)), 0, int(np.ceil(diff))) else: padding = (int(np.floor(diff)), 0, int(np.ceil(diff)), 0) img = ImageOps.expand(img, border=padding, fill=0) ##left,top,right,bottom assert img.size[0] == img.size[1] return img def post_process(self, preds, img): mask = preds > self.threshold mask = mask * 255 # print(mask.size()) mask = mask.cpu().numpy()[0][0] # print(mask) # print(mask.shape()) cv2.imwrite('mask.png', mask) mask = np.array(mask, np.uint8) contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # print(contours) # img = img.cpu() img = np.array(img, np.uint8) cv2.drawContours(img, contours, -1, (0, 0, 255), 1) cv2.imwrite('result2.png', img) boxes = [] return boxes @staticmethod def demo_visualize(): pass def inference(self, img_path, is_visualize=True, is_format_output=False): img = cv2.imread(img_path, cv2.COLOR_BGR2RGB) img = Image.fromarray(img).convert("RGB") # img = Image.open(img_path).convert("RGB") # print('222', np.array(img)) # img = self.pad_sample(img) img = self.resize_img(img) # print('333', img.size) # print('-----', np.array(img)) ori_img = img img.save('img.png') # img = [img] print('111', np.array(img)) img = self.transform(img) print('222', np.array(img)) img = img.unsqueeze(0) img = img.to(self.device) # print('1111', img.size()) # print(img) # print(img) with torch.no_grad(): s1 = time.time() preds = self.model(img) print(preds) s2 = time.time() print(s2 - s1) # boxes, scores = SegDetectorRepresenter().represent(pred=preds, height=h, width=w, is_output_polygon=False) boxes = self.post_process(preds, ori_img)
os.makedirs(save_dir) if not os.path.exists(model_path): os.makedirs(model_path) generator = UNet(n_channels=3, out_channels=1) discriminator_g = GlobalDiscriminator() discriminator_l = LocalDiscriminator() resume=False if(len(sys.argv)>1 and sys.argv[1]=='resume'): resume=True # Load model if available if(resume==True): print('Resuming training....') generator.load_state_dict(torch.load(os.path.join(model_path,'model_gen_latest'))) discriminator_g.load_state_dict(torch.load(os.path.join(model_path,'model_gdis_latest'))) discriminator_l.load_state_dict(torch.load(os.path.join(model_path,'model_ldis_latest'))) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") generator = generator.to(device) discriminator_g = discriminator_g.to(device) discriminator_l = discriminator_l.to(device) optimizer_g = optim.Adam(discriminator_g.parameters(), lr=0.00005) optimizer_l = optim.Adam(discriminator_l.parameters(), lr=0.00005) gen_optimizer = optim.Adam(generator.parameters(), lr=0.0002) lossdis = nn.BCELoss() lossgen = FocalLoss() lamda = 75
smooth = 1. num = pred.size(0) m1 = pred.view(num, -1) # Flatten m2 = target.view(num, -1) # Flatten intersection = (m1 * m2).sum() return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth) if __name__ == "__main__": # 选择设备,有cuda用cuda,没有就用cpu device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 加载网络,图片单通道,分类为1。 net = UNet(n_channels=3, n_classes=1) # 将网络拷贝到deivce中 net.to(device=device) # 加载模型参数 net.load_state_dict(torch.load('./save_model/CP_epoch300.pth', map_location=device)) # 测试模式 net.eval() # 读取所有图片路径 data_path = "./data/CHASE/test" tests_path = glob.glob(os.path.join(data_path, r'image/*.jpg')) name_dataset = DealDataset(data_path) train_loader = torch.utils.data.DataLoader(dataset=name_dataset, batch_size=1,shuffle=False) # shuffle 填True 就会打乱 # print(train_loader.len) # save_res_path = test_path.split('.')[0] + '_res2021.png' # 遍历所有图片 total_batch = int(len(name_dataset)/1) bar = tqdm(enumerate(train_loader),total=total_batch) for batch_index, batch in bar: image = batch['image'] # label = batch['label'].squeeze(0)