def train_step(lr, hr, generator, discriminator, content):

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            ## re-scale
            ## lr: 0 ~ 1
            ## hr: -1 ~ 1
            lr = tf.cast(lr, tf.float32)
            hr = tf.cast(hr, tf.float32)
            lr = lr / 255
            hr = hr / 127.5 - 1

            sr = generator(lr, training=True)

            sr_output = discriminator(sr, training=True)
            hr_output = discriminator(hr, training=True)

            disc_loss = discriminator_loss(sr_output, hr_output)

            mse_loss = mse_based_loss(sr, hr)
            gen_loss = generator_loss(sr_output)
            cont_loss = content_loss(content, sr, hr)
            perceptual_loss = mse_loss + cont_loss + 1e-3 * gen_loss

        gradients_of_generator = gen_tape.gradient(
            perceptual_loss, generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(
            disc_loss, discriminator.trainable_variables)

        generator_optimizer.apply_gradients(
            zip(gradients_of_generator, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(
            zip(gradients_of_discriminator, discriminator.trainable_variables))

        return perceptual_loss, disc_loss
Exemple #2
0
def train(model, img, art, photo, epoch_num, device, content_name_list,
          style_name_list):
    args = arg_parser()
    features = vgg19_features(model, content_name_list, style_name_list,
                              device)
    optimizer = torch.optim.SGD([img.requires_grad_()],
                                lr=args.lr,
                                momentum=args.momentum)
    _, art_style = features.extract_features(art)
    art_style = [i_style.detach() for i_style in art_style]
    photo_content, _ = features.extract_features(photo)
    photo_content = [i_content.detach() for i_content in photo_content]
    for epoch in range(epoch_num):
        end_time = time.time()
        img_content, img_style = features.extract_features(img)
        C_loss = content_loss(img_content, photo_content)
        S_loss = style_loss(img_style, art_style)
        loss = C_loss * args.content_weight + S_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if epoch % args.log == 0:
            print('[{0}/{1}]\ttime:{time:.2f}\tloss:{loss:.4f}'.format(epoch, epoch_num,\
                    time=time.time()-end_time, loss=loss.item()*1e6))
            print(C_loss.item(), S_loss.item())

        if epoch % args.save_fre == 0:
            save_img(epoch, img)

    img.data.clamp_(0, 1)
    return img
Exemple #3
0
        # get style layer from constant network
        network = vgg_net.build(style_image, 0)
        style_layer = [
            sess.run(network['conv' + str(i) + '_1']) for i in range(1, 6)
        ]
        # get content layer from constant network
        network = vgg_net.build(content_image, 0)
        content_layer = sess.run(network['conv4_2'])

        # style transfer network
        network = vgg_net.build(pred_image, 1)
        pred_style = [network['conv' + str(i) + '_1'] for i in range(1, 6)]
        pred_content = network['conv4_2']

        style_loss = loss.style_loss(style_layer, pred_style)
        content_loss = loss.content_loss(content_layer, pred_content)

        total_loss = args.ALPHA * content_loss + args.BETA * style_loss

        default_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        vgg_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     scope='vggnet')

        optimizer = tf.train.AdamOptimizer(args.learning_rate).minimize(
            loss=total_loss, var_list=default_vars + vgg_vars)

        saver = tf.train.Saver()

        # train
        print('Training Start !!!')
        sess.run(tf.global_variables_initializer())
    correct_predictions = tf.equal(tf.argmax(discrim_predictions, 1),
                                   tf.argmax(discrim_target, 1))
    discim_accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

    # 2) content loss
    '''
    CONTENT_LAYER = 'relu5_4'

    enhanced_vgg = vgg.net(vgg_dir, vgg.preprocess(enhanced * 255))
    dslr_vgg = vgg.net(vgg_dir, vgg.preprocess(dslr_image * 255))

    content_size = utils._tensor_size(dslr_vgg[CONTENT_LAYER]) * batch_size
    loss_content = 2 * tf.nn.l2_loss(enhanced_vgg[CONTENT_LAYER] - dslr_vgg[CONTENT_LAYER]) / content_size
    '''
    loss_content = loss.content_loss(dslr_image, enhanced, batch_size)

    # 3) color loss
    enhanced_blur = lutils.blur(enhanced)
    dslr_blur = lutils.blur(dslr_image)

    #loss_color = tf.reduce_sum(tf.pow(dslr_blur - enhanced_blur, 2))/(2 * batch_size)
    loss_color = tf.reduce_sum(
        tf.abs(dslr_image - enhanced)) / (2 * batch_size)

    #loss_color = loss.color_loss(dslr_image, enhanced, batch_size)

    # 4) total variation loss

    batch_shape = (batch_size, PATCH_WIDTH, PATCH_HEIGHT, 3)
    tv_y_size = lutils._tensor_size(enhanced[:, 1:, :, :])
Exemple #5
0
def DST(input_im,
        content_im,
        style_im,
        extractor,
        content_path,
        style_path,
        content_pts,
        style_pts,
        style_pts_path,
        output_dir,
        output_prefix,
        im_size=256,
        max_iter=250,
        checkpoint_iter=50,
        content_weight=8.,
        warp_weight=0.3,
        reg_weight=10,
        scales=3,
        pyr_levs=5,
        sharp_warp=False,
        optim='adam',
        lr=1e-3,
        warp_lr_fac=1.,
        verbose=False,
        save_intermediate=False,
        save_extra=False,
        device='cuda:0'):

    # If warp weight is 0, run the base method STROTSS
    use_DST = True
    if warp_weight == 0.:
        use_DST = False

    # Initialize warp parameters
    src_Kpts, target_Kpts, border_Kpts, no_flow_Kpts = init_keypoint_params(
        input_im, content_path, content_pts, style_pts, device)
    thetas_Kpts = Variable(torch.rand_like(src_Kpts).data * 1e-4,
                           requires_grad=True)

    # Clamp the target points so that they don't go outside the boundary
    target_Kpts[:, 0] = torch.clamp(target_Kpts[:, 0],
                                    min=5,
                                    max=content_im.size(2) - 5)
    target_Kpts[:, 1] = torch.clamp(target_Kpts[:, 1],
                                    min=5,
                                    max=content_im.size(3) - 5)
    target_Kpts_o = target_Kpts.clone().detach()

    # Assign colors to each set of points (used for visualization only)
    np.random.seed(1)
    colors = []
    for j in range(src_Kpts.shape[0]):
        colors.append(np.random.random(size=3))

    # Initialize pixel parameters
    s_pyr = dec_lap_pyr(input_im, pyr_levs)
    s_pyr = [Variable(li.data, requires_grad=True) for li in s_pyr]

    # Define parameters to be optimized
    s_pyr_list = [{'params': si} for si in s_pyr]
    if use_DST:
        thetas_opt_list = [{'params': thetas_Kpts, 'lr': lr * warp_lr_fac}]
    else:
        thetas_opt_list = []

    # Construct optimizer
    if optim == 'sgd':
        optimizer = torch.optim.SGD(s_pyr_list + thetas_opt_list,
                                    lr=lr,
                                    momentum=0.9)
    elif optim == 'rmsprop':
        optimizer = torch.optim.RMSprop(s_pyr_list + thetas_opt_list, lr=lr)
    else:
        optimizer = torch.optim.Adam(s_pyr_list + thetas_opt_list, lr=lr)

    # Set scales
    scale_list = list(range(scales))
    if scales == 1:
        scale_list = [0]

    # Create lists to store various loss values
    ell_list = []
    ell_style_list = []
    ell_content_list = []
    ell_warp_list = []
    ell_warp_TV_list = []

    # Iteratively stylize over more levels of image pyramid
    for scale in scale_list:

        down_fac = 2**(scales - 1 - scale)
        begin_ind = (scales - 1 - scale)
        content_weight_scaled = content_weight * down_fac

        print('\nOptimizing at scale {}, image size ({}, {})'.format(
            scale + 1,
            content_im.size(2) // down_fac,
            content_im.size(3) // down_fac))

        if down_fac > 1.:
            content_im_scaled = F.interpolate(content_im,
                                              (content_im.size(2) // down_fac,
                                               content_im.size(3) // down_fac),
                                              mode='bilinear')
            style_im_scaled = F.interpolate(
                style_im,
                (style_im.size(2) // down_fac, style_im.size(3) // down_fac),
                mode='bilinear')
        else:
            content_im_scaled = content_im.clone()
            style_im_scaled = style_im.clone()

        # Compute feature maps that won't change for this scale
        with torch.no_grad():
            feat_content = extractor(content_im_scaled)

            feat_style = None
            for i in range(5):
                with torch.no_grad():
                    feat_e = extractor.forward_samples_hypercolumn(
                        style_im_scaled, samps=1000)
                    feat_style = feat_e if feat_style is None else torch.cat(
                        (feat_style, feat_e), dim=2)

            feat_max = 3 + 2 * 64 + 2 * 128 + 3 * 256 + 2 * 512  # 2179 = sum of all extracted channels
            spatial_style = feat_style.view(1, feat_max, -1, 1)

            xx, xy = sample_indices(feat_content[0], feat_style)

        # Begin optimization for this scale
        for i in range(max_iter):

            optimizer.zero_grad()

            # Get current stylized image from the laplacian pyramid
            curr_im = syn_lap_pyr(s_pyr[begin_ind:])
            new_im = curr_im.clone()
            content_im_warp = content_im_scaled.clone()

            # Generate destination points with the current thetas
            src_Kpts_aug, dst_Kpts_aug, flow_Kpts_aug = gen_dst_pts_keypoints(
                src_Kpts, thetas_Kpts, no_flow_Kpts, border_Kpts)

            # Calculate warp loss
            ell_warp = torch.norm(target_Kpts_o -
                                  dst_Kpts_aug[:target_Kpts.size(0)],
                                  dim=1).mean()

            # Scale points to [0-1]
            src_Kpts_aug = src_Kpts_aug / torch.max(
                src_Kpts_aug, 0, keepdim=True)[0]
            dst_Kpts_aug = dst_Kpts_aug / torch.max(
                dst_Kpts_aug, 0, keepdim=True)[0]
            dst_Kpts_aug = torch.clamp(dst_Kpts_aug, min=0., max=1.)

            # Warp
            new_im, content_im_warp, warp_field = apply_warp(
                new_im, [src_Kpts_aug], [dst_Kpts_aug],
                device,
                sharp=sharp_warp,
                im2=content_im_warp)
            new_im = new_im.to(device)

            # Calculate total variation
            ell_warp_TV = TV(warp_field)

            # Extract VGG features of warped and unwarped stylized images
            feat_result_warped = extractor(new_im)
            feat_result_unwarped = extractor(curr_im)

            # Sample features to calculate losses with
            n = 2048
            if i % 1 == 0 and i != 0:
                np.random.shuffle(xx)
                np.random.shuffle(xy)
            spatial_result_warped, spatial_content = spatial_feature_extract(
                feat_result_warped, feat_content, xx[:n], xy[:n])
            spatial_result_unwarped, _ = spatial_feature_extract(
                feat_result_unwarped, feat_content, xx[:n], xy[:n])

            # Content loss
            ell_content = content_loss(spatial_result_unwarped,
                                       spatial_content)

            # Style loss

            # Lstyle(Unwarped X, S)
            loss_remd1 = remd_loss(spatial_result_unwarped,
                                   spatial_style,
                                   cos_d=True)
            loss_moment1 = moment_loss(spatial_result_unwarped,
                                       spatial_style,
                                       moments=[1, 2])
            loss_color1 = remd_loss(spatial_result_unwarped[:, :3, :, :],
                                    spatial_style[:, :3, :, :],
                                    cos_d=False)
            loss_style1 = loss_remd1 + loss_moment1 + (
                1. / max(content_weight_scaled, 1.)) * loss_color1

            # Lstyle(Warped X, S)
            loss_remd2 = remd_loss(spatial_result_warped,
                                   spatial_style,
                                   cos_d=True)
            loss_moment2 = moment_loss(spatial_result_warped,
                                       spatial_style,
                                       moments=[1, 2])
            loss_color2 = remd_loss(spatial_result_warped[:, :3, :, :],
                                    spatial_style[:, :3, :, :],
                                    cos_d=False)
            loss_style2 = loss_remd2 + loss_moment2 + (
                1. / max(content_weight_scaled, 1.)) * loss_color2

            # Total loss
            if use_DST:
                ell_style = loss_style1 + loss_style2
                ell = content_weight_scaled * ell_content + ell_style + warp_weight * ell_warp + reg_weight * ell_warp_TV
            else:
                ell_style = loss_style1
                ell = content_weight_scaled * ell_content + ell_style

            # Record loss values
            ell_list.append(ell.item())
            ell_content_list.append(ell_content.item())
            ell_style_list.append(ell_style.item())
            ell_warp_list.append(ell_warp.item())
            ell_warp_TV_list.append(ell_warp_TV.item())

            # Output intermediate loss
            if i == 0 or i % checkpoint_iter == 0:
                print('   STEP {:03d}: Loss {:04.3f}'.format(i, ell))
                if verbose:
                    print('             = alpha*Lcontent {:04.3f}'.format(
                        content_weight_scaled * ell_content))
                    print('               + Lstyle {:04.3f}'.format(ell_style))
                    print('               + beta*Lwarp {:04.3f}'.format(
                        warp_weight * ell_warp))
                    print('               + gamma*TV {:04.3f}'.format(
                        reg_weight * ell_warp_TV))
                if save_intermediate:
                    plot_intermediate(new_im, content_im_warp, output_dir,
                                      output_prefix, colors, down_fac,
                                      src_Kpts, thetas_Kpts, target_Kpts,
                                      scale, i)

            # Take a gradient step
            ell.backward()
            optimizer.step()

    # Optimization finished
    src_Kpts_aug, dst_Kpts_aug, flow_Kpts_aug = gen_dst_pts_keypoints(
        src_Kpts, thetas_Kpts, no_flow_Kpts, border_Kpts)
    sizes = torch.FloatTensor([new_im.size(2), new_im.size(3)]).to(device)
    src_Kpts_aug = src_Kpts_aug / sizes
    dst_Kpts_aug = dst_Kpts_aug / sizes
    dst_Kpts_aug = torch.clamp(dst_Kpts_aug, min=0., max=1.)
    dst_Kpts = dst_Kpts_aug[:src_Kpts.size(0)]

    # Apply final warp
    sharp_final = True
    new_im = curr_im.clone()
    content_im_warp = content_im.clone()
    new_im, _ = apply_warp(new_im, [src_Kpts_aug], [dst_Kpts_aug],
                           device,
                           sharp=sharp_final)

    # Optionally save loss, keypoints, and optimized warp parameter thetas
    if save_extra:
        save_plots(im_size, curr_im, new_im, content_im, style_im, output_dir,
                   output_prefix, style_path, style_pts_path, colors, src_Kpts,
                   src_Kpts_aug, dst_Kpts * sizes, dst_Kpts_aug, target_Kpts,
                   target_Kpts_o, border_Kpts, device)
        save_loss(output_dir, output_prefix, content_weight, warp_weight,
                  reg_weight, max_iter, scale_list, ell_list, ell_style_list,
                  ell_content_list, ell_warp_list, ell_warp_TV_list)
        save_points(output_dir, output_prefix, src_Kpts, dst_Kpts * sizes,
                    src_Kpts_aug * sizes, dst_Kpts_aug * sizes, target_Kpts,
                    thetas_Kpts)

    # Return the stylized output image
    return new_im
def style_transfer(content_img_path,
                   img_size,
                   style_img_path,
                   style_size,
                   content_layer,
                   content_weight,
                   style_layers,
                   style_weights,
                   tv_weight,
                   init_random=False):
    """Perform style transfer from style image to source content image
    
    Args:
        content_img_path (str): File location of the content image.
        img_size (int): Size of the smallest content image dimension.
        style_img_path (str): File location of the style image.
        style_size (int): Size of the smallest style image dimension.
        content_layer (int): Index of the layer to use for content loss.
        content_weight (float): Scalar weight for content loss.
        style_layers ([]int): Indices of layers to use for style loss.
        style_weights ([]float): List of scalar weights to use for each layer in style_layers.
        tv_weigh (float): Scalar weight of total variation regularization term.
        init_random (boolean): Whether to initialize the starting image to uniform random noise.
    """
    tf.reset_default_graph()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    try:
        model = SqueezeNet(ckpt_path=CKPT_PATH, sess=sess)
    except NotFoundError:
        raise ValueError('checkpoint file is not found, please check %s' %
                         CKPT_PATH)

    # Extract features from content image
    content_img = preprocess_image(load_image(content_img_path, size=img_size))
    content_feats = model.extract_features(model.image)

    # Create content target
    content_target = sess.run(content_feats[content_layer],
                              {model.image: content_img[None]})

    # Extract features from style image
    style_img = preprocess_image(load_image(style_img_path, size=style_size))
    style_feats_by_layer = [content_feats[i] for i in style_layers]

    # Create style targets
    style_targets = []
    for style_feats in style_feats_by_layer:
        style_targets.append(gram_matrix(style_feats))
    style_targets = sess.run(style_targets, {model.image: style_img[None]})

    if init_random:
        generated_img = tf.Variable(tf.random_uniform(content_img[None].shape,
                                                      0, 1),
                                    name="image")
    else:
        generated_img = tf.Variable(content_img[None], name="image")

    # Extract features from generated image
    current_feats = model.extract_features(generated_img)

    loss = content_loss(content_weight, current_feats[content_layer], content_target) + \
        style_loss(current_feats, style_layers, style_targets, style_weights) + \
        total_variation_loss(generated_img, tv_weight)

    # Set up optimization parameters
    init_learning_rate = 3.0
    decayed_learning_rate = 0.1
    max_iter = 200

    learning_rate = tf.Variable(init_learning_rate, name="lr")
    with tf.variable_scope("optimizer") as opt_scope:
        train_op = tf.train.AdamOptimizer(learning_rate).minimize(
            loss, var_list=[generated_img])

    opt_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                 scope=opt_scope.name)
    sess.run(
        tf.variables_initializer([learning_rate, generated_img] + opt_vars))

    # Create an op that will clamp the image values when run
    clamp_image_op = tf.assign(generated_img,
                               tf.clip_by_value(generated_img, -1.5, 1.5))

    display_content_and_style(content_img, style_img)

    for t in range(max_iter):
        sess.run(train_op)
        if t < int(0.90 * max_iter):
            sess.run(clamp_image_op)
        elif t == int(0.90 * max_iter):
            sess.run(tf.assign(learning_rate, decayed_learning_rate))

        if t % 20 == 0:
            current_loss = sess.run(loss)
            print 'Iteration %d: %f' % (t, current_loss)

    img = sess.run(generated_img)
    plt.imshow(deprocess_image(img[0], rescale=True))
    plt.axis('off')
    plt.show()
discrim_target = tf.concat([adv_, 1 - adv_], 1)

loss_discrim = -tf.reduce_sum(
    discrim_target * tf.log(tf.clip_by_value(discrim_predictions, 1e-10, 1.0)))

loss_texture = -loss_discrim / 1000

correct_predictions = tf.equal(tf.argmax(discrim_predictions, 1),
                               tf.argmax(discrim_target, 1))
discim_accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

# 2. color loss:
loss_color = 1000 * tf.reduce_mean(tf.abs(out_image - gt_image))

# 3. content loss:
loss_content = loss.content_loss(crop_gt_and_out[:, :, :, 0:3],
                                 crop_gt_and_out[:, :, :, 3:6], 5) / 1200

G_loss = loss_color + loss_texture + loss_content  #+ loss_tv

#t_vars = tf.trainable_variables()
generator_vars = [
    v for v in tf.global_variables() if v.name.startswith("generator")
]
discriminator_vars = [
    v for v in tf.global_variables() if v.name.startswith("discriminator")
]
lr = tf.placeholder(tf.float32)

G_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(
    G_loss, var_list=generator_vars)
D_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(
Exemple #8
0
def optimize():
    MODEL_DIR_NAME = os.path.dirname(FLAGS.MODEL_PATH)
    if not os.path.exists(MODEL_DIR_NAME):
        os.mkdir(MODEL_DIR_NAME)

    style_paths = FLAGS.STYLE_IMAGES.split(',')
    style_layers = FLAGS.STYLE_LAYERS.split(',')
    content_layers = FLAGS.CONTENT_LAYERS.split(',')

    # style gram matrix
    style_features_t = loss.get_style_features(style_paths, style_layers,
                                               FLAGS.IMAGE_SIZE, FLAGS.STYLE_SCALE, FLAGS.VGG_PATH)

    with tf.Graph().as_default(), tf.Session() as sess:
        # train_images
        images = reader.image(FLAGS.BATCH_SIZE, FLAGS.IMAGE_SIZE,
                              FLAGS.TRAIN_IMAGES_FOLDER, FLAGS.EPOCHS)

        generated = transform.net(images - vgg.MEAN_PIXEL, training=True)
        net, _ = vgg.net(FLAGS.VGG_PATH, tf.concat([generated, images], 0) - vgg.MEAN_PIXEL)

        # 损失函数
        content_loss = loss.content_loss(net, content_layers)
        style_loss = loss.style_loss(
            net, style_features_t, style_layers) / len(style_paths)

        total_loss = FLAGS.STYLE_WEIGHT * style_loss + FLAGS.CONTENT_WEIGHT * content_loss + \
            FLAGS.TV_WEIGHT * loss.total_variation_loss(generated)

        # 准备训练
        global_step = tf.Variable(0, name="global_step", trainable=False)

        variable_to_train = []
        for variable in tf.trainable_variables():
            if not variable.name.startswith('vgg19'):
                variable_to_train.append(variable)

        train_op = tf.train.AdamOptimizer(FLAGS.LEARNING_RATE).minimize(
            total_loss, global_step=global_step, var_list=variable_to_train)

        variables_to_restore = []
        for v in tf.global_variables():
            if not v.name.startswith('vgg19'):
                variables_to_restore.append(v)

        # 开始训练
        saver = tf.train.Saver(variables_to_restore,
                               write_version=tf.train.SaverDef.V1)
        sess.run([tf.global_variables_initializer(),
                  tf.local_variables_initializer()])

        # 加载检查点
        ckpt = tf.train.latest_checkpoint(MODEL_DIR_NAME)
        if ckpt:
            tf.logging.info('Restoring model from {}'.format(ckpt))
            saver.restore(sess, ckpt)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        start_time = time.time()
        try:
            while not coord.should_stop():
                _, loss_t, step = sess.run([train_op, total_loss, global_step])
                elapsed_time = time.time() - start_time
                start_time = time.time()

                if step % 10 == 0:
                    tf.logging.info(
                        'step: %d,  total loss %f, secs/step: %f' % (step, loss_t, elapsed_time))

                if step % 10000 == 0:
                    saver.save(sess, FLAGS.MODEL_PATH, global_step=step)
                    tf.logging.info('Save model')

        except tf.errors.OutOfRangeError:
            saver.save(sess,  FLAGS.MODEL_PATH + '-done')
            tf.logging.info('Done training -- epoch limit reached')
        finally:
            coord.request_stop()

        coord.join(threads)
Exemple #9
0
def train(options):
    content_weight = options.content_weight
    tv_weight = options.tv_weight
    initial_lr = options.initial_lr
    max_iter = options.max_iter
    style_weights = options.style_weights
    print_iterations = options.print_iterations
    img_size = options.img_size
    content = options.content
    style = options.style
    output = options.output
    beta1 = options.beta1
    beta2 = options.beta2
    epsilon = options.epsilon
    h5_file = options.h5_file

    content_layer = 12
    style_layers = [0, 3, 6, 11, 16]
    style_target_vars = []
    print(tf.test.is_gpu_available(cuda_only=True))
    contentImg = pre_img.load_image(content, size=img_size)
    contentImg = pre_img.preprocess_image(contentImg)
    styleImg = pre_img.load_image(style, size=img_size)
    styleImg = pre_img.preprocess_image(styleImg)
    img_var = tf.Variable(contentImg[None], name="image", dtype=tf.float32)
    lr_var = tf.Variable(initial_lr, name="lr")
    new_img_feats = vgg.extract_features(img_var, h5_file)
    content_img_feats = vgg.extract_features(contentImg[None], h5_file)
    style_img_feats = vgg.extract_features(styleImg[None], h5_file)
    for idx in style_layers:
        style_target_vars.append(loss.gram_matrix(style_img_feats[idx]))
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        style_loss = loss.style_loss(new_img_feats, style_layers,
                                     style_target_vars, style_weights)
        content_loss = loss.content_loss(content_weight,
                                         new_img_feats[content_layer],
                                         content_img_feats[content_layer])
        tv_loss = loss.tv_loss(img_var, tv_weight)
        total_loss = style_loss + content_loss + tv_loss
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_var,
                                           beta1=beta1,
                                           beta2=beta2,
                                           epsilon=epsilon)
        training_op = optimizer.minimize(total_loss, var_list=[img_var])
        init = tf.global_variables_initializer()
        sess.run(init)
        sess.run(img_var.initializer)
        for t in range(max_iter):
            if print_iterations is not None and t % print_iterations == 0:
                new_image = img_var.eval()
                imageio.imwrite(output + '\\iteration_' + str(t) + '.jpg',
                                pre_img.deprocess_image(new_image[0]))
            sess.run(training_op)
            loss_val = sess.run(total_loss)
            s_loss = sess.run(style_loss)
            c_loss = sess.run(content_loss)
            print(
                str(t) + ':' + str(loss_val) + '\t' + str(s_loss) + '\t' +
                str(c_loss))
        new_image = sess.run(img_var)
        imageio.imwrite(output + '\\final.jpg',
                        pre_img.deprocess_image(new_image[0]))