Beispiel #1
0
def transform(args, vgg):
    # Transform dataset
    dataset_transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE),           # scale shortest side to image_size
        transforms.CenterCrop(IMAGE_SIZE),      # crop center image_size out
        transforms.ToTensor(),                  # turn image from [0-255] to [0-1]
        utils.normalize_tensor_transform()      # normalize with ImageNet values
    ])
    train_dataset = datasets.ImageFolder(args.dataset, dataset_transform)
    train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE)
    # Transform style
    style_transform = transforms.Compose([
        transforms.ToTensor(),                  # turn image from [0-255] to [0-1]
        utils.normalize_tensor_transform()      # normalize with ImageNet values
    ])
    style = utils.load_image(args.style_image)
    style = style_transform(style)
#     style = Variable(style.repeat(BATCH_SIZE, 1, 1, 1)).type(dtype)
    style = Variable(style.repeat(BATCH_SIZE, 1, 1, 1))
    style_name = os.path.split(args.style_image)[-1].split('.')[0]

    # Calculate gram matrices for style features
    style_features = vgg(style)
    style_gram = [utils.gram(fmap) for fmap in style_features]
    return train_loader, style_gram, style_name
 def __call__(self, x, y):
     x_vgg, y_vgg = self.vgg(x), self.vgg(y)
     style_loss = 0.0
     style_loss += self.criterion(utils.gram(x_vgg['relu2_2']),
                                  utils.gram(y_vgg['relu2_2']))
     style_loss += self.criterion(utils.gram(x_vgg['relu3_4']),
                                  utils.gram(y_vgg['relu3_4']))
     style_loss += self.criterion(utils.gram(x_vgg['relu4_4']),
                                  utils.gram(y_vgg['relu4_4']))
     style_loss += self.criterion(utils.gram(x_vgg['relu5_2']),
                                  utils.gram(y_vgg['relu5_2']))
     return style_loss
Beispiel #3
0
def train(args):
    # GPU enabling
    if (args.gpu != None):
        use_cuda = True
        dtype = torch.cuda.FloatTensor
        torch.cuda.set_device(args.gpu)
        print("Current device: %d" % torch.cuda.current_device())

    # visualization of training controlled by flag
    visualize = (args.visualize != None)
    if (visualize):
        img_transform_512 = transforms.Compose([
            # scale shortest side to image_size
            transforms.Scale(512),
            # crop center image_size out
            transforms.CenterCrop(512),
            # turn image from [0-255] to [0-1]
            transforms.ToTensor(),
            utils.normalize_tensor_transform(
            )  # normalize with ImageNet values
        ])

        testImage_amber = utils.load_image("content_imgs/amber.jpg")
        testImage_amber = img_transform_512(testImage_amber)
        testImage_amber = Variable(testImage_amber.repeat(1, 1, 1, 1),
                                   requires_grad=False).type(dtype)

        testImage_dan = utils.load_image("content_imgs/dan.jpg")
        testImage_dan = img_transform_512(testImage_dan)
        testImage_dan = Variable(testImage_dan.repeat(1, 1, 1, 1),
                                 requires_grad=False).type(dtype)

        testImage_maine = utils.load_image("content_imgs/maine.jpg")
        testImage_maine = img_transform_512(testImage_maine)
        testImage_maine = Variable(testImage_maine.repeat(1, 1, 1, 1),
                                   requires_grad=False).type(dtype)

    # define network
    image_transformer = ImageTransformNet().type(dtype)
    optimizer = Adam(image_transformer.parameters(), LEARNING_RATE)

    loss_mse = torch.nn.MSELoss()

    # load vgg network
    vgg = Vgg16().type(dtype)

    # get training dataset
    dataset_transform = transforms.Compose([
        # scale shortest side to image_size
        transforms.Scale(IMAGE_SIZE),
        transforms.CenterCrop(IMAGE_SIZE),  # crop center image_size out
        # turn image from [0-255] to [0-1]
        transforms.ToTensor(),
        utils.normalize_tensor_transform()  # normalize with ImageNet values
    ])
    train_dataset = datasets.ImageFolder(args.dataset, dataset_transform)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)

    # style image
    style_transform = transforms.Compose([
        # turn image from [0-255] to [0-1]
        transforms.ToTensor(),
        utils.normalize_tensor_transform()  # normalize with ImageNet values
    ])
    style = utils.load_image(args.style_image)
    style = style_transform(style)
    style = Variable(style.repeat(BATCH_SIZE, 1, 1, 1)).type(dtype)
    style_name = os.path.split(args.style_image)[-1].split('.')[0]

    # calculate gram matrices for style feature layer maps we care about
    style_features = vgg(style)
    style_gram = [utils.gram(fmap) for fmap in style_features]

    for e in range(EPOCHS):

        # track values for...
        img_count = 0
        aggregate_style_loss = 0.0
        aggregate_content_loss = 0.0
        aggregate_tv_loss = 0.0

        # train network
        image_transformer.train()
        for batch_num, (x, label) in enumerate(train_loader):
            img_batch_read = len(x)
            img_count += img_batch_read

            # zero out gradients
            optimizer.zero_grad()

            # input batch to transformer network
            x = Variable(x).type(dtype)
            y_hat = image_transformer(x)

            # get vgg features
            y_c_features = vgg(x)
            y_hat_features = vgg(y_hat)

            # calculate style loss
            y_hat_gram = [utils.gram(fmap) for fmap in y_hat_features]
            style_loss = 0.0
            for j in range(4):
                style_loss += loss_mse(y_hat_gram[j],
                                       style_gram[j][:img_batch_read])
            style_loss = STYLE_WEIGHT * style_loss
            aggregate_style_loss += style_loss.item()

            # calculate content loss (h_relu_2_2)
            recon = y_c_features[1]
            recon_hat = y_hat_features[1]
            content_loss = CONTENT_WEIGHT * loss_mse(recon_hat, recon)
            aggregate_content_loss += content_loss.item()

            # calculate total variation regularization (anisotropic version)
            # https://www.wikiwand.com/en/Total_variation_denoising
            diff_i = torch.sum(
                torch.abs(y_hat[:, :, :, 1:] - y_hat[:, :, :, :-1]))
            diff_j = torch.sum(
                torch.abs(y_hat[:, :, 1:, :] - y_hat[:, :, :-1, :]))
            tv_loss = TV_WEIGHT * (diff_i + diff_j)
            aggregate_tv_loss += tv_loss.item()

            # total loss
            total_loss = style_loss + content_loss + tv_loss

            # backprop
            total_loss.backward()
            optimizer.step()

            # print out status message
            if ((batch_num + 1) % 100 == 0):
                status = "{}  Epoch {}:  [{}/{}]  Batch:[{}]  agg_style: {:.6f}  agg_content: {:.6f}  agg_tv: {:.6f}  style: {:.6f}  content: {:.6f}  tv: {:.6f} ".format(
                    time.ctime(), e + 1, img_count, len(train_dataset),
                    batch_num + 1, aggregate_style_loss / (batch_num + 1.0),
                    aggregate_content_loss / (batch_num + 1.0),
                    aggregate_tv_loss / (batch_num + 1.0), style_loss.item(),
                    content_loss.item(), tv_loss.item())
                print(status)

            if ((batch_num + 1) % 1000 == 0) and (visualize):
                image_transformer.eval()

                if not os.path.exists("visualization"):
                    os.makedirs("visualization")
                if not os.path.exists("visualization/%s" % style_name):
                    os.makedirs("visualization/%s" % style_name)

                outputTestImage_amber = image_transformer(
                    testImage_amber).cpu()

                amber_path = "visualization/%s/amber_%d_%05d.jpg" % (
                    style_name, e + 1, batch_num + 1)
                utils.save_image(amber_path, outputTestImage_amber.data[0])

                outputTestImage_dan = image_transformer(testImage_dan).cpu()
                dan_path = "visualization/%s/dan_%d_%05d.jpg" % (
                    style_name, e + 1, batch_num + 1)
                utils.save_image(dan_path, outputTestImage_dan.data[0])

                outputTestImage_maine = image_transformer(
                    testImage_maine).cpu()
                maine_path = "visualization/%s/maine_%d_%05d.jpg" % (
                    style_name, e + 1, batch_num + 1)
                utils.save_image(maine_path, outputTestImage_maine.data[0])

                print("images saved")
                image_transformer.train()

    # save model
    image_transformer.eval()

    if use_cuda:
        image_transformer.cpu()

    if not os.path.exists("models"):
        os.makedirs("models")
    filename = "models/" + str(style_name) + "_" + \
        str(time.ctime()).replace(' ', '_') + ".model"
    torch.save(image_transformer.state_dict(), filename)

    if use_cuda:
        image_transformer.cuda()
Beispiel #4
0
def train(args):
    # 是否使用GPU
    device = torch.device("cuda" if args.cuda else "cpu")
    # 设置随机种子
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # 数据载入及预处理
    transform = transforms.Compose([transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size), transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))])
    dataSet = datasets.ImageFolder(args.dataset, transform)
    data = DataLoader(dataSet, batch_size=args.batch_size)
    # 初始化训练模型
    transformer = transformNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()
    # 预训练
    vgg = vgg19(requires_grad=False).to(device)
    styleTransform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))])
    style = loadImage(args.style_image, size=args.style_size)
    style = styleTransform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)
    features_style = vgg(normalizeBatch(style))
    gram_style = [gram(y) for y in features_style]
    # 训练
    for epoch in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batchId, (x, _) in enumerate(data):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()
            # 数据部署到GPU或CPU
            x = x.to(device)
            y = transformer(x)
            # 归一化
            y = normalizeBatch(y)
            x = normalizeBatch(x)
            # 提取特征
            features_y = vgg(y)
            features_x = vgg(x)
            # 计算 content loss
            content_loss = args.content_weight * mse_loss(
                features_y.relu3_3, features_x.relu3_3)
            # 计算 style loss
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = gram(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
            style_loss *= args.style_weight
            # 计算 total loss
            total_loss = content_loss + style_loss
            # 反向传播
            total_loss.backward()
            # 更新模型
            optimizer.step()
            # 计算 aggregate loss
            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()
            # 输出日志
            if (batchId + 1) % args.log_interval == 0:
                msg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {}\tstyle: {}\ttotal: {}".format(time.ctime(), epoch + 1, count, len(dataSet), agg_content_loss / (batchId + 1), agg_style_loss / (batchId + 1), (agg_content_loss + agg_style_loss) / (batchId + 1))
                print(msg)
            # 保存检查点
            if args.checkpoint_model_dir is not None and (batchId + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(epoch) + "_batch_id_" + str(batchId + 1) + ".pth"
                ckpt_model_path = args.checkpoint_model_dir + '/' + args.save_model_name
                ckpt_model_path += '/' + ckpt_model_filename
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()
    # 保存模型
    transformer.eval().cpu()
    save_model_path = args.save_model_dir + '/' + args.save_model_name + '.pth'
    torch.save(transformer.state_dict(), save_model_path)

    print("model saved at", save_model_path)
Beispiel #5
0
#style_predictions = [x for x in style_model.predict(style_img_arr[np.newaxis, :, :, :])]
#with open(os.path.join(results_path, 'style_predictions'), 'wb') as f:
#    pickle.dump(style_predictions, f, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(results_path, 'style_predictions'), 'rb') as f:
    style_predictions = pickle.load(f)
style_predictions_tensor = [K.variable(x) for x in style_predictions]

# Define the loss functions
L_content = K.mean(
    mean_squared_error(content_predictions_tensor, content_layer))

w_style = [0.05, 0.2, 0.2, 0.25, 0.3]
L_style = [
    K.mean(
        mean_squared_error(utils.gram(style_predictions_tensor[x][0]),
                           utils.gram(style_layers[x][0])) * w_style[x])
    for x in range(len(style_predictions_tensor))
]
L_style = sum(L_style)

alpha = 0.1
beta = 1
L_total = (alpha * L_content) + (beta * L_style)

# Define the loss gradients
L_gradients = K.gradients(L_total, vgg_model.input)

loss_grads_function = K.function([vgg_model.input],
                                 [L_content, L_style, L_total] + L_gradients)
Beispiel #6
0
def train():
    # Seeds
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    # Device
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Dataset and Dataloader
    transform = transforms.Compose([
        transforms.Resize(TRAIN_IMAGE_SIZE),
        transforms.CenterCrop(TRAIN_IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255)),
    ])
    train_dataset = datasets.ImageFolder(DATASET_PATH, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=BATCH_SIZE,
                                               shuffle=True)

    # Load networks
    TransformerNetwork = transformer.TransformerNetwork().to(device)
    VGG = vgg.VGG16().to(device)

    # Get Style Features
    imagenet_neg_mean = (torch.tensor([-103.939, -116.779, -123.68],
                                      dtype=torch.float32).reshape(
                                          1, 3, 1, 1).to(device))
    style_image = utils.load_image(STYLE_IMAGE_PATH)
    style_tensor = utils.itot(style_image).to(device)
    style_tensor = style_tensor.add(imagenet_neg_mean)
    B, C, H, W = style_tensor.shape
    style_features = VGG(style_tensor.expand([BATCH_SIZE, C, H, W]))
    style_gram = {}
    for key, value in style_features.items():
        style_gram[key] = utils.gram(value)

    # Optimizer settings
    optimizer = optim.Adam(TransformerNetwork.parameters(), lr=ADAM_LR)

    # Loss trackers
    content_loss_history = []
    style_loss_history = []
    total_loss_history = []
    batch_content_loss_sum = 0
    batch_style_loss_sum = 0
    batch_total_loss_sum = 0

    # Optimization/Training Loop
    batch_count = 1
    start_time = time.time()
    for epoch in range(NUM_EPOCHS):
        print("========Epoch {}/{}========".format(epoch + 1, NUM_EPOCHS))
        for content_batch, _ in train_loader:
            # Get current batch size in case of odd batch sizes
            curr_batch_size = content_batch.shape[0]

            # Free-up unneeded cuda memory
            torch.cuda.empty_cache()

            # Zero-out Gradients
            optimizer.zero_grad()

            # Generate images and get features
            content_batch = content_batch[:, [2, 1, 0]].to(device)
            generated_batch = TransformerNetwork(content_batch)
            content_features = VGG(content_batch.add(imagenet_neg_mean))
            generated_features = VGG(generated_batch.add(imagenet_neg_mean))

            # Content Loss
            MSELoss = nn.MSELoss().to(device)
            content_loss = CONTENT_WEIGHT * MSELoss(
                generated_features["relu2_2"], content_features["relu2_2"])
            batch_content_loss_sum += content_loss

            # Style Loss
            style_loss = 0
            for key, value in generated_features.items():
                s_loss = MSELoss(utils.gram(value),
                                 style_gram[key][:curr_batch_size])
                style_loss += s_loss
            style_loss *= STYLE_WEIGHT
            batch_style_loss_sum += style_loss.item()

            # Total Loss
            total_loss = content_loss + style_loss
            batch_total_loss_sum += total_loss.item()

            # Backprop and Weight Update
            total_loss.backward()
            optimizer.step()

            # Save Model and Print Losses
            if ((batch_count - 1) % SAVE_MODEL_EVERY
                    == 0) or (batch_count == NUM_EPOCHS * len(train_loader)):
                # Print Losses
                print("========Iteration {}/{}========".format(
                    batch_count, NUM_EPOCHS * len(train_loader)))
                print("\tContent Loss:\t{:.2f}".format(batch_content_loss_sum /
                                                       batch_count))
                print("\tStyle Loss:\t{:.2f}".format(batch_style_loss_sum /
                                                     batch_count))
                print("\tTotal Loss:\t{:.2f}".format(batch_total_loss_sum /
                                                     batch_count))
                print("Time elapsed:\t{} seconds".format(time.time() -
                                                         start_time))

                # Save Model
                checkpoint_path = (SAVE_MODEL_PATH + "checkpoint_" +
                                   str(batch_count - 1) + ".pth")
                torch.save(TransformerNetwork.state_dict(), checkpoint_path)
                print("Saved TransformerNetwork checkpoint file at {}".format(
                    checkpoint_path))

                # Save sample generated image
                sample_tensor = generated_batch[0].clone().detach().unsqueeze(
                    dim=0)
                sample_image = utils.ttoi(sample_tensor.clone().detach())
                sample_image_path = (SAVE_IMAGE_PATH + "sample0_" +
                                     str(batch_count - 1) + ".png")
                utils.saveimg(sample_image, sample_image_path)
                print("Saved sample tranformed image at {}".format(
                    sample_image_path))

                # Save loss histories
                content_loss_history.append(batch_total_loss_sum / batch_count)
                style_loss_history.append(batch_style_loss_sum / batch_count)
                total_loss_history.append(batch_total_loss_sum / batch_count)

            # Iterate Batch Counter
            batch_count += 1

    stop_time = time.time()
    # Print loss histories
    print("Done Training the Transformer Network!")
    print("Training Time: {} seconds".format(stop_time - start_time))
    print("========Content Loss========")
    print(content_loss_history)
    print("========Style Loss========")
    print(style_loss_history)
    print("========Total Loss========")
    print(total_loss_history)

    # Save TransformerNetwork weights
    TransformerNetwork.eval()
    TransformerNetwork.cpu()
    final_path = SAVE_MODEL_PATH + "transformer_weight.pth"
    print("Saving TransformerNetwork weights at {}".format(final_path))
    torch.save(TransformerNetwork.state_dict(), final_path)
    print("Done saving final model")

    # Plot Loss Histories
    if PLOT_LOSS:
        utils.plot_loss_hist(content_loss_history, style_loss_history,
                             total_loss_history)
Beispiel #7
0
def train(args, image_transformer, train_loader, optimizer, vgg, loss_mse, style_gram):
    for e in range(EPOCHS):
        img_count = 0
        aggregate_style_loss = 0.0
        aggregate_content_loss = 0.0
        aggregate_tv_loss = 0.0

        # train network
        image_transformer.train()
        for batch_num, (x, label) in enumerate(train_loader):
            img_batch_read = len(x)
            img_count += img_batch_read
            optimizer.zero_grad()
            
            # input batch to transformer network
#             x = Variable(x).type(dtype)
            x = Variable(x)
            y_hat = image_transformer(x)

            # get vgg features
            y_c_features = vgg(x)
            y_hat_features = vgg(y_hat)

            # calculate style loss
            y_hat_gram = [utils.gram(fmap) for fmap in y_hat_features]
            style_loss = 0.0
            for j in range(4):
                style_loss += loss_mse(y_hat_gram[j], style_gram[j][:img_batch_read])
            style_loss = STYLE_WEIGHT*style_loss
            aggregate_style_loss += style_loss.data

            # calculate content loss
            recon = y_c_features[1]      
            recon_hat = y_hat_features[1]
            content_loss = CONTENT_WEIGHT*loss_mse(recon_hat, recon)
            aggregate_content_loss += content_loss.data

            # calculate total variation regularization
            diff_i = torch.sum(torch.abs(y_hat[:, :, :, 1:] - y_hat[:, :, :, :-1]))
            diff_j = torch.sum(torch.abs(y_hat[:, :, 1:, :] - y_hat[:, :, :-1, :]))
            tv_loss = TV_WEIGHT*(diff_i + diff_j)
            aggregate_tv_loss += tv_loss.data

            # total loss
            total_loss = style_loss + content_loss + tv_loss

            # back propagation
            total_loss.backward()
            optimizer.step()

            # check the status for each 100 batches
            if ((batch_num + 1) % 100 == 0):
                status = "{}  Epoch {}:  [{}/{}]  Batch:[{}]  agg_style: {:.6f}  agg_content: {:.6f}  agg_tv: {:.6f}  style: {:.6f}  content: {:.6f}  tv: {:.6f} ".format(
                                time.ctime(), e + 1, img_count, len(train_dataset), batch_num+1,
                                aggregate_style_loss/(batch_num+1.0), aggregate_content_loss/(batch_num+1.0), aggregate_tv_loss/(batch_num+1.0),
                                style_loss.data[0], content_loss.data[0], tv_loss.data[0]
                            )
                print(status)

            if ((batch_num + 1) % 1000 == 0) and (visualize):
                image_transformer.eval()

                if not os.path.exists("visualization"):
                    os.makedirs("visualization")
                if not os.path.exists("visualization/%s" %style_name):
                    os.makedirs("visualization/%s" %style_name)

                outputTestImage_amber = image_transformer(testImage_amber).cpu()
                amber_path = "visualization/%s/amber_%d_%05d.jpg" %(style_name, e+1, batch_num+1)
                utils.save_image(amber_path, outputTestImage_amber.data[0])

                outputTestImage_dan = image_transformer(testImage_dan).cpu()
                dan_path = "visualization/%s/dan_%d_%05d.jpg" %(style_name, e+1, batch_num+1)
                utils.save_image(dan_path, outputTestImage_dan.data[0])

                outputTestImage_maine = image_transformer(testImage_maine).cpu()
                maine_path = "visualization/%s/maine_%d_%05d.jpg" %(style_name, e+1, batch_num+1)
                utils.save_image(maine_path, outputTestImage_maine.data[0])

                print("images saved")
                image_transformer.train()
Beispiel #8
0
def training(args):

    dtype = torch.float64
    if args.gpu:
        use_cuda = True
        print("Current device: %d" % torch.cuda.current_device())
        dtype = torch.cuda.FloatTensor

    print('content = {}'.format(args.content))
    print('style = {}'.format(args.style))

    img_transform = transforms.Compose([
        transforms.Grayscale(3),
        transforms.Resize(
            args.image_size,
            interpolation=Image.NEAREST),  # scale shortest side to image_size
        transforms.CenterCrop(args.image_size),  # crop center image_size out
        transforms.ToTensor(),  # turn image from [0-255] to [0-1]
        utils.normalize_imagenet(norm)  # normalize with ImageNet values
    ])

    content = Image.open(args.content)
    content = img_transform(content)  # Loaded already cropped
    content = Variable(content.repeat(1, 1, 1, 1),
                       requires_grad=False).type(dtype)

    # define network
    image_transformer = ImageTransformNet().type(dtype)
    optimizer = Adam(image_transformer.parameters(), 1e-5)

    loss_mse = torch.nn.MSELoss()

    # load vgg network
    vgg = Vgg19().type(dtype)

    # get training dataset
    dataset_transform = transforms.Compose([
        transforms.Grayscale(3),
        transforms.Resize(
            args.image_size,
            interpolation=Image.NEAREST),  # scale shortest side to image_size
        transforms.CenterCrop(args.image_size),  # crop center image_size out
        transforms.ToTensor(),  # turn image from [0-255] to [0-1]
        utils.normalize_imagenet(norm)  # normalize with ImageNet values
    ])
    train_dataset = datasets.ImageFolder(args.dataset, dataset_transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)

    # style image
    style_transform = transforms.Compose([
        transforms.Grayscale(3),
        transforms.ToTensor(),  # turn image from [0-255] to [0-1]
        utils.normalize_imagenet(norm)  # normalize with ImageNet values
    ])

    style = Image.open(args.style)
    if "clinical" in args.style:
        style = style.crop(
            (20, 0, style.size[0],
             style.size[1]))  # Remove left bar from the style image
    style = style_transform(style)
    style = Variable(style.repeat(args.batch_size, 1, 1, 1)).type(dtype)

    # calculate gram matrices for target style layers
    style_features = vgg(style)
    style_gram = [utils.gram(feature) for feature in style_features]

    if args.loss == 1:
        print("Using average style on features")
        if "clinical" in args.style:
            with open('models/perceptual/us_clinical_ft_dict.pickle',
                      'rb') as handle:
                style_features = pickle.load(handle)
        else:
            with open('models/perceptual/us_hq_ft_dict.pickle',
                      'rb') as handle:
                style_features = pickle.load(handle)
        style_features = [
            style_features[label].type(dtype)
            for label in style_features.keys()
        ]
        style_gram = [utils.gram(feature) for feature in style_features]

    style_loss_list, content_loss_list, total_loss_list = [], [], []

    for e in range(args.epochs):
        count = 0
        img_count = 0

        # train network
        image_transformer.train()
        for batch_num, (x, label) in enumerate(train_loader):
            img_batch_read = len(x)
            img_count += img_batch_read

            # zero out gradients
            optimizer.zero_grad()

            # input batch to transformer network
            x = Variable(x).type(dtype)
            y_hat = image_transformer(x)

            # get vgg features
            y_c_features = vgg(x)
            y_hat_features = vgg(y_hat)

            # calculate style loss
            y_hat_gram = [utils.gram(feature) for feature in y_hat_features]
            style_loss = 0.0
            for j in range(5):
                style_loss += loss_mse(y_hat_gram[j],
                                       style_gram[j][:img_batch_read])
            style_loss = args.weights[0] * style_loss

            # calculate content loss (block5_conv2)
            recon = y_c_features[5]
            recon_hat = y_hat_features[5]
            content_loss = args.weights[1] * loss_mse(recon_hat, recon)

            # total loss
            total_loss = style_loss + content_loss

            # backprop
            total_loss.backward()
            optimizer.step()

            # print out status message
            if ((batch_num + 1) % 100 == 0):
                count = count + 1
                total_loss_list.append(total_loss.item())
                content_loss_list.append(content_loss.item())
                style_loss_list.append(style_loss.item())
                print(
                    "Epoch {}:\t [{}/{}]\t\t Batch:[{}]\t total: {:.6f}\t style: {:.6f}\t content: {:.6f}\t"
                    .format(e, img_count, len(train_dataset), batch_num + 1,
                            total_loss.item(), style_loss.item(),
                            content_loss.item()))

        image_transformer.eval()

        stylized = image_transformer(content).cpu()
        out_path = args.save_dir + "/opt/perc%d_%d.png" % (e, batch_num + 1)
        utils.save_image(out_path, stylized.data[0], norm)

        image_transformer.train()

    # save model
    image_transformer.eval()

    filename = 'models/perceptual/' + str(args.model_name)
    if not '.model' in filename:
        filename = filename + '.model'
    torch.save(image_transformer.state_dict(), filename)

    total_loss = np.array(total_loss_list)
    style_loss = np.array(style_loss_list)
    content_loss = np.array(content_loss_list)
    x = np.arange(0, np.size(total_loss)) / (count + 1)

    fig = plt.figure('Perceptual Loss')
    plt.plot(x, total_loss)
    plt.plot(x, content_loss)
    plt.plot(x, style_loss)
    plt.legend(['Total', 'Content', 'Style'])
    plt.title('Perceptual Loss')
    plt.savefig(args.save_dir + '/perc_loss.png')
def trainer(args):
    print_args(args)

    # if cuda is available, GPU will be used. On the contrary, CPU will be used.
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    imgs_path = args.dataset
    style_image = args.style_image

    size = (args.image_width, args.image_height)

    # define network
    if args.model_mode != "slim":
        image_transformer = normal.ImageTransformNet(size).to(device)
    else:
        image_transformer = slim.ImageTransformNet(size).to(device)

    # set optimizer
    optimizer = Adam(image_transformer.parameters(), args.learning_rate)

    # define loss function
    loss_mse = nn.MSELoss()

    # load vgg network
    vgg = Vgg16(args.VGG_path).to(device)

    # get training dataset
    dataset_transform = transforms.Compose([
        transforms.Resize(size),  # scale shortest side to image_size
        transforms.ToTensor(),  # turn image from [0-255] to [0-1]
        utils.normalize_tensor_transform()  # normalize with ImageNet values
    ])
    train_dataset = IDataset(imgs_path, dataset_transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)

    # style image
    style_transform = transforms.Compose([
        transforms.ToTensor(),  # turn image from [0-255] to [0-1]
        utils.normalize_tensor_transform()  # normalize with ImageNet values
    ])
    style = utils.load_image(style_image)
    style = style_transform(style)
    style = Variable(style.repeat(args.batch_size, 1, 1, 1)).to(device)
    style_name = os.path.split(style_image)[-1].split('.')[0]

    # calculate gram matrices for style feature layer maps we care about
    style_features = vgg(style)
    style_gram = [utils.gram(fmap) for fmap in style_features]

    print("Start training. . .")
    best_style_loss = 1e9
    best_content_loss = 1e9
    for e in range(args.epoch):

        # track values for...
        img_count = 0
        aggregate_style_loss = 0.0
        aggregate_content_loss = 0.0
        aggregate_tv_loss = 0.0

        # train network
        image_transformer.train()
        for batch_num, x in enumerate(train_loader):
            img_batch_read = len(x)
            img_count += img_batch_read

            # zero out gradients
            optimizer.zero_grad()

            # input batch to transformer network
            x = Variable(x).to(device)
            y_hat = image_transformer(x)

            # get vgg features
            y_c_features = vgg(x)
            y_hat_features = vgg(y_hat)

            # calculate style loss
            y_hat_gram = [utils.gram(fmap) for fmap in y_hat_features]
            style_loss = 0.0
            for j in range(4):
                style_loss += loss_mse(y_hat_gram[j],
                                       style_gram[j][:img_batch_read])
            style_loss = args.style_weight * style_loss
            aggregate_style_loss += style_loss.item()

            # calculate content loss (h_relu_2_2)
            recon = y_c_features[1]
            recon_hat = y_hat_features[1]
            content_loss = args.content_weight * loss_mse(recon_hat, recon)
            aggregate_content_loss += content_loss.item()

            # calculate total variation regularization (anisotropic version)
            # https://www.wikiwand.com/en/Total_variation_denoising
            diff_i = torch.sum(
                torch.abs(y_hat[:, :, :, 1:] - y_hat[:, :, :, :-1]))
            diff_j = torch.sum(
                torch.abs(y_hat[:, :, 1:, :] - y_hat[:, :, :-1, :]))
            tv_loss = args.tv_weight * (diff_i + diff_j)
            aggregate_tv_loss += tv_loss.item()

            # total loss
            total_loss = style_loss + content_loss + tv_loss

            # backprop
            total_loss.backward()
            optimizer.step()

            # print out status message
            if (batch_num + 1) % 50 == 0:
                status = "{}  Epoch {}:  [{}/{}]  Batch:[{}]  agg_style: {:.6f}  agg_content: {:.6f}  " \
                         "agg_tv: {:.6f}  style: {:.6f}  content: {:.6f}  tv: {:.6f} "\
                    .format(
                            time.ctime(), e + 1, img_count, len(train_dataset), batch_num+1,
                            aggregate_style_loss/(batch_num+1.0),
                            aggregate_content_loss/(batch_num+1.0),
                            aggregate_tv_loss/(batch_num+1.0), style_loss.item(),
                            content_loss.item(), tv_loss.item())
                print(status)

        # save model
        image_transformer.eval()

        model_folder = args.model_folder
        if not os.path.exists("models/{}".format(model_folder)):
            os.makedirs("models/{}".format(model_folder))
        num = len(train_dataset) / args.batch_size

        aggregate_style_loss /= num
        aggregate_content_loss /= num
        aggregate_tv_loss /= num

        filename = "models/{}/{}_epoch={}_style={:.4f}_content={:.4f}_tv={:.4f}.pth".format(
            model_folder, style_name, e + 1, aggregate_style_loss,
            aggregate_content_loss, aggregate_tv_loss)

        if aggregate_style_loss < best_style_loss or aggregate_content_loss < best_content_loss:
            torch.save(image_transformer.state_dict(), filename)

        if aggregate_style_loss < best_style_loss:
            best_style_loss = aggregate_style_loss
        if aggregate_content_loss < best_content_loss:
            best_content_loss = aggregate_content_loss
def train():
    # Seeds
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    # Device
    device = ("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset and Dataloader
    transform = transforms.Compose([
        transforms.Resize(TRAIN_IMAGE_SIZE),
        transforms.CenterCrop(TRAIN_IMAGE_SIZE),
        # transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(DATASET_PATH, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=BATCH_SIZE,
                                               shuffle=True)

    # Load networks
    TransformerNetwork = transformer.TransformerNetwork().to(device)

    if USE_LATEST_CHECKPOINT is True:
        files = glob.glob(
            "/home/clng/github/fast-neural-style-pytorch/models/checkpoint*")
        if len(files) == 0:
            print("use latest checkpoint but no checkpoint found")
        else:
            files.sort(key=os.path.getmtime, reverse=True)
            latest_checkpoint_path = files[0]
            print("using latest checkpoint %s" % (latest_checkpoint_path))
            params = torch.load(latest_checkpoint_path, map_location=device)
            TransformerNetwork.load_state_dict(params)

    VGG = vgg.VGG19().to(device)

    # Get Style Features
    imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68],
                                     dtype=torch.float32).reshape(1, 3, 1,
                                                                  1).to(device)
    style_image = utils.load_image(STYLE_IMAGE_PATH)
    if ADJUST_BRIGHTNESS == "1":
        style_image = cv2.cvtColor(style_image, cv2.COLOR_BGR2GRAY)
        style_image = utils.hist_norm(style_image,
                                      [0, 64, 96, 128, 160, 192, 255],
                                      [0, 0.05, 0.15, 0.5, 0.85, 0.95, 1],
                                      inplace=True)
    elif ADJUST_BRIGHTNESS == "2":
        style_image = cv2.cvtColor(style_image, cv2.COLOR_BGR2GRAY)
        style_image = cv2.equalizeHist(style_image)
    elif ADJUST_BRIGHTNESS == "3":
        a = 1
        # hsv = cv2.cvtColor(style_image, cv2.COLOR_BGR2HSV)
        # hsv = utils.auto_brightness(hsv)
        # style_image = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
    style_image = ensure_three_channels(style_image)
    sname = os.path.splitext(os.path.basename(STYLE_IMAGE_PATH))[0] + "_train"
    cv2.imwrite(
        "/home/clng/datasets/bytenow/neural_styles/{s}.jpg".format(s=sname),
        style_image)

    style_tensor = utils.itot(style_image,
                              max_size=TRAIN_STYLE_SIZE).to(device)

    style_tensor = style_tensor.add(imagenet_neg_mean)
    B, C, H, W = style_tensor.shape
    style_features = VGG(style_tensor.expand([BATCH_SIZE, C, H, W]))
    style_gram = {}
    for key, value in style_features.items():
        style_gram[key] = utils.gram(value)

    # Optimizer settings
    optimizer = optim.Adam(TransformerNetwork.parameters(), lr=ADAM_LR)

    # Loss trackers
    content_loss_history = []
    style_loss_history = []
    total_loss_history = []
    batch_content_loss_sum = 0
    batch_style_loss_sum = 0
    batch_total_loss_sum = 0

    # Optimization/Training Loop
    batch_count = 1
    start_time = time.time()
    for epoch in range(NUM_EPOCHS):
        print("========Epoch {}/{}========".format(epoch + 1, NUM_EPOCHS))
        for content_batch, _ in train_loader:
            # Get current batch size in case of odd batch sizes
            curr_batch_size = content_batch.shape[0]

            # Free-up unneeded cuda memory
            # torch.cuda.empty_cache()

            # Zero-out Gradients
            optimizer.zero_grad()

            # Generate images and get features
            content_batch = content_batch[:, [2, 1, 0]].to(device)
            generated_batch = TransformerNetwork(content_batch)
            content_features = VGG(content_batch.add(imagenet_neg_mean))
            generated_features = VGG(generated_batch.add(imagenet_neg_mean))

            # Content Loss
            MSELoss = nn.MSELoss().to(device)
            content_loss = CONTENT_WEIGHT * \
                MSELoss(generated_features['relu3_4'],
                        content_features['relu3_4'])
            batch_content_loss_sum += content_loss

            # Style Loss
            style_loss = 0
            for key, value in generated_features.items():
                s_loss = MSELoss(utils.gram(value),
                                 style_gram[key][:curr_batch_size])
                style_loss += s_loss
            style_loss *= STYLE_WEIGHT
            batch_style_loss_sum += style_loss.item()

            # Total Loss
            total_loss = content_loss + style_loss
            batch_total_loss_sum += total_loss.item()

            # Backprop and Weight Update
            total_loss.backward()
            optimizer.step()

            # Save Model and Print Losses
            if (((batch_count - 1) % SAVE_MODEL_EVERY == 0)
                    or (batch_count == NUM_EPOCHS * len(train_loader))):
                # Print Losses
                print("========Iteration {}/{}========".format(
                    batch_count, NUM_EPOCHS * len(train_loader)))
                print("\tContent Loss:\t{:.2f}".format(batch_content_loss_sum /
                                                       batch_count))
                print("\tStyle Loss:\t{:.2f}".format(batch_style_loss_sum /
                                                     batch_count))
                print("\tTotal Loss:\t{:.2f}".format(batch_total_loss_sum /
                                                     batch_count))
                print("Time elapsed:\t{} seconds".format(time.time() -
                                                         start_time))

                # Save Model
                checkpoint_path = SAVE_MODEL_PATH + "checkpoint_" + str(
                    batch_count - 1) + ".pth"
                torch.save(TransformerNetwork.state_dict(), checkpoint_path)
                print("Saved TransformerNetwork checkpoint file at {}".format(
                    checkpoint_path))

                # Save sample generated image
                sample_tensor = generated_batch[0].clone().detach().unsqueeze(
                    dim=0)
                sample_image = utils.ttoi(sample_tensor.clone().detach())
                sample_image_path = SAVE_IMAGE_PATH + "sample0_" + str(
                    batch_count - 1) + ".png"
                utils.saveimg(sample_image, sample_image_path)
                print("Saved sample tranformed image at {}".format(
                    sample_image_path))

                # Save loss histories
                content_loss_history.append(batch_total_loss_sum / batch_count)
                style_loss_history.append(batch_style_loss_sum / batch_count)
                total_loss_history.append(batch_total_loss_sum / batch_count)

            # Iterate Batch Counter
            batch_count += 1

    stop_time = time.time()
    # Print loss histories
    print("Done Training the Transformer Network!")
    print("Training Time: {} seconds".format(stop_time - start_time))
    print("========Content Loss========")
    print(content_loss_history)
    print("========Style Loss========")
    print(style_loss_history)
    print("========Total Loss========")
    print(total_loss_history)

    # Save TransformerNetwork weights
    TransformerNetwork.eval()
    TransformerNetwork.cpu()
    final_path = SAVE_MODEL_PATH + STYLE_NAME + ".pth"
    print("Saving TransformerNetwork weights at {}".format(final_path))
    torch.save(TransformerNetwork.state_dict(), final_path)
    print("Done saving final model")

    # Plot Loss Histories
    if (PLOT_LOSS):
        utils.plot_loss_hist(content_loss_history, style_loss_history,
                             total_loss_history)