def webcam(style_transform_path, width=1280, height=720): """ Captures and saves an image, perform style transfer, and again saves the styled image. Reads the styled image and show in window. Saving and loading SHOULD BE eliminated, however this produces too much whitening in the "generated styled image". This may be caused by the async nature of VideoCapture, and I don't know how to fix it. """ # Device device = ("cuda" if torch.cuda.is_available() else "cpu") # Load Transformer Network print("Loading Transformer Network") net = transformer.TransformerNetwork() net.load_state_dict(torch.load(style_transform_path)) net = net.to(device) print("Done Loading Transformer Network") # Set webcam settings cam = cv2.VideoCapture(0) cam.set(3, width) cam.set(4, height) # Main loop with torch.no_grad(): count = 1 while True: # Get webcam input ret_val, img = cam.read() # Mirror img = cv2.flip(img, 1) utils.saveimg(img, str(count) + ".png") # Free-up unneeded cuda memory torch.cuda.empty_cache() # Generate image content_image = utils.load_image(str(count) + ".png") content_tensor = utils.itot(content_image).to(device) generated_tensor = net(content_tensor) generated_image = utils.ttoi(generated_tensor.detach()) if (PRESERVE_COLOR): generated_image = utils.transfer_color(content_image, generated_image) utils.saveimg(generated_image, str(count + 1) + ".png") img2 = cv2.imread(str(count + 1) + ".png") count += 2 # Show webcam cv2.imshow('Demo webcam', img2) if cv2.waitKey(1) == 27: break # esc to quit # Free-up memories cam.release() cv2.destroyAllWindows()
def stylize_folder(style_path, folder_containing_the_content_folder, save_folder, batch_size=1): """Stylizes images in a folder by batch If the images are of different dimensions, use transform.resize() or use a batch size of 1 IMPORTANT: Put content_folder inside another folder folder_containing_the_content_folder folder_containing_the_content_folder content_folder pic1.ext pic2.ext pic3.ext ... and saves as the styled images in save_folder as follow: save_folder pic1.ext pic2.ext pic3.ext ... """ # Device device = ("cuda" if torch.cuda.is_available() else "cpu") # Image loader transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) image_dataset = utils.ImageFolderWithPaths(folder_containing_the_content_folder, transform=transform) image_loader = torch.utils.data.DataLoader(image_dataset, batch_size=batch_size) # Load Transformer Network net = transformer.TransformerNetwork() net.load_state_dict(torch.load(style_path)) net = net.to(device) # Stylize batches of images with torch.no_grad(): for content_batch, _, path in image_loader: # Free-up unneeded cuda memory torch.cuda.empty_cache() # Generate image generated_tensor = net(content_batch.to(device)).detach() # Save images for i in range(len(path)): generated_image = utils.ttoi(generated_tensor[i]) if (PRESERVE_COLOR): generated_image = utils.transfer_color(content_image, generated_image) image_name = os.path.basename(path[i]) utils.saveimg(generated_image, save_folder + image_name) #stylize()
def test(framework): from utils import saveimg load_network(framework.model, os.path.join(exp_dir, args.checkpoint)) framework.cuda() if config['framework'] == 'voxelflow': results = framework.predict(generator_k(1, framework.val_loader)) saveimg(results['truth'], results['pred'], os.path.join(exp_dir, 'vis.jpg'), n=10) print(((results['truth'] - results['pred']) ** 2).mean()) for subset in args.test: data, loader = framework.build_test_dataset(subset) result = framework.test(data, loader) print(print_dict(result, subset))
def stylize_folder_single(style_path, content_folder, save_folder): """ Reads frames/pictures as follows: content_folder pic1.ext pic2.ext pic3.ext ... and saves as the styled images in save_folder as follow: save_folder pic1.ext pic2.ext pic3.ext ... """ # Device device = ("cuda" if torch.cuda.is_available() else "cpu") # Load Transformer Network net = transformer.TransformerNetwork() net.load_state_dict(torch.load(style_path)) net = net.to(device) # Stylize every frame images = [img for img in os.listdir(content_folder) if img.endswith(".jpg")] with torch.no_grad(): for image_name in images: # Free-up unneeded cuda memory torch.cuda.empty_cache() # Load content image content_image = utils.load_image(content_folder + image_name) content_tensor = utils.itot(content_image).to(device) # Generate image generated_tensor = net(content_tensor) generated_image = utils.ttoi(generated_tensor.detach()) if (PRESERVE_COLOR): generated_image = utils.transfer_color(content_image, generated_image) # Save image utils.saveimg(generated_image, save_folder + image_name)
def stylize(): # Device device = ("cuda" if torch.cuda.is_available() else "cpu") # Load Transformer Network net = transformer.TransformerNetwork() net.load_state_dict(torch.load(STYLE_TRANSFORM_PATH)) net = net.to(device) with torch.no_grad(): while(1): torch.cuda.empty_cache() print("Stylize Image~ Press Ctrl+C and Enter to close the program") content_image_path = input("Enter the image path: ") content_image = utils.load_image(content_image_path) starttime = time.time() content_tensor = utils.itot(content_image).to(device) generated_tensor = net(content_tensor) generated_image = utils.ttoi(generated_tensor.detach()) if (PRESERVE_COLOR): generated_image = utils.transfer_color(content_image, generated_image) print("Transfer Time: {}".format(time.time() - starttime)) utils.show(generated_image) utils.saveimg(generated_image, "helloworld.jpg")
def train(): # Seeds torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) np.random.seed(SEED) random.seed(SEED) # Device device = "cuda" if torch.cuda.is_available() else "cpu" # Dataset and Dataloader transform = transforms.Compose([ transforms.Resize(TRAIN_IMAGE_SIZE), transforms.CenterCrop(TRAIN_IMAGE_SIZE), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)), ]) train_dataset = datasets.ImageFolder(DATASET_PATH, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) # Load networks TransformerNetwork = transformer.TransformerNetwork().to(device) VGG = vgg.VGG16().to(device) # Get Style Features imagenet_neg_mean = (torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape( 1, 3, 1, 1).to(device)) style_image = utils.load_image(STYLE_IMAGE_PATH) style_tensor = utils.itot(style_image).to(device) style_tensor = style_tensor.add(imagenet_neg_mean) B, C, H, W = style_tensor.shape style_features = VGG(style_tensor.expand([BATCH_SIZE, C, H, W])) style_gram = {} for key, value in style_features.items(): style_gram[key] = utils.gram(value) # Optimizer settings optimizer = optim.Adam(TransformerNetwork.parameters(), lr=ADAM_LR) # Loss trackers content_loss_history = [] style_loss_history = [] total_loss_history = [] batch_content_loss_sum = 0 batch_style_loss_sum = 0 batch_total_loss_sum = 0 # Optimization/Training Loop batch_count = 1 start_time = time.time() for epoch in range(NUM_EPOCHS): print("========Epoch {}/{}========".format(epoch + 1, NUM_EPOCHS)) for content_batch, _ in train_loader: # Get current batch size in case of odd batch sizes curr_batch_size = content_batch.shape[0] # Free-up unneeded cuda memory torch.cuda.empty_cache() # Zero-out Gradients optimizer.zero_grad() # Generate images and get features content_batch = content_batch[:, [2, 1, 0]].to(device) generated_batch = TransformerNetwork(content_batch) content_features = VGG(content_batch.add(imagenet_neg_mean)) generated_features = VGG(generated_batch.add(imagenet_neg_mean)) # Content Loss MSELoss = nn.MSELoss().to(device) content_loss = CONTENT_WEIGHT * MSELoss( generated_features["relu2_2"], content_features["relu2_2"]) batch_content_loss_sum += content_loss # Style Loss style_loss = 0 for key, value in generated_features.items(): s_loss = MSELoss(utils.gram(value), style_gram[key][:curr_batch_size]) style_loss += s_loss style_loss *= STYLE_WEIGHT batch_style_loss_sum += style_loss.item() # Total Loss total_loss = content_loss + style_loss batch_total_loss_sum += total_loss.item() # Backprop and Weight Update total_loss.backward() optimizer.step() # Save Model and Print Losses if ((batch_count - 1) % SAVE_MODEL_EVERY == 0) or (batch_count == NUM_EPOCHS * len(train_loader)): # Print Losses print("========Iteration {}/{}========".format( batch_count, NUM_EPOCHS * len(train_loader))) print("\tContent Loss:\t{:.2f}".format(batch_content_loss_sum / batch_count)) print("\tStyle Loss:\t{:.2f}".format(batch_style_loss_sum / batch_count)) print("\tTotal Loss:\t{:.2f}".format(batch_total_loss_sum / batch_count)) print("Time elapsed:\t{} seconds".format(time.time() - start_time)) # Save Model checkpoint_path = (SAVE_MODEL_PATH + "checkpoint_" + str(batch_count - 1) + ".pth") torch.save(TransformerNetwork.state_dict(), checkpoint_path) print("Saved TransformerNetwork checkpoint file at {}".format( checkpoint_path)) # Save sample generated image sample_tensor = generated_batch[0].clone().detach().unsqueeze( dim=0) sample_image = utils.ttoi(sample_tensor.clone().detach()) sample_image_path = (SAVE_IMAGE_PATH + "sample0_" + str(batch_count - 1) + ".png") utils.saveimg(sample_image, sample_image_path) print("Saved sample tranformed image at {}".format( sample_image_path)) # Save loss histories content_loss_history.append(batch_total_loss_sum / batch_count) style_loss_history.append(batch_style_loss_sum / batch_count) total_loss_history.append(batch_total_loss_sum / batch_count) # Iterate Batch Counter batch_count += 1 stop_time = time.time() # Print loss histories print("Done Training the Transformer Network!") print("Training Time: {} seconds".format(stop_time - start_time)) print("========Content Loss========") print(content_loss_history) print("========Style Loss========") print(style_loss_history) print("========Total Loss========") print(total_loss_history) # Save TransformerNetwork weights TransformerNetwork.eval() TransformerNetwork.cpu() final_path = SAVE_MODEL_PATH + "transformer_weight.pth" print("Saving TransformerNetwork weights at {}".format(final_path)) torch.save(TransformerNetwork.state_dict(), final_path) print("Done saving final model") # Plot Loss Histories if PLOT_LOSS: utils.plot_loss_hist(content_loss_history, style_loss_history, total_loss_history)
def train(): # Seeds torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) np.random.seed(SEED) random.seed(SEED) # Device device = ("cuda" if torch.cuda.is_available() else "cpu") # Dataset and Dataloader transform = transforms.Compose([ transforms.Resize(TRAIN_IMAGE_SIZE), transforms.CenterCrop(TRAIN_IMAGE_SIZE), transforms.ToTensor() ]) x_trainloader, x_testloader = utils.prepare_loader(DATASET_PATH, X_CLASS, transform=transform, batch_size=BATCH_SIZE, shuffle=True) y_trainloader, y_testloader = utils.prepare_loader(DATASET_PATH, Y_CLASS, transform=transform, batch_size=BATCH_SIZE, shuffle=True) # Load Networks Gxy = models.Generator(64).to(device) Gyx = models.Generator(64).to(device) Dxy = models.Discriminator(64).to(device) Dyx = models.Discriminator(64).to(device) # Optimizer Settings G_param = list(Gxy.parameters()) + list(Gyx.parameters()) G_optim = optim.Adam(G_param, lr=LR, betas=[BETA_1, BETA_2]) Dxy_optim = optim.Adam(Dxy.parameters(), lr=LR, betas=[BETA_1, BETA_2]) Dyx_optim = optim.Adam(Dyx.parameters(), lr=LR, betas=[BETA_1, BETA_2]) # Losses from losses import real_loss, fake_loss, cycle_loss # Fixed test samples x_testiter = iter(x_testloader) y_testiter = iter(y_testloader) fixed_X = next(x_testiter)[0] fixed_X = fixed_X.to(device) fixed_Y = next(y_testiter)[0] fixed_Y = fixed_Y.to(device) # Tensor to Image fixed_X_image = utils.ttoi(fixed_X) fixed_Y_image = utils.ttoi(fixed_Y) # Number of batches x_trainiter = iter(x_trainloader) y_trainiter = iter(y_trainloader) iter_per_epoch = min(len(x_trainiter), len(y_trainiter)) # Training the CycleGAN for epoch in range(1, NUM_EPOCHS + 1): print("========Epoch {}/{}========".format(epoch, NUM_EPOCHS)) for _ in range(iter_per_epoch - 1): # -1 in case of imbalanced sizes of the last batch # Fetch the dataset x_real = next(x_trainiter)[0] x_real = x_real.to(device) y_real = next(y_trainiter)[0] y_real = y_real.to(device) # ========= Discriminator ========== # In training the discriminators, we fix the generators' parameters. # It is alright to train both discriminators seperately beceause # their forward pass don't share any parameters with each other # Discriminator X -> Y Adversarial Loss Dxy_optim.zero_grad() # Zero-out gradients Dxy_real_out = Dxy(y_real) # Dxy Forward Pass Dxy_real_loss = real_loss(Dxy_real_out) # Dxy Eeal loss Dxy_fake_out = Dxy(Gxy(x_real)) # Gxy produces fake-y images Dxy_fake_loss = fake_loss(Dxy_fake_out) # Dxy Fake Loss Dxy_loss = Dxy_real_loss + Dxy_fake_loss # Dxy Total Loss Dxy_loss.backward() # Dxy Backprop Dxy_optim.step() # Dxy Gradient Descent # Discriminator Y-> X Adversarial Loss Dyx_optim.zero_grad() # Zero-out gradients Dyx_real_out = Dyx(x_real) # Dyx Forward Pass Dyx_real_loss = real_loss(Dyx_real_out) # Dyx Eeal loss Dyx_fake_out = Dyx(Gyx(y_real)) # Gyx produces fake-x images Dyx_fake_loss = fake_loss(Dyx_fake_out) # Dyx Fake Loss Dyx_loss = Dyx_real_loss + Dyx_fake_loss # Dyx Total Loss Dyx_loss.backward() # Dyx Backprop Dyx_optim.step() # Dyx Gradient Descent # ============= Generator ============== # Similar to training discriminator networks, in training # generator networks, we fix discriminator networks. # However, cycle consistency prohibits us # from training generators seperately. # Generator X -> Y Adversarial Loss G_optim.zero_grad() # Zero-out gradients Gxy_out = Gxy(x_real) # Gxy Forward Pass D_Gxy_out = Dxy(Gxy_out) # Gxy -> Dxy Forward Gxy_loss = real_loss(D_Gxy_out) # Gxy Real Loss # Generator Y -> X Adversarial Loss Gyx_out = Gyx(y_real) # Gyx Forward Pass D_Gyx_out = Dyx(Gyx_out) # Gyx -> Dyx Forward Gyx_loss = real_loss(D_Gyx_out) # Gyx Real Loss # Cycle Consistency Loss y_x_y = Gxy(Gyx(x_real)) # Reconstruct Y yxy_cycle_loss = cycle_loss( y_x_y, y_real) # Y-X-Y Cycle Reconstruction Loss x_y_x = Gyx(Gxy(y_real)) # Reconstruct X xyx_cycle_loss = cycle_loss( x_y_x, x_real) # X-Y-X Cycle Reconstruction Loss # Generator Total Loss G_loss = Gxy_loss + Gyx_loss + CYCLE_WEIGHT * (xyx_cycle_loss + yxy_cycle_loss) G_loss.backward() G_optim.step() # Print Losses print("Dxy Loss: {} Dyx Loss: {} Generator Loss: {}".format( Dxy_loss.item(), Dyx_loss.item(), G_loss.item())) # Generate Sample Fake Images Gxy.eval() Gyx.eval() with torch.no_grad(): generated_y = Gyx(fixed_X) generated_y_img = utils.ttoi(generated_y.clone().detach()) generated_x = Gxy(fixed_Y) generated_x_img = utils.ttoi(generated_x.clone().detach()) H = W = TRAIN_IMAGE_SIZE concat_y = utils.concatenate_images(fixed_Y_image, generated_y_img, H, W) concat_x = utils.concatenate_images(fixed_X_image, generated_x_img, H, W) utils.saveimg(concat_x, "generated_x.png") utils.saveimg(concat_y, "generated_y.png")
def train(): # Seeds torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) np.random.seed(SEED) random.seed(SEED) # Device device = ("cuda" if torch.cuda.is_available() else "cpu") print("Device: {}".format(device)) # Prepare dataset to make it usable with ImageFolder. Please only do this once # Uncomment this when you encounter "RuntimeError: Found 0 files in subfolders of:" # prepare_dataset_and_folder(DATASET_PATH, [IMAGE_SAVE_FOLDER, MODEL_SAVE_FOLDER]) # Tranform, Dataset, DataLoaders transform = transforms.Compose([ transforms.Resize(TRAIN_IMAGE_SIZE), transforms.CenterCrop(TRAIN_IMAGE_SIZE), transforms.ToTensor() ]) x_trainloader, x_testloader = utils.prepare_loader(DATASET_PATH, X_CLASS, transform=transform, batch_size=BATCH_SIZE, shuffle=True) y_trainloader, y_testloader = utils.prepare_loader(DATASET_PATH, Y_CLASS, transform=transform, batch_size=BATCH_SIZE, shuffle=True) # [Iterators]: We need iterators for DataLoader because we # are fetching training samples from 2 or more DataLoaders x_trainiter, x_testiter = iter(x_trainloader), iter(x_testloader) y_trainiter, y_testiter = iter(y_trainloader), iter(y_testloader) # Load Networks Gxy = models.Generator(64).to(device) Gyx = models.Generator(64).to(device) Dxy = models.Discriminator(64).to(device) Dyx = models.Discriminator(64).to(device) # Initialize weights with Normal(0, 0.02) Gxy.apply(models.init_weights) Gyx.apply(models.init_weights) Dxy.apply(models.init_weights) Dyx.apply(models.init_weights) # Optimizers # Concatenate Generator ParamsSee Training Loop for explanation # Reference: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/cycle_gan_model.py#L94 G_optim = optim.Adam(itertools.chain(Gxy.parameters(), Gyx.parameters()), lr=LR, betas=[BETA_1, BETA_2]) Dxy_optim = optim.Adam(Dxy.parameters(), lr=LR, betas=[BETA_1, BETA_2]) Dyx_optim = optim.Adam(Dyx.parameters(), lr=LR, betas=[BETA_1, BETA_2]) # LR Scheduler # Reference: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py#L51 def lambda_rule(epoch): lr_l = 1.0 - max(0, epoch - START_LR_DECAY)/(NUM_EPOCHS - START_LR_DECAY) return lr_l lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(G_optim, lr_lambda=lambda_rule) lr_scheduler_Dxy = torch.optim.lr_scheduler.LambdaLR(Dxy_optim, lr_lambda=lambda_rule) lr_scheduler_Dyx = torch.optim.lr_scheduler.LambdaLR(Dyx_optim, lr_lambda=lambda_rule) # Some Helper Functions! # Output of generator is Tanh! so we need to scale real images accordingly def scale(tensor, mini=-1, maxi=1): return tensor * (maxi-mini) + mini # Fixed test samples fixed_X, _ = next(x_testiter) fixed_X = fixed_X.to(device) fixed_Y, _ = next(y_testiter) fixed_Y = fixed_Y.to(device) # Loss History for the Whole Training Process loss_hist = losses.createLogger(["Gxy", "Gyx", "Dxy", "Dyx", "cycle"]) # Number of batches iter_per_epoch = min(len(x_trainiter), len(y_trainiter)) print("There are {} batches per epoch".format(iter_per_epoch)) for epoch in range(1, NUM_EPOCHS+1): print("========Epoch {}/{}========".format(epoch, NUM_EPOCHS)) start_time = time.time() # Loss Logger for the Current Batch curr_loss_hist = losses.createLogger(["Gxy", "Gyx", "Dxy", "Dyx", "cycle"]) # Reset Iterators every epoch, otherwise, we'll have unequal batch sizes # or worse, reach the end of iterator and get a Stop Iteration Error x_trainiter = iter(x_trainloader) y_trainiter = iter(y_trainloader) for i in range(1, iter_per_epoch): # Get current batches x_real, _ = next(x_trainiter) x_real = scale(x_real) x_real = x_real.to(device) y_real, _ = next(y_trainiter) y_real = scale(y_real) y_real = y_real.to(device) # ========= Discriminator ========== # In training the discriminators, we fix the generators' parameters # It is alright to train both discriminators seperately because # their forward pass don't share any parameters with each other # Discriminator Y -> X Adversarial Loss Dyx_optim.zero_grad() # Zero-out Gradients Dyx_real_out = Dyx(x_real) # Dyx Forward Pass Dyx_real_loss = real_loss(Dyx_real_out) # Dyx Real Loss Dyx_fake_out = Dyx(Gyx(y_real)) # Gyx produces fake-x images Dyx_fake_loss = fake_loss(Dyx_fake_out) # Dyx Fake Loss Dyx_loss = Dyx_real_loss + Dyx_fake_loss # Dyx Total Loss Dyx_loss.backward() # Dyx Backprop Dyx_optim.step() # Dyx Gradient Descent # Discriminator X -> Y Adversarial Loss Dxy_optim.zero_grad() # Zero-out Gradients Dxy_real_out = Dxy(y_real) # Dxy Forward Pass Dxy_real_loss = real_loss(Dxy_real_out) # Dxy Real Loss Dxy_fake_out = Dxy(Gxy(x_real)) # Gxy produces fake y-images Dxy_fake_loss = fake_loss(Dxy_fake_out) # Dxy Fake Loss Dxy_loss = Dxy_real_loss + Dxy_fake_loss # Dxy Total Loss Dxy_loss.backward() # Dxy Backprop Dxy_optim.step() # Dxy Gradient Descent # ============= Generator ============== # Similar to training discriminator networks, in training # generator networks, we fix discriminator networks. # However, cycle consistency prohibits us # from training generators seperately. # Generator X -> Y Adversarial Loss G_optim.zero_grad() # Zero-out Gradients Gxy_out = Gxy(x_real) # Gxy Forward Pass : generates fake-y images D_Gxy_out = Dxy(Gxy_out) # Gxy -> Dxy Forward Pass Gxy_loss = real_loss(D_Gxy_out) # Gxy Real Loss # Generator Y -> X Adversarial Loss Gyx_out = Gyx(y_real) # Gyx Forward Pass : generates fake-x images D_Gyx_out = Dyx(Gyx_out) # Gyx -> Dyx Forward Pass Gyx_loss = real_loss(D_Gyx_out) # Gyx Real Loss # Cycle Consistency Loss yxy = Gxy(Gyx_out) # Reconstruct Y yxy_cycle_loss = cycle_loss(yxy, y_real) # Y-X-Y L1 Cycle Reconstruction Loss xyx = Gyx(Gxy_out) # Reconstruct X xyx_cycle_loss = cycle_loss(xyx, x_real) # X-Y-X L1 Cycle Reconstruction Loss G_cycle_loss = CYCLE_WEIGHT * (yxy_cycle_loss + xyx_cycle_loss) # Generator Total Loss G_loss = Gxy_loss + Gyx_loss + G_cycle_loss G_loss.backward() G_optim.step() # Record Losses curr_loss_hist = losses.updateEpochLogger(curr_loss_hist, [Gxy_loss, Gyx_loss, Dxy_loss, Dxy_loss, G_cycle_loss]) # Learning Rate Scheduler Step lr_scheduler_G.step() lr_scheduler_Dxy.step() lr_scheduler_Dyx.step() # Record and Print Losses print("Dxy: {} Dyx: {} G: {} Cycle: {}".format(Dxy_loss.item(), Dyx_loss.item(), G_loss.item(), G_cycle_loss.item())) print("Time Elapsed: {}".format(time.time() - start_time)) loss_hist = losses.updateGlobalLogger(loss_hist, curr_loss_hist) # Generate Fake Images Gxy.eval() Gyx.eval() with torch.no_grad(): # Generate Fake X Images x_tensor = generate.evaluate(Gyx, fixed_Y) # Generate Image Tensor x_images = utils.concat_images(x_tensor) # Merge Image Tensors -> Numpy Array save_path = IMAGE_SAVE_FOLDER + DATASET_PATH[:-1] + "X" + str(epoch) + ".png" utils.saveimg(x_images, save_path) # Generate Fake Y Images y_tensor = generate.evaluate(Gxy, fixed_X) # Generate Image Tensor y_images = utils.concat_images(y_tensor) # Merge Image Tensors -> Numpy Array save_path = IMAGE_SAVE_FOLDER + DATASET_PATH[:-1] + "Y" + str(epoch) + ".png" utils.saveimg(y_images, save_path) # Save Model Checkpoints if ((epoch % SAVE_MODEL_EVERY == 0) or (epoch == 1)): save_str = MODEL_SAVE_FOLDER + DATASET_PATH[:-1] + str(epoch) torch.save(Gxy.cpu().state_dict(), save_str + "_Gxy.pth") torch.save(Gyx.cpu().state_dict(), save_str + "_Gyx.pth") torch.save(Dxy.cpu().state_dict(), save_str + "_Dxy.pth") torch.save(Dxy.cpu().state_dict(), save_str + "_Dyx.pth") Gxy.to(device) Gyx.to(device) Dxy.to(device) Dyx.to(device) Gxy.train(); Gyx.train(); utils.plot_loss(loss_hist)
def allfilters(img, k, tc, mu, beta, imagename, original): """ This function applies all the filters and saves the output images and logs the RMSE. Filters applied: A, B, C, R1, R2, R3, R4, R3-Crisp, R4-Crisp and Median """ inp_img = img.copy() img_A = filterA(img, k, mu, beta) op_imagepath = os.path.join('images', 'enhanced', imagename+'_A.png') img_A = saveimg(op_imagepath, img_A) img_B = filterB(img, k, beta) op_imagepath = os.path.join('images', 'enhanced', imagename+'_B.png') img_B = saveimg(op_imagepath, img_B) img_C = filterC(img, k, mu, beta) op_imagepath = os.path.join('images', 'enhanced', imagename+'_C.png') img_C = saveimg(op_imagepath, img_C) img_Med = medfilt2d(img, k) op_imagepath = os.path.join('images', 'enhanced', imagename+'_Med.png') img_Med = saveimg(op_imagepath, img_Med) img_R1 = filterR1(img, k, tc, beta, img_A=img_A, img_B=img_B, img_C=img_C) op_imagepath = os.path.join('images', 'enhanced', imagename+'_R1.png') img_R1 = saveimg(op_imagepath, img_R1) img_R2 = filterR2(img, k, tc, mu, beta, img_A=img_A, img_B=img_B, img_C=img_C) op_imagepath = os.path.join('images', 'enhanced', imagename+'_R2.png') img_R2 = saveimg(op_imagepath, img_R2) img_R3 = filterR3(img, k, tc, mu, beta, img_A=img_A, img_B=img_B, img_C=img_C) op_imagepath = os.path.join('images', 'enhanced', imagename+'_R3.png') img_R3 = saveimg(op_imagepath, img_R3) img_R3Crisp = filterR3Crisp( img, k, tc, mu, beta, img_A=img_A, img_B=img_B, img_C=img_C) op_imagepath = os.path.join('images', 'enhanced', imagename+'_R3Crisp.png') img_R3Crisp = saveimg(op_imagepath, img_R3Crisp) img_R4 = filterR4(img, k, tc, mu, beta, img_A=img_A, img_B=img_B, img_C=img_C) op_imagepath = os.path.join('images', 'enhanced', imagename+'_R4.png') img_R4 = saveimg(op_imagepath, img_R4) img_R4Crisp = filterR4(img, k, tc, mu, beta, img_A=img_A, img_B=img_B, img_C=img_C) op_imagepath = os.path.join('images', 'enhanced', imagename+'_R4Crisp.png') img_R4Crisp = saveimg(op_imagepath, img_R4Crisp) if original is not None: orig_imagepath = os.path.join('images', original) orig_img = imageio.imread(orig_imagepath) err = rmse(img_A, orig_img) print('RMSE of filterA (against original image):{}'.format(err)) err = rmse(img_B, orig_img) print('RMSE of filterB (against original image):{}'.format(err)) err = rmse(img_C, orig_img) print('RMSE of filterC (against original image):{}'.format(err)) err = rmse(img_Med, orig_img) print('RMSE of filterMed (against original image):{}'.format(err)) err = rmse(img_R1, orig_img) print('RMSE of filterR1 (against original image):{}'.format(err)) err = rmse(img_R2, orig_img) print('RMSE of filterR2 (against original image):{}'.format(err)) err = rmse(img_R3, orig_img) print('RMSE of filterR3 (against original image):{}'.format(err)) err = rmse(img_R3Crisp, orig_img) print('RMSE of filterR3Crisp (against original image):{}'.format(err)) err = rmse(img_R4, orig_img) print('RMSE of filterR4 (against original image):{}'.format(err)) err = rmse(img_R4Crisp, orig_img) print('RMSE of filterR3Crisp (against original image):{}'.format(err)) noisy_err = rmse(orig_img, inp_img) print('RMSE of input image (against original image):{}'.format(noisy_err))
def main(args): config = Config(args.config) cfg = config(vars(args), mode=['infer', 'init']) scale = cfg['infer']['scale'] mdname = cfg['infer']['model'] imgname = ''.join(mdname) # + '/' + str(scale) #dirname = ''.join(mdname) + '_' + str(scale) sz = cfg['infer']['sz'] infer_size = cfg['infer']['infer_size'] #save_path = os.path.join(args.save_dir, cfg['init']['result']) list = cfg['infer']['infer_size'] save_path = create_path(args.save_dir, cfg['init']['result']) save_path = create_path(save_path, imgname) save_path = create_path(save_path, str(scale)) tif_path = create_path(save_path, cfg['infer']['lab']) color_path = create_path(save_path, cfg['infer']['color']) gray_path = create_path(save_path, cfg['infer']['gray']) vdl_dir = os.path.join(args.save_dir, cfg['init']['vdl_dir']) palette = cfg['infer']['palette'] palette = np.array(palette, dtype=np.uint8) num_class = cfg['init']['num_classes'] batchsz = cfg['infer']['batchsz'] infer_path = os.path.join(cfg['infer']['root_path'], cfg['infer']['path']) tagname = imgname + '/' + str(scale) vdl_dir = os.path.join(vdl_dir, 'infer') writer = LogWriter(logdir=vdl_dir) infer_ds = TeDataset(path=cfg['infer']['root_path'], fl=cfg['infer']['path'], sz=sz) total = len(infer_ds) # select model #addresult = np.zeros((total//batchsz,batchsz,num_class,sz,sz)) addresult = np.zeros((total, num_class, sz, sz)) for mnet in mdname: result_list = [] net = modelset(mode=mnet, num_classes=cfg['init']['num_classes']) # load moel input = InputSpec([None, 3, 64, 64], 'float32', 'x') label = InputSpec([None, 1, 64, 64], 'int64', 'label') model = paddle.Model(net, input, label) model.load(path=os.path.join(args.save_dir, mnet) + '/' + mnet) model.prepare() result = model.predict( infer_ds, batch_size=batchsz, num_workers=cfg['infer']['num_workers'], stack_outputs=True # [160,2,64,64] ) addresult = result[0] + scale * addresult pred = construct(addresult, infer_size, sz=sz) # pred = construct(addresult,infer_size,sz = sz) # # 腐蚀膨胀 # read vdl file_list = os.listdir(infer_path) file_list.sort(key=lambda x: int(x[-5:-4])) step = 0 for i, fl in enumerate(file_list): name, _ = fl.split(".") # save pred lab_img = Image.fromarray(pred[i].astype(np.uint8)).convert("L") saveimg(lab_img, tif_path, name=name, type='.tif') # gray_label label = colorize(pred[i], palette) writer.add_image(tag=tagname, img=saveimg(label, gray_path, name=name, type='.png', re_out=True), step=step, dataformats='HW') step += 1 # color_label file = os.path.join(infer_path, fl) out = blend_image(file, label, alpha=0.25) writer.add_image(tag=tagname, img=saveimg(out, color_path, name=name, type='.png', re_out=True), step=step, dataformats='HWC') step += 1 writer.close()
inp_img = imageio.imread(imagepath) img = inp_img.astype(float) # Save a copy as inp_img will be modified assert img.shape == (256, 256), 'Image is not of size (255,255).' orig_imagepath = os.path.join('images', args.original) orig_img = imageio.imread(orig_imagepath) if args.method == 'Sharpen': blurred_img = ndimage.gaussian_filter(img, 3) filter_blurred_img = ndimage.gaussian_filter(blurred_img, 1) alpha = 30 sharpened = blurred_img + alpha * (blurred_img - filter_blurred_img) op_imagepath = os.path.join( 'images', 'enhanced', os.path.basename(imagepath)[:-4] + '_sharpen.png') sharpened = saveimg(op_imagepath, sharpened) err = rmse(sharpened, orig_img) print('RMSE of filterSharpen (against original image):{}'.format(err)) if args.method == 'Gauss': gauss_denoised = ndimage.gaussian_filter(img, 2) op_imagepath = os.path.join( 'images', 'enhanced', os.path.basename(imagepath)[:-4] + '_gauss.png') gauss_denoised = saveimg(op_imagepath, gauss_denoised) err = rmse(gauss_denoised, orig_img) print('RMSE of filterGauss (against original image):{}'.format(err)) if args.method == 'TVC': tvc = denoise_tv_chambolle(img, weight=30, multichannel=False) op_imagepath = os.path.join('images', 'enhanced',
def train(): # Seeds torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) np.random.seed(SEED) random.seed(SEED) # Device device = ("cuda" if torch.cuda.is_available() else "cpu") # Dataset and Dataloader transform = transforms.Compose([ transforms.Resize(TRAIN_IMAGE_SIZE), transforms.CenterCrop(TRAIN_IMAGE_SIZE), # transforms.Grayscale(num_output_channels=3), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(DATASET_PATH, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) # Load networks TransformerNetwork = transformer.TransformerNetwork().to(device) if USE_LATEST_CHECKPOINT is True: files = glob.glob( "/home/clng/github/fast-neural-style-pytorch/models/checkpoint*") if len(files) == 0: print("use latest checkpoint but no checkpoint found") else: files.sort(key=os.path.getmtime, reverse=True) latest_checkpoint_path = files[0] print("using latest checkpoint %s" % (latest_checkpoint_path)) params = torch.load(latest_checkpoint_path, map_location=device) TransformerNetwork.load_state_dict(params) VGG = vgg.VGG19().to(device) # Get Style Features imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1, 3, 1, 1).to(device) style_image = utils.load_image(STYLE_IMAGE_PATH) if ADJUST_BRIGHTNESS == "1": style_image = cv2.cvtColor(style_image, cv2.COLOR_BGR2GRAY) style_image = utils.hist_norm(style_image, [0, 64, 96, 128, 160, 192, 255], [0, 0.05, 0.15, 0.5, 0.85, 0.95, 1], inplace=True) elif ADJUST_BRIGHTNESS == "2": style_image = cv2.cvtColor(style_image, cv2.COLOR_BGR2GRAY) style_image = cv2.equalizeHist(style_image) elif ADJUST_BRIGHTNESS == "3": a = 1 # hsv = cv2.cvtColor(style_image, cv2.COLOR_BGR2HSV) # hsv = utils.auto_brightness(hsv) # style_image = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) style_image = ensure_three_channels(style_image) sname = os.path.splitext(os.path.basename(STYLE_IMAGE_PATH))[0] + "_train" cv2.imwrite( "/home/clng/datasets/bytenow/neural_styles/{s}.jpg".format(s=sname), style_image) style_tensor = utils.itot(style_image, max_size=TRAIN_STYLE_SIZE).to(device) style_tensor = style_tensor.add(imagenet_neg_mean) B, C, H, W = style_tensor.shape style_features = VGG(style_tensor.expand([BATCH_SIZE, C, H, W])) style_gram = {} for key, value in style_features.items(): style_gram[key] = utils.gram(value) # Optimizer settings optimizer = optim.Adam(TransformerNetwork.parameters(), lr=ADAM_LR) # Loss trackers content_loss_history = [] style_loss_history = [] total_loss_history = [] batch_content_loss_sum = 0 batch_style_loss_sum = 0 batch_total_loss_sum = 0 # Optimization/Training Loop batch_count = 1 start_time = time.time() for epoch in range(NUM_EPOCHS): print("========Epoch {}/{}========".format(epoch + 1, NUM_EPOCHS)) for content_batch, _ in train_loader: # Get current batch size in case of odd batch sizes curr_batch_size = content_batch.shape[0] # Free-up unneeded cuda memory # torch.cuda.empty_cache() # Zero-out Gradients optimizer.zero_grad() # Generate images and get features content_batch = content_batch[:, [2, 1, 0]].to(device) generated_batch = TransformerNetwork(content_batch) content_features = VGG(content_batch.add(imagenet_neg_mean)) generated_features = VGG(generated_batch.add(imagenet_neg_mean)) # Content Loss MSELoss = nn.MSELoss().to(device) content_loss = CONTENT_WEIGHT * \ MSELoss(generated_features['relu3_4'], content_features['relu3_4']) batch_content_loss_sum += content_loss # Style Loss style_loss = 0 for key, value in generated_features.items(): s_loss = MSELoss(utils.gram(value), style_gram[key][:curr_batch_size]) style_loss += s_loss style_loss *= STYLE_WEIGHT batch_style_loss_sum += style_loss.item() # Total Loss total_loss = content_loss + style_loss batch_total_loss_sum += total_loss.item() # Backprop and Weight Update total_loss.backward() optimizer.step() # Save Model and Print Losses if (((batch_count - 1) % SAVE_MODEL_EVERY == 0) or (batch_count == NUM_EPOCHS * len(train_loader))): # Print Losses print("========Iteration {}/{}========".format( batch_count, NUM_EPOCHS * len(train_loader))) print("\tContent Loss:\t{:.2f}".format(batch_content_loss_sum / batch_count)) print("\tStyle Loss:\t{:.2f}".format(batch_style_loss_sum / batch_count)) print("\tTotal Loss:\t{:.2f}".format(batch_total_loss_sum / batch_count)) print("Time elapsed:\t{} seconds".format(time.time() - start_time)) # Save Model checkpoint_path = SAVE_MODEL_PATH + "checkpoint_" + str( batch_count - 1) + ".pth" torch.save(TransformerNetwork.state_dict(), checkpoint_path) print("Saved TransformerNetwork checkpoint file at {}".format( checkpoint_path)) # Save sample generated image sample_tensor = generated_batch[0].clone().detach().unsqueeze( dim=0) sample_image = utils.ttoi(sample_tensor.clone().detach()) sample_image_path = SAVE_IMAGE_PATH + "sample0_" + str( batch_count - 1) + ".png" utils.saveimg(sample_image, sample_image_path) print("Saved sample tranformed image at {}".format( sample_image_path)) # Save loss histories content_loss_history.append(batch_total_loss_sum / batch_count) style_loss_history.append(batch_style_loss_sum / batch_count) total_loss_history.append(batch_total_loss_sum / batch_count) # Iterate Batch Counter batch_count += 1 stop_time = time.time() # Print loss histories print("Done Training the Transformer Network!") print("Training Time: {} seconds".format(stop_time - start_time)) print("========Content Loss========") print(content_loss_history) print("========Style Loss========") print(style_loss_history) print("========Total Loss========") print(total_loss_history) # Save TransformerNetwork weights TransformerNetwork.eval() TransformerNetwork.cpu() final_path = SAVE_MODEL_PATH + STYLE_NAME + ".pth" print("Saving TransformerNetwork weights at {}".format(final_path)) torch.save(TransformerNetwork.state_dict(), final_path) print("Done saving final model") # Plot Loss Histories if (PLOT_LOSS): utils.plot_loss_hist(content_loss_history, style_loss_history, total_loss_history)