Y_train = np.load(TRAIN_DATA_PATH / 'Y_train.npy') np.random.seed(SEED) X_train = np.random.permutation(X_train) np.random.seed(SEED) Y_train = np.random.permutation(Y_train) m = X_train.shape[0] m_val = VAL_BATCH_SIZE X_val = X_train[:m_val] Y_val = Y_train[:m_val] X_train = X_train[m_val:] Y_train = Y_train[m_val:] print('Beginning training...') model.train(X_train, Y_train, X_val, Y_val, max_epochs=100, batch_size=32, learning_rate_init=1e-4, reg_param=0, learning_rate_decay_type='constant', learning_rate_decay_parameter=1, early_stopping=True, save_path='./models/0/UNet0', reset_parameters=True, check_val_every_n_batches=100, seed=SEED, data_on_GPU=False)
transform = transforms.Compose( [transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]) args = parser.parse_args() VAR = args.var DATA_DIR = args.data_dir CHECKPOINT = args.checkpoint testset = CustomImageDataset(DATA_DIR, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4) dataiter = iter(testloader) checkpoint = torch.load(CHECKPOINT, map_location=torch.device('cpu')) model_test = UNet(in_channels=3, out_channels=3).double() model_test.load_state_dict(checkpoint['model_state_dict']) model_test = model_test.cpu() model_test.train() noisy = NoisyDataset(var=VAR) images, _ = dataiter.next() noisy_images = noisy(images) # Displaying the Noisy Images imshow(torchvision.utils.make_grid(noisy_images.cpu())) # Displaying the Denoised Images imshow(torchvision.utils.make_grid(model_test(noisy_images.cpu())))
def train(train_loader, valid_loader, loss_type, act_type, tolerance, result_path, log_interval=10, lr=0.000001, max_epochs=500): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # model = UNet(upsample_mode='transpose').to(device) model = UNet(upsample_mode='bilinear').to(device) optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) best_model_path = result_path + '/best_model.pth' model_path = result_path + '/model_epoch' train_batch_loss_file = open(result_path + "/train_batch_loss.txt", "w") valid_batch_loss_file = open(result_path + "/valid_batch_loss.txt", "w") train_all_epochs_loss_file = open( result_path + "/train_all_epochs_loss.txt", "w") train_all_epochs_loss = [] valid_all_epochs_loss_file = open( result_path + "/valid_all_epochs_loss.txt", "w") valid_all_epochs_loss = [] minimum_loss = np.inf finish = False for epoch in range(1, max_epochs + 1): for phase in ['train', 'val']: if phase == 'train': idx = list(range(0, len(train_loader))) train_smpl = random.sample(idx, 1) #train_smpl.append(len(train_loader)-1) loader = train_loader model.train() elif phase == 'val': idx = list(range(0, len(valid_loader))) val_smpl = random.sample(idx, 1) #val_smpl.append(len(valid_loader) - 1) loader = valid_loader model.eval() all_batches_losses = [] for batch_i, sample in enumerate(loader): data, target, loss_weight = sample['image'], sample[ 'image_anno'], sample['loss_weight'] #/1000 data, target, loss_weight = data.to(device), target.to( device), loss_weight.to(device) optimizer.zero_grad() loss_weight = loss_weight / 1000 with torch.set_grad_enabled(phase == 'train'): output = model(data) # Set activation type: if act_type == 'sigmoid': activation = torch.nn.Sigmoid().cuda() elif act_type == 'tanh': activation = torch.nn.Tanh().cuda() elif act_type == 'soft': activation = torch.nn.Softmax().cuda() # Calculate loss: if loss_type == 'wbce': # Weighted BCE with averaging: criterion = torch.nn.BCELoss(weight=loss_weight).cuda( ) #,size_average=False).cuda() loss = criterion(activation(output), target).cuda() #loss = criterion(output, target).cuda() elif loss_type == 'bce': # BCE with averaging: criterion = torch.nn.BCELoss().cuda( ) # ,size_average=False).cuda() loss = criterion(activation(output), target).cuda() elif loss_type == 'mse': # MSE: loss = F.mse_loss(output, target).cuda() else: # loss_type == 'jac': loss = jaccard_loss(activation(output), target).cuda() if phase == 'train': loss.backward() optimizer.step() if phase == 'train': train_batch_loss_file.write(str(loss.item()) + "\n") train_batch_loss_file.close() train_batch_loss_file = open( result_path + "/train_batch_loss.txt", "a") else: valid_batch_loss_file.write(str(loss.item()) + "\n") valid_batch_loss_file.close() valid_batch_loss_file = open( result_path + "/valid_batch_loss.txt", "a") all_batches_losses.append(loss.item()) if batch_i % log_interval == 0: print( '{} Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( phase, epoch, batch_i * len(data), len(loader.dataset), 100. * batch_i / len(loader), loss.item())) if phase == 'train' and batch_i in train_smpl: post_transform = transforms.Compose( [Binarize_Output(threshold=output.mean())]) thres = post_transform(output) post_transform_weight = transforms.Compose( [Binarize_Output(threshold=loss_weight.mean())]) weight_tresh = post_transform_weight(loss_weight) utils.save_image( data, "{}/train_input_{}_{}.png".format( result_path, epoch, batch_i)) utils.save_image( target, "{}/train_target_{}_{}.png".format( result_path, epoch, batch_i)) utils.save_image( output, "{}/train_output_{}_{}.png".format( result_path, epoch, batch_i)) utils.save_image( thres, "{}/train_thres_{}_{}.png".format( result_path, epoch, batch_i)) utils.save_image( weight_tresh, "{}/train_weights_{}_{}.png".format( result_path, epoch, batch_i)) if epoch % 25 == 0: torch.save(model.state_dict(), model_path + '_{}.pth'.format(epoch)) if phase == 'val' and batch_i in val_smpl: post_transform = transforms.Compose( [Binarize_Output(threshold=output.mean())]) thres = post_transform(output) post_transform_weight = transforms.Compose( [Binarize_Output(threshold=loss_weight.mean())]) weight_tresh = post_transform_weight(loss_weight) utils.save_image( data, "{}/valid_input_{}_{}.png".format( result_path, epoch, batch_i)) utils.save_image( target, "{}/valid_target_{}_{}.png".format( result_path, epoch, batch_i)) utils.save_image( output, "{}/valid_output_{}_{}.png".format( result_path, epoch, batch_i)) utils.save_image( thres, "{}/valid_thres_{}_{}.png".format( result_path, epoch, batch_i)) utils.save_image( weight_tresh, "{}/valid_weights_{}_{}.png".format( result_path, epoch, batch_i)) if phase == 'train': train_last_avg_loss = np.mean(all_batches_losses) print("------average %s loss %f" % (phase, train_last_avg_loss)) train_all_epochs_loss_file.write( str(train_last_avg_loss) + "\n") train_all_epochs_loss_file.close() train_all_epochs_loss_file = open( result_path + "/train_all_epochs_loss.txt", "a") if phase == 'val': valid_last_avg_loss = np.mean(all_batches_losses) print("------average %s loss %f" % (phase, valid_last_avg_loss)) valid_all_epochs_loss_file.write( str(valid_last_avg_loss) + "\n") valid_all_epochs_loss_file.close() valid_all_epochs_loss_file = open( result_path + "/valid_all_epochs_loss.txt", "a") valid_all_epochs_loss.append(valid_last_avg_loss) if valid_last_avg_loss < minimum_loss: minimum_loss = valid_last_avg_loss #--------------------- Saving the best found model ----------------------- torch.save(model.state_dict(), best_model_path) print("Minimum Average Loss so far:", minimum_loss) if early_stopping(epoch, valid_all_epochs_loss, tolerance): finish = True break if finish == True: break
np.random.seed(SEED) X_train = np.random.permutation(X_train) np.random.seed(SEED) Y_train = np.random.permutation(Y_train) m = X_train.shape[0] m_val = VAL_BATCH_SIZE X_val = X_train[:m_val] Y_val = Y_train[:m_val] X_train = X_train[m_val:] Y_train = Y_train[m_val:] print('Beginning training...') model.train(X_train, Y_train, X_val, Y_val, max_epochs=int(1e9), batch_size=16, learning_rate_init=2e-3, reg_param=0, learning_rate_decay_type='inverse', learning_rate_decay_parameter=10, keep_prob=[0.7, 0.8], early_stopping=True, save_path=MODEL_PATH, reset_parameters=True, val_checks_per_epoch=10, seed=SEED, data_on_GPU=False)
trainset = TrainSet(IMG_ROOT, LABEL_ROOT, data_type) trainloader = DataLoader(trainset, BATCH_SIZE, shuffle=True) print('loader done') # Defining model and optimization methode device = 'cuda:0' #device = 'cpu' unet = UNet(in_channel=3, class_num=2).to(device) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(unet.parameters(), lr=0.0005, amsgrad=True) epochs = 1 lsize = len(trainloader) itr = 0 p_itr = 10 # print every N iteration unet.train() tloss = 0 loss_history = [] for epoch in range(epochs): with tqdm(total=lsize) as pbar: for x, y, path in trainloader: x, y = x.to(device), y.to(device) optimizer.zero_grad() output = unet(x) loss = criterion(output, y[:, 0, :, :].to(device)) loss.backward() optimizer.step() tloss += loss.item() loss_history.append(loss.item())
def train(args, Dataset): ####################################### Initializing Model ####################################### step = args.lr #experiment_dir = args['--experiment_dir'] device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("device:{}".format(device)) print_every = int(args.print_every) num_epochs = int(args.num_epochs) save_every = int(args.save_every) save_path = str(args.model_save_path) batch_size = int(args.batch_size) #train_data_path = str(args['--data_path']) in_ch = int(args.in_ch) val_split = args.val_split img_directory = args.image_directory #model = MW_Unet(in_ch=in_ch) model = UNet(in_ch=in_ch) #model = model model.to(device) model.apply(init_weights) optimizer = torch.optim.Adam(model.parameters(), lr=step) #criterion = nn.MSELoss() criterion = torch.nn.L1Loss() scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) ######################################### Loading Data ########################################## dataset_total = Dataset dataset_size = len(dataset_total) indices = list(range(dataset_size)) split = int(np.floor(val_split * dataset_size)) np.random.shuffle(indices) train_indices, val_indices = indices[split:], indices[:split] #train_indices, val_indices = indices[:1], indices[1:2] train_sampler = SubsetRandomSampler(train_indices) valid_sampler = SubsetRandomSampler(val_indices) dataloader_train = torch.utils.data.DataLoader(dataset_total, batch_size=batch_size, sampler=train_sampler, num_workers=8) dataloader_val = torch.utils.data.DataLoader(dataset_total, batch_size=batch_size, sampler=valid_sampler, num_workers=2) print("length of train set: ", len(train_indices)) print("length of val set: ", len(val_indices)) #best_val_PSNR = 0.0 best_val_MSE, best_val_PSNR, best_val_SSIM = 100.0, -1, -1 train_PSNRs = [] train_losses = [] train_SSIMs = [] train_MSEs = [] val_PSNRs = [] val_losses = [] val_SSIMs = [] val_MSEs = [] try: for epoch in range(1, num_epochs + 1): # INITIATE dataloader_train print("epoch: ", epoch) with tqdm(total=len(dataloader_train)) as pbar: for index, sample in enumerate(dataloader_train): model.train() target, model_input, features = sample['target'], sample[ 'input'], sample['features'] N, P, C, H, W = model_input.shape N, P, C_feat, H, W = features.shape model_input = torch.reshape(model_input, (-1, C, H, W)) features = torch.reshape(features, (-1, C_feat, H, W)) albedo = features[:, 3:, :, :] albedo = albedo.to(device) eps = torch.tensor(1e-2) eps = eps.to(device) model_input = model_input.to(device) model_input /= (albedo + eps) target = torch.reshape(target, (-1, C, H, W)) features = features.to(device) model_input = torch.cat((model_input, features), dim=1) target = target.to(device) model_input = model_input.to(device) #print(model_input.dtype) #print(model_input.shape) # print(index) output = model.forward(model_input) output *= (albedo + eps) train_loss = utils.backprop(optimizer, output, target, criterion) train_PSNR = utils.get_PSNR(output, target) train_MSE = utils.get_MSE(output, target) train_SSIM = utils.get_SSIM(output, target) avg_val_PSNR = [] avg_val_loss = [] avg_val_MSE = [] avg_val_SSIM = [] model.eval() #output_val = 0; train_losses.append(train_loss.cpu().detach().numpy()) train_PSNRs.append(train_PSNR) train_MSEs.append(train_MSE) train_SSIMs.append(train_SSIM) if index == len(dataloader_train) - 1: with torch.no_grad(): for val_index, val_sample in enumerate( dataloader_val): target_val, model_input_val, features_val = val_sample[ 'target'], val_sample['input'], val_sample[ 'features'] N, P, C, H, W = model_input_val.shape N, P, C_feat, H, W = features_val.shape model_input_val = torch.reshape( model_input_val, (-1, C, H, W)) features_val = torch.reshape( features_val, (-1, C_feat, H, W)) albedo = features_val[:, 3:, :, :] albedo = albedo.to(device) eps = torch.tensor(1e-2) eps = eps.to(device) model_input_val = model_input_val.to(device) model_input_val /= (albedo + eps) target_val = torch.reshape( target_val, (-1, C, H, W)) features_val = features_val.to(device) model_input_val = torch.cat( (model_input_val, features_val), dim=1) target_val = target_val.to(device) model_input_val = model_input_val.to(device) output_val = model.forward(model_input_val) output_val *= (albedo + eps) loss_fn = criterion loss_val = loss_fn(output_val, target_val) PSNR = utils.get_PSNR(output_val, target_val) MSE = utils.get_MSE(output_val, target_val) SSIM = utils.get_SSIM(output_val, target_val) avg_val_PSNR.append(PSNR) avg_val_loss.append( loss_val.cpu().detach().numpy()) avg_val_MSE.append(MSE) avg_val_SSIM.append(SSIM) avg_val_PSNR = np.mean(avg_val_PSNR) avg_val_loss = np.mean(avg_val_loss) avg_val_MSE = np.mean(avg_val_MSE) avg_val_SSIM = np.mean(avg_val_SSIM) val_PSNRs.append(avg_val_PSNR) val_losses.append(avg_val_loss) val_MSEs.append(avg_val_MSE) val_SSIMs.append(avg_val_SSIM) scheduler.step(avg_val_loss) img_grid = output.data[:9] img_grid = torchvision.utils.make_grid(img_grid) real_grid = target.data[:9] real_grid = torchvision.utils.make_grid(real_grid) input_grid = model_input.data[:9, :3, :, :] input_grid = torchvision.utils.make_grid(input_grid) val_grid = output_val.data[:9] val_grid = torchvision.utils.make_grid(val_grid) #save_image(input_grid, '{}train_input_img.png'.format(img_directory)) #save_image(img_grid, '{}train_img_{}.png'.format(img_directory, epoch)) #save_image(real_grid, '{}train_real_img_{}.png'.format(img_directory, epoch)) #print('train images') fig, ax = plt.subplots(4) fig.subplots_adjust(hspace=0.5) ax[0].set_title('target') ax[0].imshow(real_grid.cpu().numpy().transpose( (1, 2, 0))) ax[1].set_title('input') ax[1].imshow(input_grid.cpu().numpy().transpose( (1, 2, 0))) ax[2].set_title('output_train') ax[2].imshow(img_grid.cpu().numpy().transpose( (1, 2, 0))) ax[3].set_title('output_val') ax[3].imshow(val_grid.cpu().numpy().transpose( (1, 2, 0))) #plt.show() plt.savefig('{}train_output_target_img_{}.png'.format( img_directory, epoch)) plt.close() pbar.update(1) if epoch % print_every == 0: print( "Epoch: {}, Loss: {}, Train MSE: {} Train PSNR: {}, Train SSIM: {}" .format(epoch, train_loss, train_MSE, train_PSNR, train_SSIM)) print( "Epoch: {}, Avg Val Loss: {}, Avg Val MSE: {}, Avg Val PSNR: {}, Avg Val SSIM: {}" .format(epoch, avg_val_loss, avg_val_MSE, avg_val_PSNR, avg_val_SSIM)) plt.figure() plt.semilogy(np.linspace(0, epoch, len(train_losses)), train_losses) plt.xlabel("Epoch") plt.ylabel("Loss") plt.savefig("{}train_loss.png".format(img_directory)) plt.close() plt.figure() plt.semilogy(np.linspace(0, epoch, len(val_losses)), val_losses) plt.xlabel("Epoch") plt.ylabel("Loss") plt.savefig("{}val_loss.png".format(img_directory)) plt.close() plt.figure() plt.plot(np.linspace(0, epoch, len(train_PSNRs)), train_PSNRs) plt.xlabel("Epoch") plt.ylabel("PSNR") plt.savefig("{}train_PSNR.png".format(img_directory)) plt.close() plt.figure() plt.plot(np.linspace(0, epoch, len(val_PSNRs)), val_PSNRs) plt.xlabel("Epoch") plt.ylabel("PSNR") plt.savefig("{}val_PSNR.png".format(img_directory)) plt.close() plt.figure() plt.semilogy(np.linspace(0, epoch, len(train_MSEs)), train_MSEs) plt.xlabel("Epoch") plt.ylabel("MSE") plt.savefig("{}train_MSE.png".format(img_directory)) plt.close() plt.figure() plt.semilogy(np.linspace(0, epoch, len(val_MSEs)), val_MSEs) plt.xlabel("Epoch") plt.ylabel("MSE") plt.savefig("{}val_MSE.png".format(img_directory)) plt.close() plt.figure() plt.plot(np.linspace(0, epoch, len(train_SSIMs)), train_SSIMs) plt.xlabel("Epoch") plt.ylabel("SSIM") plt.savefig("{}train_SSIM.png".format(img_directory)) plt.close() plt.figure() plt.plot(np.linspace(0, epoch, len(val_SSIMs)), val_SSIMs) plt.xlabel("Epoch") plt.ylabel("SSIM") plt.savefig("{}val_SSIM.png".format(img_directory)) plt.close() if best_val_MSE > avg_val_MSE: best_val_MSE, best_val_PSNR, best_val_SSIM = avg_val_MSE, avg_val_PSNR, avg_val_SSIM print("new best Avg Val MSE: {}".format(best_val_MSE)) print("Saving model to {}".format(save_path)) torch.save( { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': train_loss }, save_path + "best_model.pth") print("Saved successfully to {}".format(save_path)) except KeyboardInterrupt: print("Training interupted...") print("Saving model to {}".format(save_path)) torch.save( { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': train_loss }, save_path + "checkpoint{}.pth".format(epoch)) print("Saved successfully to {}".format(save_path)) print("Training completed.") print("Best MSE: %.10f, Best PSNR: %.10f, Best SSIM: %.10f" % (best_val_MSE, best_val_PSNR, best_val_SSIM)) return (train_losses, train_PSNRs, val_losses, val_PSNRs, best_val_MSE)
def main_loop(data_path, batch_size=batch_size, model_type='UNet', green=False, tensorboard=True): # Load train and val data tasks = ['EX'] data_path = data_path n_labels = len(tasks) n_channels = 1 if green else 3 # green or RGB train_loader, val_loader = load_train_val_data(tasks=tasks, data_path=data_path, batch_size=batch_size, green=green) if model_type == 'UNet': lr = learning_rate model = UNet(n_channels, n_labels) # Choose loss function criterion = nn.MSELoss() # criterion = dice_loss # criterion = mean_dice_loss # criterion = nn.BCELoss() elif model_type == 'GCN': lr = 1e-4 model = GCN(n_labels, image_size[0]) criterion = weighted_BCELoss # criterion = nn.BCELoss() else: raise TypeError('Please enter a valid name for the model type') try: loss_name = criterion._get_name() except AttributeError: loss_name = criterion.__name__ if loss_name == 'BCEWithLogitsLoss': lr = 1e-4 print('learning rate: ', lr) optimizer = torch.optim.Adam(model.parameters(), lr=lr) # Choose optimize lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=7) if tensorboard: log_dir = tensorboard_folder + session_name + '/' print('log dir: ', log_dir) if not os.path.isdir(log_dir): os.makedirs(log_dir) writer = SummaryWriter(log_dir) else: writer = None max_aupr = 0.0 for epoch in range(epochs): # loop over the dataset multiple times print('******** Epoch [{}/{}] ********'.format(epoch + 1, epochs + 1)) print(session_name) # train for one epoch model.train(True) print('Training with batch size : ', batch_size) train_loop(train_loader, model, criterion, optimizer, writer, epoch, lr_scheduler=lr_scheduler, model_type=model_type) # evaluate on validation set print('Validation') with torch.no_grad(): model.eval() val_loss, val_aupr = train_loop(val_loader, model, criterion, optimizer, writer, epoch) # Save best model if val_aupr > max_aupr and epoch > 3: print('\t Saving best model, mean aupr on validation set: {:.4f}'. format(val_aupr)) max_aupr = val_aupr save_checkpoint( { 'epoch': epoch, 'best_model': True, 'model': model_type, 'state_dict': model.state_dict(), 'val_loss': val_loss, 'loss': loss_name, 'optimizer': optimizer.state_dict() }, model_path) elif save_model and (epoch + 1) % save_frequency == 0: save_checkpoint( { 'epoch': epoch, 'best_model': False, 'model': model_type, 'loss': loss_name, 'state_dict': model.state_dict(), 'val_loss': val_loss, 'optimizer': optimizer.state_dict() }, model_path) return model
def train(net: UNet, train_ids_file_path: str, val_ids_file_path: str, in_dir_path: str, mask_dir_path: str, check_points: str, epochs=10, batch_size=4, learning_rate=0.1, device=torch.device("cpu")): train_data_set = ImageSet(train_ids_file_path, in_dir_path, mask_dir_path) train_data_loader = DataLoader(train_data_set, batch_size=batch_size, shuffle=True, num_workers=1) net = net.to(device) loss_func = nn.BCELoss() optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.99) writer = SummaryWriter("tensorboard") g_step = 0 for epoch in range(epochs): net.train() total_loss = 0 with tqdm(total=len(train_data_set), desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for step, (imgs, masks) in tqdm(enumerate(train_data_loader)): imgs = imgs.to(device) masks = masks.to(device) outputs = net(imgs) loss = loss_func(outputs, masks) total_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() # record writer.add_scalar("Loss/Train", loss.item(), g_step) writer.flush() pbar.set_postfix(**{'loss (batch)': loss.item()}) pbar.update(imgs.shape[0]) g_step += 1 if g_step % 10 == 0: writer.add_images('masks/origin', imgs, g_step) writer.add_images('masks/true', masks, g_step) writer.add_images('masks/pred', outputs > 0.5, g_step) writer.flush() try: os.mkdir(check_points) logging.info('Created checkpoint directory') except OSError: pass torch.save(net.state_dict(), check_points + f'CP_epoch{epoch + 1}.pth') logging.info(f'Checkpoint {epoch + 1} saved !') writer.close()
def run(): print('loop') # torch.backends.cudnn.enabled = False device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # device = torch.device("cpu") # Assuming that we are on a CUDA machine, this should print a CUDA device: print(device) Dx = Discriminator().to(device) Gx = UNet(3, 3).to(device) Dy = Discriminator().to(device) Gy = UNet(3, 3).to(device) ld = False if ld: try: Gx.load_state_dict(torch.load('./genx')) Dx.load_state_dict(torch.load('./fcnx')) Gy.load_state_dict(torch.load('./geny')) Dy.load_state_dict(torch.load('./fcny')) print('net loaded') except Exception as e: print(e) dataset = 'ukiyoe2photo' # A 562 image_path_A = './datasets/' + dataset + '/trainA/*.jpg' image_path_B = './datasets/' + dataset + '/trainB/*.jpg' plt.ion() train_image_paths_A = glob.glob(image_path_A) train_image_paths_B = glob.glob(image_path_B) print(len(train_image_paths_A), len(train_image_paths_B)) b_size = 8 train_dataset_A = CustomDataset(train_image_paths_A, train=True) train_loader_A = torch.utils.data.DataLoader(train_dataset_A, batch_size=b_size, shuffle=True, num_workers=4, pin_memory=False, drop_last=True) train_dataset_B = CustomDataset(train_image_paths_B, True, 562, train=True) train_loader_B = torch.utils.data.DataLoader(train_dataset_B, batch_size=b_size, shuffle=True, num_workers=4, pin_memory=False, drop_last=True) Gx.train() Dx.train() Gy.train() Dy.train() criterion = nn.BCEWithLogitsLoss().to(device) # criterion2 = nn.SmoothL1Loss().to(device) criterion2 = nn.L1Loss().to(device) g_lr = 2e-4 d_lr = 2e-4 optimizer_x = optim.Adam(Gx.parameters(), lr=g_lr, betas=(0.5, 0.999)) optimizer_x_d = optim.Adam(Dx.parameters(), lr=d_lr, betas=(0.5, 0.999)) optimizer_y = optim.Adam(Gy.parameters(), lr=g_lr, betas=(0.5, 0.999)) optimizer_y_d = optim.Adam(Dy.parameters(), lr=d_lr, betas=(0.5, 0.999)) # cp = cropper().to(device) _zero = torch.from_numpy(np.zeros((b_size, 1))).float().to(device) _zero.requires_grad = False _one = torch.from_numpy(np.ones((b_size, 1))).float().to(device) _one.requires_grad = False for epoch in trange(100, desc='epoch'): # loop = tqdm(zip(train_loader_A, train_loader_B), desc='iteration') loop = zip(tqdm(train_loader_A, desc='iteration'), train_loader_B) batch_idx = 0 for data_A, data_B in loop: batch_idx += 1 zero = _zero one = _one _data_A = data_A.to(device) _data_B = data_B.to(device) # Dy loss (A -> B) gen = Gy(_data_A) optimizer_y_d.zero_grad() output2_p = Dy(_data_B.detach()) output_p = Dy(gen.detach()) errD = ( criterion(output2_p - torch.mean(output_p), one.detach()) + criterion(output_p - torch.mean(output2_p), zero.detach())) / 2 errD.backward() optimizer_y_d.step() # Dx loss (B -> A) gen = Gx(_data_B) optimizer_x_d.zero_grad() output2_p = Dx(_data_A.detach()) output_p = Dx(gen.detach()) errD = ( criterion(output2_p - torch.mean(output_p), one.detach()) + criterion(output_p - torch.mean(output2_p), zero.detach())) / 2 errD.backward() optimizer_x_d.step() # Gy loss (A -> B) optimizer_y.zero_grad() gen = Gy(_data_A) output_p = Dy(gen) output2_p = Dy(_data_B.detach()) g_loss = ( criterion(output2_p - torch.mean(output_p), zero.detach()) + criterion(output_p - torch.mean(output2_p), one.detach())) / 2 # Gy cycle loss (B -> A -> B) fA = Gx(_data_B) gen = Gy(fA.detach()) c_loss = criterion2(gen, _data_B) errG = g_loss + c_loss errG.backward() optimizer_y.step() if batch_idx % 10 == 0: fig = plt.figure(1) fig.clf() plt.imshow((np.transpose(_data_B.detach().cpu().numpy()[0], (1, 2, 0)) + 1) / 2) fig.canvas.draw() fig.canvas.flush_events() fig = plt.figure(2) fig.clf() plt.imshow((np.transpose(fA.detach().cpu().numpy()[0], (1, 2, 0)) + 1) / 2) fig.canvas.draw() fig.canvas.flush_events() fig = plt.figure(3) fig.clf() plt.imshow((np.transpose(gen.detach().cpu().numpy()[0], (1, 2, 0)) + 1) / 2) fig.canvas.draw() fig.canvas.flush_events() # Gx loss (B -> A) optimizer_x.zero_grad() gen = Gx(_data_B) output_p = Dx(gen) output2_p = Dx(_data_A.detach()) g_loss = ( criterion(output2_p - torch.mean(output_p), zero.detach()) + criterion(output_p - torch.mean(output2_p), one.detach())) / 2 # Gx cycle loss (A -> B -> A) fB = Gy(_data_A) gen = Gx(fB.detach()) c_loss = criterion2(gen, _data_A) errG = g_loss + c_loss errG.backward() optimizer_x.step() torch.save(Gx.state_dict(), './genx') torch.save(Dx.state_dict(), './fcnx') torch.save(Gy.state_dict(), './geny') torch.save(Dy.state_dict(), './fcny') print('\nFinished Training')
def train(): # 训练的epoch数 epoch = 500 # 数据文件夹 img_dir = "./data/training/images" # 掩模文件夹 mask_dir = "./data/training/1st_manual" # 网络输入图片大小 img_size = (512, 512) # 创建训练loader和验证loader tr_loader = DataLoader(DRIVE_Loader(img_dir, mask_dir, img_size, 'train'), batch_size=4, shuffle=True, num_workers=2, pin_memory=True, drop_last=True) val_loader = DataLoader(DRIVE_Loader(img_dir, mask_dir, img_size, 'val'), batch_size=4, shuffle=True, num_workers=2, pin_memory=True, drop_last=True) # 定义损失函数 criterion = DiceBCELoss() # 把网络加载到显卡 network = UNet().cuda() # 定义优化器 optimizer = Adam(network.parameters(), weight_decay=0.0001) best_score = 1.0 for i in range(epoch): # 设置为训练模式,会更新BN和Dropout参数 network.train() train_step = 0 train_loss = 0 val_loss = 0 val_step = 0 # 训练 for batch in tr_loader: # 读取每个batch的数据和掩模 imgs, mask = batch # 把数据加载到显卡 imgs = imgs.cuda() mask = mask.cuda() # 把数据喂入网络,获得一个预测结果 mask_pred = network(imgs) # 根据预测结果与掩模求出Loss loss = criterion(mask_pred, mask) # 统计训练loss train_loss += loss.item() train_step += 1 # 梯度清零 optimizer.zero_grad() # 通过loss求出梯度 loss.backward() # 使用Adam进行梯度回传 optimizer.step() # 设置为验证模式,不更新BN和Dropout参数 network.eval() # 验证 with torch.no_grad(): for batch in val_loader: imgs, mask = batch imgs = imgs.cuda() mask = mask.cuda() # 求出评价指标,这里用的是dice val_loss += DiceLoss()(network(imgs), mask).item() val_step += 1 # 分别求出整个epoch的训练loss以及验证指标 train_loss /= train_step val_loss /= val_step # 如果验证指标比最优值更好,那么保存当前模型参数 if val_loss < best_score: best_score = val_loss torch.save(network.state_dict(), "./checkpoint.pth") # 输出 print(str(i), "train_loss:", train_loss, "val_dice", val_loss)