コード例 #1
0
def webcam(style_transform_path, width=1280, height=720):
    """
    Captures and saves an image, perform style transfer, and again saves the styled image.
    Reads the styled image and show in window. 

    Saving and loading SHOULD BE eliminated, however this produces too much whitening in
    the "generated styled image". This may be caused by the async nature of VideoCapture,
    and I don't know how to fix it. 
    """
    # Device
    device = ("cuda" if torch.cuda.is_available() else "cpu")

    # Load Transformer Network
    print("Loading Transformer Network")
    net = transformer.TransformerNetwork()
    net.load_state_dict(torch.load(style_transform_path))
    net = net.to(device)
    print("Done Loading Transformer Network")

    # Set webcam settings
    cam = cv2.VideoCapture(0)
    cam.set(3, width)
    cam.set(4, height)

    # Main loop
    with torch.no_grad():
        count = 1
        while True:
            # Get webcam input
            ret_val, img = cam.read()

            # Mirror
            img = cv2.flip(img, 1)
            utils.saveimg(img, str(count) + ".png")

            # Free-up unneeded cuda memory
            torch.cuda.empty_cache()

            # Generate image
            content_image = utils.load_image(str(count) + ".png")
            content_tensor = utils.itot(content_image).to(device)
            generated_tensor = net(content_tensor)
            generated_image = utils.ttoi(generated_tensor.detach())
            if (PRESERVE_COLOR):
                generated_image = utils.transfer_color(content_image,
                                                       generated_image)
            utils.saveimg(generated_image, str(count + 1) + ".png")
            img2 = cv2.imread(str(count + 1) + ".png")

            count += 2
            # Show webcam
            cv2.imshow('Demo webcam', img2)
            if cv2.waitKey(1) == 27:
                break  # esc to quit

    # Free-up memories
    cam.release()
    cv2.destroyAllWindows()
コード例 #2
0
def stylize_folder(style_path, folder_containing_the_content_folder, save_folder, batch_size=1):
    """Stylizes images in a folder by batch
    If the images  are of different dimensions, use transform.resize() or use a batch size of 1
    IMPORTANT: Put content_folder inside another folder folder_containing_the_content_folder

    folder_containing_the_content_folder
        content_folder
            pic1.ext
            pic2.ext
            pic3.ext
            ...

    and saves as the styled images in save_folder as follow:

    save_folder
        pic1.ext
        pic2.ext
        pic3.ext
        ...
    """
    # Device
    device = ("cuda" if torch.cuda.is_available() else "cpu")

    # Image loader
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    image_dataset = utils.ImageFolderWithPaths(folder_containing_the_content_folder, transform=transform)
    image_loader = torch.utils.data.DataLoader(image_dataset, batch_size=batch_size)

    # Load Transformer Network
    net = transformer.TransformerNetwork()
    net.load_state_dict(torch.load(style_path))
    net = net.to(device)

    # Stylize batches of images
    with torch.no_grad():
        for content_batch, _, path in image_loader:
            # Free-up unneeded cuda memory
            torch.cuda.empty_cache()

            # Generate image
            generated_tensor = net(content_batch.to(device)).detach()

            # Save images
            for i in range(len(path)):
                generated_image = utils.ttoi(generated_tensor[i])
                if (PRESERVE_COLOR):
                    generated_image = utils.transfer_color(content_image, generated_image)
                image_name = os.path.basename(path[i])
                utils.saveimg(generated_image, save_folder + image_name)

#stylize()
コード例 #3
0
ファイル: train.py プロジェクト: ldf921/video-representation
def test(framework):
    from utils import saveimg
    load_network(framework.model, os.path.join(exp_dir, args.checkpoint))
    framework.cuda()
    if config['framework'] == 'voxelflow':
        results = framework.predict(generator_k(1, framework.val_loader))
        saveimg(results['truth'], results['pred'], os.path.join(exp_dir, 'vis.jpg'), n=10)
        print(((results['truth'] - results['pred']) ** 2).mean())
    for subset in args.test:
        data, loader = framework.build_test_dataset(subset)
        result = framework.test(data, loader)
        print(print_dict(result, subset))
コード例 #4
0
def stylize_folder_single(style_path, content_folder, save_folder):
    """
    Reads frames/pictures as follows:

    content_folder
        pic1.ext
        pic2.ext
        pic3.ext
        ...

    and saves as the styled images in save_folder as follow:

    save_folder
        pic1.ext
        pic2.ext
        pic3.ext
        ...
    """
    # Device
    device = ("cuda" if torch.cuda.is_available() else "cpu")

    # Load Transformer Network
    net = transformer.TransformerNetwork()
    net.load_state_dict(torch.load(style_path))
    net = net.to(device)

    # Stylize every frame
    images = [img for img in os.listdir(content_folder) if img.endswith(".jpg")]
    with torch.no_grad():
        for image_name in images:
            # Free-up unneeded cuda memory
            torch.cuda.empty_cache()
            
            # Load content image
            content_image = utils.load_image(content_folder + image_name)
            content_tensor = utils.itot(content_image).to(device)

            # Generate image
            generated_tensor = net(content_tensor)
            generated_image = utils.ttoi(generated_tensor.detach())
            if (PRESERVE_COLOR):
                generated_image = utils.transfer_color(content_image, generated_image)
            # Save image
            utils.saveimg(generated_image, save_folder + image_name)
コード例 #5
0
def stylize():
    # Device
    device = ("cuda" if torch.cuda.is_available() else "cpu")

    # Load Transformer Network
    net = transformer.TransformerNetwork()
    net.load_state_dict(torch.load(STYLE_TRANSFORM_PATH))
    net = net.to(device)

    with torch.no_grad():
        while(1):
            torch.cuda.empty_cache()
            print("Stylize Image~ Press Ctrl+C and Enter to close the program")
            content_image_path = input("Enter the image path: ")
            content_image = utils.load_image(content_image_path)
            starttime = time.time()
            content_tensor = utils.itot(content_image).to(device)
            generated_tensor = net(content_tensor)
            generated_image = utils.ttoi(generated_tensor.detach())
            if (PRESERVE_COLOR):
                generated_image = utils.transfer_color(content_image, generated_image)
            print("Transfer Time: {}".format(time.time() - starttime))
            utils.show(generated_image)
            utils.saveimg(generated_image, "helloworld.jpg")
コード例 #6
0
def train():
    # Seeds
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    # Device
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Dataset and Dataloader
    transform = transforms.Compose([
        transforms.Resize(TRAIN_IMAGE_SIZE),
        transforms.CenterCrop(TRAIN_IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255)),
    ])
    train_dataset = datasets.ImageFolder(DATASET_PATH, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=BATCH_SIZE,
                                               shuffle=True)

    # Load networks
    TransformerNetwork = transformer.TransformerNetwork().to(device)
    VGG = vgg.VGG16().to(device)

    # Get Style Features
    imagenet_neg_mean = (torch.tensor([-103.939, -116.779, -123.68],
                                      dtype=torch.float32).reshape(
                                          1, 3, 1, 1).to(device))
    style_image = utils.load_image(STYLE_IMAGE_PATH)
    style_tensor = utils.itot(style_image).to(device)
    style_tensor = style_tensor.add(imagenet_neg_mean)
    B, C, H, W = style_tensor.shape
    style_features = VGG(style_tensor.expand([BATCH_SIZE, C, H, W]))
    style_gram = {}
    for key, value in style_features.items():
        style_gram[key] = utils.gram(value)

    # Optimizer settings
    optimizer = optim.Adam(TransformerNetwork.parameters(), lr=ADAM_LR)

    # Loss trackers
    content_loss_history = []
    style_loss_history = []
    total_loss_history = []
    batch_content_loss_sum = 0
    batch_style_loss_sum = 0
    batch_total_loss_sum = 0

    # Optimization/Training Loop
    batch_count = 1
    start_time = time.time()
    for epoch in range(NUM_EPOCHS):
        print("========Epoch {}/{}========".format(epoch + 1, NUM_EPOCHS))
        for content_batch, _ in train_loader:
            # Get current batch size in case of odd batch sizes
            curr_batch_size = content_batch.shape[0]

            # Free-up unneeded cuda memory
            torch.cuda.empty_cache()

            # Zero-out Gradients
            optimizer.zero_grad()

            # Generate images and get features
            content_batch = content_batch[:, [2, 1, 0]].to(device)
            generated_batch = TransformerNetwork(content_batch)
            content_features = VGG(content_batch.add(imagenet_neg_mean))
            generated_features = VGG(generated_batch.add(imagenet_neg_mean))

            # Content Loss
            MSELoss = nn.MSELoss().to(device)
            content_loss = CONTENT_WEIGHT * MSELoss(
                generated_features["relu2_2"], content_features["relu2_2"])
            batch_content_loss_sum += content_loss

            # Style Loss
            style_loss = 0
            for key, value in generated_features.items():
                s_loss = MSELoss(utils.gram(value),
                                 style_gram[key][:curr_batch_size])
                style_loss += s_loss
            style_loss *= STYLE_WEIGHT
            batch_style_loss_sum += style_loss.item()

            # Total Loss
            total_loss = content_loss + style_loss
            batch_total_loss_sum += total_loss.item()

            # Backprop and Weight Update
            total_loss.backward()
            optimizer.step()

            # Save Model and Print Losses
            if ((batch_count - 1) % SAVE_MODEL_EVERY
                    == 0) or (batch_count == NUM_EPOCHS * len(train_loader)):
                # Print Losses
                print("========Iteration {}/{}========".format(
                    batch_count, NUM_EPOCHS * len(train_loader)))
                print("\tContent Loss:\t{:.2f}".format(batch_content_loss_sum /
                                                       batch_count))
                print("\tStyle Loss:\t{:.2f}".format(batch_style_loss_sum /
                                                     batch_count))
                print("\tTotal Loss:\t{:.2f}".format(batch_total_loss_sum /
                                                     batch_count))
                print("Time elapsed:\t{} seconds".format(time.time() -
                                                         start_time))

                # Save Model
                checkpoint_path = (SAVE_MODEL_PATH + "checkpoint_" +
                                   str(batch_count - 1) + ".pth")
                torch.save(TransformerNetwork.state_dict(), checkpoint_path)
                print("Saved TransformerNetwork checkpoint file at {}".format(
                    checkpoint_path))

                # Save sample generated image
                sample_tensor = generated_batch[0].clone().detach().unsqueeze(
                    dim=0)
                sample_image = utils.ttoi(sample_tensor.clone().detach())
                sample_image_path = (SAVE_IMAGE_PATH + "sample0_" +
                                     str(batch_count - 1) + ".png")
                utils.saveimg(sample_image, sample_image_path)
                print("Saved sample tranformed image at {}".format(
                    sample_image_path))

                # Save loss histories
                content_loss_history.append(batch_total_loss_sum / batch_count)
                style_loss_history.append(batch_style_loss_sum / batch_count)
                total_loss_history.append(batch_total_loss_sum / batch_count)

            # Iterate Batch Counter
            batch_count += 1

    stop_time = time.time()
    # Print loss histories
    print("Done Training the Transformer Network!")
    print("Training Time: {} seconds".format(stop_time - start_time))
    print("========Content Loss========")
    print(content_loss_history)
    print("========Style Loss========")
    print(style_loss_history)
    print("========Total Loss========")
    print(total_loss_history)

    # Save TransformerNetwork weights
    TransformerNetwork.eval()
    TransformerNetwork.cpu()
    final_path = SAVE_MODEL_PATH + "transformer_weight.pth"
    print("Saving TransformerNetwork weights at {}".format(final_path))
    torch.save(TransformerNetwork.state_dict(), final_path)
    print("Done saving final model")

    # Plot Loss Histories
    if PLOT_LOSS:
        utils.plot_loss_hist(content_loss_history, style_loss_history,
                             total_loss_history)
コード例 #7
0
def train():
    # Seeds
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    # Device
    device = ("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset and Dataloader
    transform = transforms.Compose([
        transforms.Resize(TRAIN_IMAGE_SIZE),
        transforms.CenterCrop(TRAIN_IMAGE_SIZE),
        transforms.ToTensor()
    ])
    x_trainloader, x_testloader = utils.prepare_loader(DATASET_PATH,
                                                       X_CLASS,
                                                       transform=transform,
                                                       batch_size=BATCH_SIZE,
                                                       shuffle=True)
    y_trainloader, y_testloader = utils.prepare_loader(DATASET_PATH,
                                                       Y_CLASS,
                                                       transform=transform,
                                                       batch_size=BATCH_SIZE,
                                                       shuffle=True)

    # Load Networks
    Gxy = models.Generator(64).to(device)
    Gyx = models.Generator(64).to(device)
    Dxy = models.Discriminator(64).to(device)
    Dyx = models.Discriminator(64).to(device)

    # Optimizer Settings
    G_param = list(Gxy.parameters()) + list(Gyx.parameters())
    G_optim = optim.Adam(G_param, lr=LR, betas=[BETA_1, BETA_2])
    Dxy_optim = optim.Adam(Dxy.parameters(), lr=LR, betas=[BETA_1, BETA_2])
    Dyx_optim = optim.Adam(Dyx.parameters(), lr=LR, betas=[BETA_1, BETA_2])

    # Losses
    from losses import real_loss, fake_loss, cycle_loss

    # Fixed test samples
    x_testiter = iter(x_testloader)
    y_testiter = iter(y_testloader)
    fixed_X = next(x_testiter)[0]
    fixed_X = fixed_X.to(device)
    fixed_Y = next(y_testiter)[0]
    fixed_Y = fixed_Y.to(device)

    # Tensor to Image
    fixed_X_image = utils.ttoi(fixed_X)
    fixed_Y_image = utils.ttoi(fixed_Y)

    # Number of batches
    x_trainiter = iter(x_trainloader)
    y_trainiter = iter(y_trainloader)
    iter_per_epoch = min(len(x_trainiter), len(y_trainiter))

    # Training the CycleGAN
    for epoch in range(1, NUM_EPOCHS + 1):
        print("========Epoch {}/{}========".format(epoch, NUM_EPOCHS))
        for _ in range(iter_per_epoch -
                       1):  # -1 in case of imbalanced sizes of the last batch

            # Fetch the dataset
            x_real = next(x_trainiter)[0]
            x_real = x_real.to(device)
            y_real = next(y_trainiter)[0]
            y_real = y_real.to(device)

            # ========= Discriminator ==========
            # In training the discriminators, we fix the generators' parameters.
            # It is alright to train both discriminators seperately beceause
            # their forward pass don't share any parameters with each other

            # Discriminator X -> Y Adversarial Loss
            Dxy_optim.zero_grad()  # Zero-out gradients
            Dxy_real_out = Dxy(y_real)  # Dxy Forward Pass
            Dxy_real_loss = real_loss(Dxy_real_out)  # Dxy Eeal loss
            Dxy_fake_out = Dxy(Gxy(x_real))  # Gxy produces fake-y images
            Dxy_fake_loss = fake_loss(Dxy_fake_out)  # Dxy Fake Loss
            Dxy_loss = Dxy_real_loss + Dxy_fake_loss  # Dxy Total Loss
            Dxy_loss.backward()  # Dxy Backprop
            Dxy_optim.step()  # Dxy Gradient Descent

            # Discriminator Y-> X Adversarial Loss
            Dyx_optim.zero_grad()  # Zero-out gradients
            Dyx_real_out = Dyx(x_real)  # Dyx Forward Pass
            Dyx_real_loss = real_loss(Dyx_real_out)  # Dyx Eeal loss
            Dyx_fake_out = Dyx(Gyx(y_real))  # Gyx produces fake-x images
            Dyx_fake_loss = fake_loss(Dyx_fake_out)  # Dyx Fake Loss
            Dyx_loss = Dyx_real_loss + Dyx_fake_loss  # Dyx Total Loss
            Dyx_loss.backward()  # Dyx Backprop
            Dyx_optim.step()  # Dyx Gradient Descent

            # ============= Generator ==============
            # Similar to training discriminator networks, in training
            # generator networks, we fix discriminator networks.
            # However, cycle consistency prohibits us
            # from training generators seperately.

            # Generator X -> Y Adversarial Loss
            G_optim.zero_grad()  # Zero-out gradients
            Gxy_out = Gxy(x_real)  # Gxy Forward Pass
            D_Gxy_out = Dxy(Gxy_out)  # Gxy -> Dxy Forward
            Gxy_loss = real_loss(D_Gxy_out)  # Gxy Real Loss

            # Generator Y -> X Adversarial Loss
            Gyx_out = Gyx(y_real)  # Gyx Forward Pass
            D_Gyx_out = Dyx(Gyx_out)  # Gyx -> Dyx Forward
            Gyx_loss = real_loss(D_Gyx_out)  # Gyx Real Loss

            # Cycle Consistency Loss
            y_x_y = Gxy(Gyx(x_real))  # Reconstruct Y
            yxy_cycle_loss = cycle_loss(
                y_x_y, y_real)  # Y-X-Y Cycle Reconstruction Loss
            x_y_x = Gyx(Gxy(y_real))  # Reconstruct X
            xyx_cycle_loss = cycle_loss(
                x_y_x, x_real)  # X-Y-X Cycle Reconstruction Loss

            # Generator Total Loss
            G_loss = Gxy_loss + Gyx_loss + CYCLE_WEIGHT * (xyx_cycle_loss +
                                                           yxy_cycle_loss)
            G_loss.backward()
            G_optim.step()

        # Print Losses
        print("Dxy Loss: {} Dyx Loss: {} Generator Loss: {}".format(
            Dxy_loss.item(), Dyx_loss.item(), G_loss.item()))

        # Generate Sample Fake Images
        Gxy.eval()
        Gyx.eval()
        with torch.no_grad():
            generated_y = Gyx(fixed_X)
            generated_y_img = utils.ttoi(generated_y.clone().detach())

            generated_x = Gxy(fixed_Y)
            generated_x_img = utils.ttoi(generated_x.clone().detach())

            H = W = TRAIN_IMAGE_SIZE
            concat_y = utils.concatenate_images(fixed_Y_image, generated_y_img,
                                                H, W)
            concat_x = utils.concatenate_images(fixed_X_image, generated_x_img,
                                                H, W)

            utils.saveimg(concat_x, "generated_x.png")
            utils.saveimg(concat_y, "generated_y.png")
コード例 #8
0
def train():
    # Seeds
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    # Device
    device = ("cuda" if torch.cuda.is_available() else "cpu")
    print("Device: {}".format(device))

    # Prepare dataset to make it usable with ImageFolder. Please only do this once
    # Uncomment this when you encounter "RuntimeError: Found 0 files in subfolders of:"
    # prepare_dataset_and_folder(DATASET_PATH, [IMAGE_SAVE_FOLDER, MODEL_SAVE_FOLDER])

    # Tranform, Dataset, DataLoaders
    transform = transforms.Compose([
        transforms.Resize(TRAIN_IMAGE_SIZE),
        transforms.CenterCrop(TRAIN_IMAGE_SIZE),
        transforms.ToTensor()
    ])
    x_trainloader, x_testloader = utils.prepare_loader(DATASET_PATH, X_CLASS, transform=transform, batch_size=BATCH_SIZE, shuffle=True)
    y_trainloader, y_testloader = utils.prepare_loader(DATASET_PATH, Y_CLASS, transform=transform, batch_size=BATCH_SIZE, shuffle=True)

    # [Iterators]: We need iterators for DataLoader because we 
    # are fetching training samples from 2 or more DataLoaders
    x_trainiter, x_testiter = iter(x_trainloader), iter(x_testloader)
    y_trainiter, y_testiter = iter(y_trainloader), iter(y_testloader)

    # Load Networks
    Gxy = models.Generator(64).to(device)
    Gyx = models.Generator(64).to(device)
    Dxy = models.Discriminator(64).to(device)
    Dyx = models.Discriminator(64).to(device)

    # Initialize weights with Normal(0, 0.02)
    Gxy.apply(models.init_weights)
    Gyx.apply(models.init_weights)
    Dxy.apply(models.init_weights)
    Dyx.apply(models.init_weights)

    # Optimizers
    # Concatenate Generator ParamsSee Training Loop for explanation
    # Reference: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/cycle_gan_model.py#L94             
    G_optim = optim.Adam(itertools.chain(Gxy.parameters(), Gyx.parameters()), lr=LR, betas=[BETA_1, BETA_2]) 
    Dxy_optim = optim.Adam(Dxy.parameters(), lr=LR, betas=[BETA_1, BETA_2])
    Dyx_optim = optim.Adam(Dyx.parameters(), lr=LR, betas=[BETA_1, BETA_2])

    # LR Scheduler
    # Reference: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py#L51
    def lambda_rule(epoch):
        lr_l = 1.0 - max(0, epoch - START_LR_DECAY)/(NUM_EPOCHS - START_LR_DECAY)
        return lr_l
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(G_optim, lr_lambda=lambda_rule)
    lr_scheduler_Dxy = torch.optim.lr_scheduler.LambdaLR(Dxy_optim, lr_lambda=lambda_rule)
    lr_scheduler_Dyx = torch.optim.lr_scheduler.LambdaLR(Dyx_optim, lr_lambda=lambda_rule)

    # Some Helper Functions!    
    # Output of generator is Tanh! so we need to scale real images accordingly
    def scale(tensor, mini=-1, maxi=1):
        return tensor * (maxi-mini) + mini

    # Fixed test samples
    fixed_X, _ = next(x_testiter)
    fixed_X = fixed_X.to(device)   
    fixed_Y, _ = next(y_testiter)
    fixed_Y = fixed_Y.to(device)

    # Loss History for the Whole Training Process
    loss_hist = losses.createLogger(["Gxy", "Gyx", "Dxy", "Dyx", "cycle"])

    # Number of batches
    iter_per_epoch = min(len(x_trainiter), len(y_trainiter))
    print("There are {} batches per epoch".format(iter_per_epoch))

    for epoch in range(1, NUM_EPOCHS+1):
        print("========Epoch {}/{}========".format(epoch, NUM_EPOCHS))
        start_time = time.time()
        # Loss Logger for the Current Batch
        curr_loss_hist = losses.createLogger(["Gxy", "Gyx", "Dxy", "Dyx", "cycle"])

        # Reset Iterators every epoch, otherwise, we'll have unequal batch sizes 
        # or worse, reach the end of iterator and get a Stop Iteration Error
        x_trainiter = iter(x_trainloader)
        y_trainiter = iter(y_trainloader)

        for i in range(1, iter_per_epoch):

            # Get current batches
            x_real, _ = next(x_trainiter) 
            x_real = scale(x_real)
            x_real = x_real.to(device)
            
            y_real, _ = next(y_trainiter)
            y_real = scale(y_real)
            y_real = y_real.to(device)
            
            # ========= Discriminator ==========
            # In training the discriminators, we fix the generators' parameters
            # It is alright to train both discriminators seperately because
            # their forward pass don't share any parameters with each other

            # Discriminator Y -> X Adversarial Loss
            Dyx_optim.zero_grad()                       # Zero-out Gradients
            Dyx_real_out = Dyx(x_real)                  # Dyx Forward Pass
            Dyx_real_loss = real_loss(Dyx_real_out)     # Dyx Real Loss 
            Dyx_fake_out = Dyx(Gyx(y_real))             # Gyx produces fake-x images
            Dyx_fake_loss = fake_loss(Dyx_fake_out)     # Dyx Fake Loss
            Dyx_loss = Dyx_real_loss + Dyx_fake_loss    # Dyx Total Loss
            Dyx_loss.backward()                         # Dyx Backprop
            Dyx_optim.step()                            # Dyx Gradient Descent
            
            # Discriminator X -> Y Adversarial Loss
            Dxy_optim.zero_grad()                       # Zero-out Gradients
            Dxy_real_out = Dxy(y_real)                  # Dxy Forward Pass
            Dxy_real_loss = real_loss(Dxy_real_out)     # Dxy Real Loss
            Dxy_fake_out = Dxy(Gxy(x_real))             # Gxy produces fake y-images
            Dxy_fake_loss = fake_loss(Dxy_fake_out)     # Dxy Fake Loss
            Dxy_loss = Dxy_real_loss + Dxy_fake_loss    # Dxy Total Loss
            Dxy_loss.backward()                         # Dxy Backprop
            Dxy_optim.step()                            # Dxy Gradient Descent

            # ============= Generator ==============
            # Similar to training discriminator networks, in training 
            # generator networks, we fix discriminator networks.
            # However, cycle consistency prohibits us 
            # from training generators seperately.

            # Generator X -> Y Adversarial Loss
            G_optim.zero_grad()                         # Zero-out Gradients
            Gxy_out = Gxy(x_real)                       # Gxy Forward Pass : generates fake-y images
            D_Gxy_out = Dxy(Gxy_out)                    # Gxy -> Dxy Forward Pass
            Gxy_loss = real_loss(D_Gxy_out)             # Gxy Real Loss
            
            # Generator Y -> X Adversarial Loss
            Gyx_out = Gyx(y_real)                       # Gyx Forward Pass : generates fake-x images
            D_Gyx_out = Dyx(Gyx_out)                    # Gyx -> Dyx Forward Pass    
            Gyx_loss = real_loss(D_Gyx_out)             # Gyx Real Loss
            
            # Cycle Consistency Loss
            yxy = Gxy(Gyx_out)                          # Reconstruct Y
            yxy_cycle_loss = cycle_loss(yxy, y_real)    # Y-X-Y L1 Cycle Reconstruction Loss
            xyx = Gyx(Gxy_out)                          # Reconstruct X
            xyx_cycle_loss = cycle_loss(xyx, x_real)    # X-Y-X L1 Cycle Reconstruction Loss
            G_cycle_loss = CYCLE_WEIGHT * (yxy_cycle_loss + xyx_cycle_loss)
            
            # Generator Total Loss
            G_loss = Gxy_loss + Gyx_loss + G_cycle_loss
            G_loss.backward()
            G_optim.step()

            # Record Losses
            curr_loss_hist = losses.updateEpochLogger(curr_loss_hist, [Gxy_loss, Gyx_loss, Dxy_loss, Dxy_loss, G_cycle_loss])

        # Learning Rate Scheduler Step
        lr_scheduler_G.step()
        lr_scheduler_Dxy.step()
        lr_scheduler_Dyx.step()

        # Record and Print Losses
        print("Dxy: {} Dyx: {} G: {} Cycle: {}".format(Dxy_loss.item(), Dyx_loss.item(), G_loss.item(), G_cycle_loss.item()))
        print("Time Elapsed: {}".format(time.time() - start_time))
        loss_hist = losses.updateGlobalLogger(loss_hist, curr_loss_hist)

        # Generate Fake Images
        Gxy.eval()
        Gyx.eval()
        with torch.no_grad():
            # Generate Fake X Images  
            x_tensor = generate.evaluate(Gyx, fixed_Y)  # Generate Image Tensor
            x_images = utils.concat_images(x_tensor)    # Merge Image Tensors -> Numpy Array
            save_path = IMAGE_SAVE_FOLDER + DATASET_PATH[:-1] + "X" + str(epoch) + ".png"
            utils.saveimg(x_images, save_path)

            # Generate Fake Y Images
            y_tensor = generate.evaluate(Gxy, fixed_X)  # Generate Image Tensor
            y_images = utils.concat_images(y_tensor)    # Merge Image Tensors -> Numpy Array
            save_path = IMAGE_SAVE_FOLDER + DATASET_PATH[:-1] + "Y" + str(epoch) + ".png"
            utils.saveimg(y_images, save_path)

        # Save Model Checkpoints
        if ((epoch % SAVE_MODEL_EVERY == 0) or (epoch == 1)):
            save_str = MODEL_SAVE_FOLDER + DATASET_PATH[:-1] + str(epoch)
            torch.save(Gxy.cpu().state_dict(), save_str + "_Gxy.pth")
            torch.save(Gyx.cpu().state_dict(), save_str + "_Gyx.pth")
            torch.save(Dxy.cpu().state_dict(), save_str + "_Dxy.pth")
            torch.save(Dxy.cpu().state_dict(), save_str + "_Dyx.pth")

            Gxy.to(device)
            Gyx.to(device)
            Dxy.to(device)
            Dyx.to(device)

        Gxy.train();
        Gyx.train();

    utils.plot_loss(loss_hist)
コード例 #9
0
def allfilters(img, k, tc, mu, beta, imagename, original):
    """ This function applies all the filters and saves the output images and logs the RMSE.
    Filters applied: A, B, C, R1, R2, R3, R4, R3-Crisp, R4-Crisp and Median """
    inp_img = img.copy()

    img_A = filterA(img, k, mu, beta)
    op_imagepath = os.path.join('images', 'enhanced', imagename+'_A.png')
    img_A = saveimg(op_imagepath, img_A)

    img_B = filterB(img, k, beta)
    op_imagepath = os.path.join('images', 'enhanced', imagename+'_B.png')
    img_B = saveimg(op_imagepath, img_B)

    img_C = filterC(img, k, mu, beta)
    op_imagepath = os.path.join('images', 'enhanced', imagename+'_C.png')
    img_C = saveimg(op_imagepath, img_C)

    img_Med = medfilt2d(img, k)
    op_imagepath = os.path.join('images', 'enhanced', imagename+'_Med.png')
    img_Med = saveimg(op_imagepath, img_Med)

    img_R1 = filterR1(img, k, tc, beta, img_A=img_A, img_B=img_B, img_C=img_C)
    op_imagepath = os.path.join('images', 'enhanced', imagename+'_R1.png')
    img_R1 = saveimg(op_imagepath, img_R1)

    img_R2 = filterR2(img, k, tc, mu, beta, img_A=img_A,
                      img_B=img_B, img_C=img_C)
    op_imagepath = os.path.join('images', 'enhanced', imagename+'_R2.png')
    img_R2 = saveimg(op_imagepath, img_R2)

    img_R3 = filterR3(img, k, tc, mu, beta, img_A=img_A,
                      img_B=img_B, img_C=img_C)
    op_imagepath = os.path.join('images', 'enhanced', imagename+'_R3.png')
    img_R3 = saveimg(op_imagepath, img_R3)

    img_R3Crisp = filterR3Crisp(
        img, k, tc, mu, beta, img_A=img_A, img_B=img_B, img_C=img_C)
    op_imagepath = os.path.join('images', 'enhanced', imagename+'_R3Crisp.png')
    img_R3Crisp = saveimg(op_imagepath, img_R3Crisp)

    img_R4 = filterR4(img, k, tc, mu, beta, img_A=img_A,
                      img_B=img_B, img_C=img_C)
    op_imagepath = os.path.join('images', 'enhanced', imagename+'_R4.png')
    img_R4 = saveimg(op_imagepath, img_R4)

    img_R4Crisp = filterR4(img, k, tc, mu, beta,
                           img_A=img_A, img_B=img_B, img_C=img_C)
    op_imagepath = os.path.join('images', 'enhanced', imagename+'_R4Crisp.png')
    img_R4Crisp = saveimg(op_imagepath, img_R4Crisp)

    if original is not None:
        orig_imagepath = os.path.join('images', original)
        orig_img = imageio.imread(orig_imagepath)

        err = rmse(img_A, orig_img)
        print('RMSE of filterA (against original image):{}'.format(err))
        err = rmse(img_B, orig_img)
        print('RMSE of filterB (against original image):{}'.format(err))
        err = rmse(img_C, orig_img)
        print('RMSE of filterC (against original image):{}'.format(err))
        err = rmse(img_Med, orig_img)
        print('RMSE of filterMed (against original image):{}'.format(err))
        err = rmse(img_R1, orig_img)
        print('RMSE of filterR1 (against original image):{}'.format(err))
        err = rmse(img_R2, orig_img)
        print('RMSE of filterR2 (against original image):{}'.format(err))
        err = rmse(img_R3, orig_img)
        print('RMSE of filterR3 (against original image):{}'.format(err))
        err = rmse(img_R3Crisp, orig_img)
        print('RMSE of filterR3Crisp (against original image):{}'.format(err))
        err = rmse(img_R4, orig_img)
        print('RMSE of filterR4 (against original image):{}'.format(err))
        err = rmse(img_R4Crisp, orig_img)
        print('RMSE of filterR3Crisp (against original image):{}'.format(err))

        noisy_err = rmse(orig_img, inp_img)
        print('RMSE of input image (against original image):{}'.format(noisy_err))
コード例 #10
0
def main(args):
    config = Config(args.config)
    cfg = config(vars(args), mode=['infer', 'init'])
    scale = cfg['infer']['scale']
    mdname = cfg['infer']['model']
    imgname = ''.join(mdname)  # + '/' + str(scale)
    #dirname = ''.join(mdname)  + '_' + str(scale)
    sz = cfg['infer']['sz']
    infer_size = cfg['infer']['infer_size']
    #save_path = os.path.join(args.save_dir, cfg['init']['result'])
    list = cfg['infer']['infer_size']
    save_path = create_path(args.save_dir, cfg['init']['result'])
    save_path = create_path(save_path, imgname)
    save_path = create_path(save_path, str(scale))

    tif_path = create_path(save_path, cfg['infer']['lab'])
    color_path = create_path(save_path, cfg['infer']['color'])
    gray_path = create_path(save_path, cfg['infer']['gray'])
    vdl_dir = os.path.join(args.save_dir, cfg['init']['vdl_dir'])
    palette = cfg['infer']['palette']
    palette = np.array(palette, dtype=np.uint8)
    num_class = cfg['init']['num_classes']
    batchsz = cfg['infer']['batchsz']
    infer_path = os.path.join(cfg['infer']['root_path'], cfg['infer']['path'])
    tagname = imgname + '/' + str(scale)
    vdl_dir = os.path.join(vdl_dir, 'infer')
    writer = LogWriter(logdir=vdl_dir)

    infer_ds = TeDataset(path=cfg['infer']['root_path'],
                         fl=cfg['infer']['path'],
                         sz=sz)
    total = len(infer_ds)
    # select model
    #addresult = np.zeros((total//batchsz,batchsz,num_class,sz,sz))
    addresult = np.zeros((total, num_class, sz, sz))
    for mnet in mdname:
        result_list = []
        net = modelset(mode=mnet, num_classes=cfg['init']['num_classes'])
        # load moel
        input = InputSpec([None, 3, 64, 64], 'float32', 'x')
        label = InputSpec([None, 1, 64, 64], 'int64', 'label')
        model = paddle.Model(net, input, label)
        model.load(path=os.path.join(args.save_dir, mnet) + '/' + mnet)
        model.prepare()
        result = model.predict(
            infer_ds,
            batch_size=batchsz,
            num_workers=cfg['infer']['num_workers'],
            stack_outputs=True  # [160,2,64,64]
        )

        addresult = result[0] + scale * addresult

    pred = construct(addresult, infer_size, sz=sz)
    # pred = construct(addresult,infer_size,sz = sz)
    # # 腐蚀膨胀
    # read vdl
    file_list = os.listdir(infer_path)
    file_list.sort(key=lambda x: int(x[-5:-4]))
    step = 0
    for i, fl in enumerate(file_list):
        name, _ = fl.split(".")
        # save pred
        lab_img = Image.fromarray(pred[i].astype(np.uint8)).convert("L")
        saveimg(lab_img, tif_path, name=name, type='.tif')

        # gray_label
        label = colorize(pred[i], palette)
        writer.add_image(tag=tagname,
                         img=saveimg(label,
                                     gray_path,
                                     name=name,
                                     type='.png',
                                     re_out=True),
                         step=step,
                         dataformats='HW')
        step += 1

        # color_label
        file = os.path.join(infer_path, fl)
        out = blend_image(file, label, alpha=0.25)
        writer.add_image(tag=tagname,
                         img=saveimg(out,
                                     color_path,
                                     name=name,
                                     type='.png',
                                     re_out=True),
                         step=step,
                         dataformats='HWC')
        step += 1
    writer.close()
コード例 #11
0
inp_img = imageio.imread(imagepath)
img = inp_img.astype(float)  # Save a copy as inp_img will be modified
assert img.shape == (256, 256), 'Image is not of size (255,255).'

orig_imagepath = os.path.join('images', args.original)
orig_img = imageio.imread(orig_imagepath)

if args.method == 'Sharpen':
    blurred_img = ndimage.gaussian_filter(img, 3)
    filter_blurred_img = ndimage.gaussian_filter(blurred_img, 1)
    alpha = 30
    sharpened = blurred_img + alpha * (blurred_img - filter_blurred_img)
    op_imagepath = os.path.join(
        'images', 'enhanced',
        os.path.basename(imagepath)[:-4] + '_sharpen.png')
    sharpened = saveimg(op_imagepath, sharpened)
    err = rmse(sharpened, orig_img)
    print('RMSE of filterSharpen (against original image):{}'.format(err))

if args.method == 'Gauss':
    gauss_denoised = ndimage.gaussian_filter(img, 2)
    op_imagepath = os.path.join(
        'images', 'enhanced',
        os.path.basename(imagepath)[:-4] + '_gauss.png')
    gauss_denoised = saveimg(op_imagepath, gauss_denoised)
    err = rmse(gauss_denoised, orig_img)
    print('RMSE of filterGauss (against original image):{}'.format(err))

if args.method == 'TVC':
    tvc = denoise_tv_chambolle(img, weight=30, multichannel=False)
    op_imagepath = os.path.join('images', 'enhanced',
コード例 #12
0
def train():
    # Seeds
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    # Device
    device = ("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset and Dataloader
    transform = transforms.Compose([
        transforms.Resize(TRAIN_IMAGE_SIZE),
        transforms.CenterCrop(TRAIN_IMAGE_SIZE),
        # transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(DATASET_PATH, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=BATCH_SIZE,
                                               shuffle=True)

    # Load networks
    TransformerNetwork = transformer.TransformerNetwork().to(device)

    if USE_LATEST_CHECKPOINT is True:
        files = glob.glob(
            "/home/clng/github/fast-neural-style-pytorch/models/checkpoint*")
        if len(files) == 0:
            print("use latest checkpoint but no checkpoint found")
        else:
            files.sort(key=os.path.getmtime, reverse=True)
            latest_checkpoint_path = files[0]
            print("using latest checkpoint %s" % (latest_checkpoint_path))
            params = torch.load(latest_checkpoint_path, map_location=device)
            TransformerNetwork.load_state_dict(params)

    VGG = vgg.VGG19().to(device)

    # Get Style Features
    imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68],
                                     dtype=torch.float32).reshape(1, 3, 1,
                                                                  1).to(device)
    style_image = utils.load_image(STYLE_IMAGE_PATH)
    if ADJUST_BRIGHTNESS == "1":
        style_image = cv2.cvtColor(style_image, cv2.COLOR_BGR2GRAY)
        style_image = utils.hist_norm(style_image,
                                      [0, 64, 96, 128, 160, 192, 255],
                                      [0, 0.05, 0.15, 0.5, 0.85, 0.95, 1],
                                      inplace=True)
    elif ADJUST_BRIGHTNESS == "2":
        style_image = cv2.cvtColor(style_image, cv2.COLOR_BGR2GRAY)
        style_image = cv2.equalizeHist(style_image)
    elif ADJUST_BRIGHTNESS == "3":
        a = 1
        # hsv = cv2.cvtColor(style_image, cv2.COLOR_BGR2HSV)
        # hsv = utils.auto_brightness(hsv)
        # style_image = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
    style_image = ensure_three_channels(style_image)
    sname = os.path.splitext(os.path.basename(STYLE_IMAGE_PATH))[0] + "_train"
    cv2.imwrite(
        "/home/clng/datasets/bytenow/neural_styles/{s}.jpg".format(s=sname),
        style_image)

    style_tensor = utils.itot(style_image,
                              max_size=TRAIN_STYLE_SIZE).to(device)

    style_tensor = style_tensor.add(imagenet_neg_mean)
    B, C, H, W = style_tensor.shape
    style_features = VGG(style_tensor.expand([BATCH_SIZE, C, H, W]))
    style_gram = {}
    for key, value in style_features.items():
        style_gram[key] = utils.gram(value)

    # Optimizer settings
    optimizer = optim.Adam(TransformerNetwork.parameters(), lr=ADAM_LR)

    # Loss trackers
    content_loss_history = []
    style_loss_history = []
    total_loss_history = []
    batch_content_loss_sum = 0
    batch_style_loss_sum = 0
    batch_total_loss_sum = 0

    # Optimization/Training Loop
    batch_count = 1
    start_time = time.time()
    for epoch in range(NUM_EPOCHS):
        print("========Epoch {}/{}========".format(epoch + 1, NUM_EPOCHS))
        for content_batch, _ in train_loader:
            # Get current batch size in case of odd batch sizes
            curr_batch_size = content_batch.shape[0]

            # Free-up unneeded cuda memory
            # torch.cuda.empty_cache()

            # Zero-out Gradients
            optimizer.zero_grad()

            # Generate images and get features
            content_batch = content_batch[:, [2, 1, 0]].to(device)
            generated_batch = TransformerNetwork(content_batch)
            content_features = VGG(content_batch.add(imagenet_neg_mean))
            generated_features = VGG(generated_batch.add(imagenet_neg_mean))

            # Content Loss
            MSELoss = nn.MSELoss().to(device)
            content_loss = CONTENT_WEIGHT * \
                MSELoss(generated_features['relu3_4'],
                        content_features['relu3_4'])
            batch_content_loss_sum += content_loss

            # Style Loss
            style_loss = 0
            for key, value in generated_features.items():
                s_loss = MSELoss(utils.gram(value),
                                 style_gram[key][:curr_batch_size])
                style_loss += s_loss
            style_loss *= STYLE_WEIGHT
            batch_style_loss_sum += style_loss.item()

            # Total Loss
            total_loss = content_loss + style_loss
            batch_total_loss_sum += total_loss.item()

            # Backprop and Weight Update
            total_loss.backward()
            optimizer.step()

            # Save Model and Print Losses
            if (((batch_count - 1) % SAVE_MODEL_EVERY == 0)
                    or (batch_count == NUM_EPOCHS * len(train_loader))):
                # Print Losses
                print("========Iteration {}/{}========".format(
                    batch_count, NUM_EPOCHS * len(train_loader)))
                print("\tContent Loss:\t{:.2f}".format(batch_content_loss_sum /
                                                       batch_count))
                print("\tStyle Loss:\t{:.2f}".format(batch_style_loss_sum /
                                                     batch_count))
                print("\tTotal Loss:\t{:.2f}".format(batch_total_loss_sum /
                                                     batch_count))
                print("Time elapsed:\t{} seconds".format(time.time() -
                                                         start_time))

                # Save Model
                checkpoint_path = SAVE_MODEL_PATH + "checkpoint_" + str(
                    batch_count - 1) + ".pth"
                torch.save(TransformerNetwork.state_dict(), checkpoint_path)
                print("Saved TransformerNetwork checkpoint file at {}".format(
                    checkpoint_path))

                # Save sample generated image
                sample_tensor = generated_batch[0].clone().detach().unsqueeze(
                    dim=0)
                sample_image = utils.ttoi(sample_tensor.clone().detach())
                sample_image_path = SAVE_IMAGE_PATH + "sample0_" + str(
                    batch_count - 1) + ".png"
                utils.saveimg(sample_image, sample_image_path)
                print("Saved sample tranformed image at {}".format(
                    sample_image_path))

                # Save loss histories
                content_loss_history.append(batch_total_loss_sum / batch_count)
                style_loss_history.append(batch_style_loss_sum / batch_count)
                total_loss_history.append(batch_total_loss_sum / batch_count)

            # Iterate Batch Counter
            batch_count += 1

    stop_time = time.time()
    # Print loss histories
    print("Done Training the Transformer Network!")
    print("Training Time: {} seconds".format(stop_time - start_time))
    print("========Content Loss========")
    print(content_loss_history)
    print("========Style Loss========")
    print(style_loss_history)
    print("========Total Loss========")
    print(total_loss_history)

    # Save TransformerNetwork weights
    TransformerNetwork.eval()
    TransformerNetwork.cpu()
    final_path = SAVE_MODEL_PATH + STYLE_NAME + ".pth"
    print("Saving TransformerNetwork weights at {}".format(final_path))
    torch.save(TransformerNetwork.state_dict(), final_path)
    print("Done saving final model")

    # Plot Loss Histories
    if (PLOT_LOSS):
        utils.plot_loss_hist(content_loss_history, style_loss_history,
                             total_loss_history)