print('tr_mask:',tr_mask.shape) print('val_mask:',val_mask.shape) tr_dataset=TensorDataset(tr_data,tr_mask) val_dataset=TensorDataset(val_data,val_mask) device='cuda:0' batch_size = 8 torch.backends.cudnn.deterministic = True train_loader=DataLoader(dataset=tr_dataset,batch_size=batch_size,shuffle=True,num_workers=4) val_loader=DataLoader(dataset=val_dataset,batch_size=batch_size,shuffle=True,num_workers=4) del tr_data,tr_mask,val_data,val_mask,tr_dataset,val_dataset gc.collect() print('Data loading') # fix random seed for reproducibility # Build model model = M.Hybrid_U_Net(input_size = (3,256,256)).to(device) print('Training') loss_function=nn.BCELoss() optimiser=optim.Adam(model.parameters(),lr=1e-4) scheduler=optim.lr_scheduler.ReduceLROnPlateau(optimiser,mode='min',eps=1e-5,patience=10,factor=0.8,verbose=True) nb_epoch = 100 val_loss_best=30 val_accuracy_best=0.90 for epoch in range(nb_epoch): print('epoch:',epoch) model.train() train_loss=0 train_num=0 train_correct=0 for input,label in train_loader: input=input.to(device)
#========= CONFIG FILE TO READ FROM ======= #=========================================== #run the training on invariant or local path_data = './DRIVE_datasets_training_testing/' #original test images (for FOV selection) DRIVE_test_imgs_original = path_data + 'DRIVE_dataset_imgs_test.hdf5' test_imgs_orig = load_hdf5(DRIVE_test_imgs_original) full_img_height = test_imgs_orig.shape[2] full_img_width = test_imgs_orig.shape[3] #the border masks provided by the DRIVE DRIVE_test_border_masks = path_data + 'DRIVE_dataset_borderMasks_test.hdf5' test_border_masks = load_hdf5(DRIVE_test_border_masks) device = 'cuda:0' model = M.Hybrid_U_Net(input_size=(1, 64, 64)).to(device) # dimension of the patches patch_height = 64 patch_width = 64 # the stride in case output with average stride_height = 5 stride_width = 5 # model name name_experiment = 'output' path_experiment = './' + name_experiment + '/' # N full images to be predicted Imgs_to_test = 2 # Grouping of the predicted images N_visual = 1 # ====== average mode ===========