def stylize_folder(style_path, folder_containing_the_content_folder, save_folder, batch_size=1): """Stylizes images in a folder by batch If the images are of different dimensions, use transform.resize() or use a batch size of 1 IMPORTANT: Put content_folder inside another folder folder_containing_the_content_folder folder_containing_the_content_folder content_folder pic1.ext pic2.ext pic3.ext ... and saves as the styled images in save_folder as follow: save_folder pic1.ext pic2.ext pic3.ext ... """ # Device device = ("cuda" if torch.cuda.is_available() else "cpu") # Image loader transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) image_dataset = utils.ImageFolderWithPaths( folder_containing_the_content_folder, transform=transform) image_loader = torch.utils.data.DataLoader(image_dataset, batch_size=batch_size) # Load Transformer Network net = transformer.TransformerNetwork() net.load_state_dict(torch.load(style_path)) net = net.to(device) # Stylize batches of images with torch.no_grad(): for content_batch, _, path in image_loader: # Free-up unneeded cuda memory torch.cuda.empty_cache() # Generate image generated_tensor = net(content_batch.to(device)).detach() # Save images for i in range(len(path)): generated_image = utils.ttoi(generated_tensor[i]) if (PRESERVE_COLOR): generated_image = utils.transfer_color( content_image, generated_image) image_name = os.path.basename(path[i]) utils.saveimg(generated_image, save_folder + image_name) #stylize()
def stylize(): # Device device = ("cuda" if torch.cuda.is_available() else "cpu") # Load Transformer Network net = transformer.TransformerNetwork() net.load_state_dict(torch.load(STYLE_TRANSFORM_PATH)) net = net.to(device) with torch.no_grad(): while (1): torch.cuda.empty_cache() print("Stylize Image~ Press Ctrl+C and Enter to close the program") content_image_path = input("Enter the image path: ") content_image = utils.load_image(content_image_path) starttime = time.time() content_tensor = utils.itot(content_image).to(device) generated_tensor = net(content_tensor) generated_image = utils.ttoi(generated_tensor.detach()) if (PRESERVE_COLOR): generated_image = utils.transfer_color(content_image, generated_image) print("Transfer Time: {}".format(time.time() - starttime)) utils.show(generated_image) utils.saveimg(generated_image, "helloworld.jpg")
def stylize( content_image_path=None, style_path="fast_neural_style_pytorch/transforms/starry.pth", ): # Device device = "cuda" if torch.cuda.is_available() else "cpu" # Load Transformer Network net = transformer.TransformerNetwork() net.load_state_dict( torch.load(style_path, map_location=torch.device("cpu"))) net = net.to(device) with torch.no_grad(): while 1: torch.cuda.empty_cache() print("Stylize Image~ Press Ctrl+C and Enter to close the program") if content_image_path == None: content_image_path = input("Enter the image path: ") content_image = utils.load_image(content_image_path) starttime = time.time() content_tensor = utils.itot(content_image).to(device) generated_tensor = net(content_tensor) generated_image = utils.ttoi(generated_tensor.detach()) if PRESERVE_COLOR: generated_image = utils.transfer_color(content_image, generated_image) print("Transfer Time: {}".format(time.time() - starttime)) # utils.show(generated_image) # utils.saveimg(generated_image, "helloworld.jpg") return generated_image
def webcam(style_transform_path, width=1280, height=720): """ Captures and saves an image, perform style transfer, and again saves the styled image. Reads the styled image and show in window. """ # Device device = ("cuda" if torch.cuda.is_available() else "cpu") # Load Transformer Network print("Loading Transformer Network") net = transformer.TransformerNetwork() net.load_state_dict(torch.load(style_transform_path)) net = net.to(device) print("Done Loading Transformer Network") # Set webcam settings cam = cv2.VideoCapture(0) cam.set(3, width) cam.set(4, height) # Main loop with torch.no_grad(): while True: # Get webcam input ret_val, img = cam.read() # Mirror img = cv2.flip(img, 1) # Free-up unneeded cuda memory torch.cuda.empty_cache() # Generate image content_tensor = utils.itot(img).to(device) generated_tensor = net(content_tensor) generated_image = utils.ttoi(generated_tensor.detach()) if (PRESERVE_COLOR): generated_image = utils.transfer_color(img, generated_image) generated_image = generated_image / 255 # Show webcam cv2.imshow('Demo webcam', generated_image) if cv2.waitKey(1) == 27: break # esc to quit # Free-up memories cam.release() cv2.destroyAllWindows()
def stylize_folder_single(style_path, content_folder, save_folder): """ Reads frames/pictures as follows: content_folder pic1.ext pic2.ext pic3.ext ... and saves as the styled images in save_folder as follow: save_folder pic1.ext pic2.ext pic3.ext ... """ # Device device = "cuda" if torch.cuda.is_available() else "cpu" # Load Transformer Network net = transformer.TransformerNetwork() net.load_state_dict( torch.load(style_path, map_location=torch.device("cpu"))) net = net.to(device) # Stylize every frame images = [ img for img in os.listdir(content_folder) if img.endswith(".jpg") ] with torch.no_grad(): for image_name in images: # Free-up unneeded cuda memory torch.cuda.empty_cache() # Load content image content_image = utils.load_image(content_folder + image_name) content_tensor = utils.itot(content_image).to(device) # Generate image generated_tensor = net(content_tensor) generated_image = utils.ttoi(generated_tensor.detach()) if PRESERVE_COLOR: generated_image = utils.transfer_color(content_image, generated_image) # Save image utils.saveimg(generated_image, save_folder + image_name)
def train(): # Seeds torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) np.random.seed(SEED) random.seed(SEED) # Device device = ("cuda" if torch.cuda.is_available() else "cpu") # Dataset and Dataloader transform = transforms.Compose([ transforms.Resize(TRAIN_IMAGE_SIZE), transforms.CenterCrop(TRAIN_IMAGE_SIZE), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255)) ]) train_dataset = datasets.ImageFolder(DATASET_PATH, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) # Load networks TransformerNetwork = transformer.TransformerNetwork().to(device) VGG = vgg.VGG16().to(device) # Get Style Features imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1, 3, 1, 1).to(device) style_image = utils.load_image(STYLE_IMAGE_PATH) style_tensor = utils.itot(style_image).to(device) style_tensor = style_tensor.add(imagenet_neg_mean) B, C, H, W = style_tensor.shape style_features = VGG(style_tensor.expand([BATCH_SIZE, C, H, W])) style_gram = {} for key, value in style_features.items(): style_gram[key] = utils.gram(value) # Optimizer settings optimizer = optim.Adam(TransformerNetwork.parameters(), lr=ADAM_LR) # Loss trackers content_loss_history = [] style_loss_history = [] total_loss_history = [] batch_content_loss_sum = 0 batch_style_loss_sum = 0 batch_total_loss_sum = 0 # Optimization/Training Loop batch_count = 1 start_time = time.time() for epoch in range(NUM_EPOCHS): print("========Epoch {}/{}========".format(epoch + 1, NUM_EPOCHS)) for content_batch, _ in train_loader: # Get current batch size in case of odd batch sizes curr_batch_size = content_batch.shape[0] # Free-up unneeded cuda memory torch.cuda.empty_cache() # Zero-out Gradients optimizer.zero_grad() # Generate images and get features content_batch = content_batch[:, [2, 1, 0]].to(device) generated_batch = TransformerNetwork(content_batch) content_features = VGG(content_batch.add(imagenet_neg_mean)) generated_features = VGG(generated_batch.add(imagenet_neg_mean)) # Content Loss MSELoss = nn.MSELoss().to(device) content_loss = CONTENT_WEIGHT * MSELoss( generated_features['relu2_2'], content_features['relu2_2']) batch_content_loss_sum += content_loss # Style Loss style_loss = 0 for key, value in generated_features.items(): s_loss = MSELoss(utils.gram(value), style_gram[key][:curr_batch_size]) style_loss += s_loss style_loss *= STYLE_WEIGHT batch_style_loss_sum += style_loss.item() # Total Loss total_loss = content_loss + style_loss batch_total_loss_sum += total_loss.item() # Backprop and Weight Update total_loss.backward() optimizer.step() # Save Model and Print Losses if (((batch_count - 1) % SAVE_MODEL_EVERY == 0) or (batch_count == NUM_EPOCHS * len(train_loader))): # Print Losses print("========Iteration {}/{}========".format( batch_count, NUM_EPOCHS * len(train_loader))) print("\tContent Loss:\t{:.2f}".format(batch_content_loss_sum / batch_count)) print("\tStyle Loss:\t{:.2f}".format(batch_style_loss_sum / batch_count)) print("\tTotal Loss:\t{:.2f}".format(batch_total_loss_sum / batch_count)) print("Time elapsed:\t{} seconds".format(time.time() - start_time)) # Save Model checkpoint_path = SAVE_MODEL_PATH + "checkpoint_" + str( batch_count - 1) + ".pth" torch.save(TransformerNetwork.state_dict(), checkpoint_path) print("Saved TransformerNetwork checkpoint file at {}".format( checkpoint_path)) # Save sample generated image sample_tensor = generated_batch[0].clone().detach().unsqueeze( dim=0) # ???: domain=(-1.4763, 508.0564), img is clipped to (0,255) in utils.saveimg print("\tImage, domain:\t({:.1f},{:.1f})".format( torch.min(sample_tensor).numpy(), torch.max(sample_tensor).numpy())) sample_image = utils.ttoi(sample_tensor.clone().detach()) sample_image_path = SAVE_IMAGE_PATH + "sample0_" + str( batch_count - 1) + ".png" utils.saveimg(sample_image, sample_image_path) print("Saved sample tranformed image at {}".format( sample_image_path)) # Save loss histories content_loss_history.append(batch_total_loss_sum / batch_count) style_loss_history.append(batch_style_loss_sum / batch_count) total_loss_history.append(batch_total_loss_sum / batch_count) # Iterate Batch Counter batch_count += 1 stop_time = time.time() # Print loss histories print("Done Training the Transformer Network!") print("Training Time: {} seconds".format(stop_time - start_time)) print("========Content Loss========") print(content_loss_history) print("========Style Loss========") print(style_loss_history) print("========Total Loss========") print(total_loss_history) # Save TransformerNetwork weights TransformerNetwork.eval() TransformerNetwork.cpu() final_path = SAVE_MODEL_PATH + "transformer_weight.pth" print("Saving TransformerNetwork weights at {}".format(final_path)) torch.save(TransformerNetwork.state_dict(), final_path) print("Done saving final model") # Plot Loss Histories if (PLOT_LOSS): utils.plot_loss_hist(content_loss_history, style_loss_history, total_loss_history)