Esempio n. 1
0
def train(args):

    input_photo = tf.placeholder(
        tf.float32, [args.batch_size, args.patch_size, args.patch_size, 3])
    input_superpixel = tf.placeholder(
        tf.float32, [args.batch_size, args.patch_size, args.patch_size, 3])
    input_cartoon = tf.placeholder(
        tf.float32, [args.batch_size, args.patch_size, args.patch_size, 3])
    # output=>fake picture
    output = network.unet_generator(input_photo)
    #
    output = guided_filter(input_photo, output, r=1)

    blur_fake = guided_filter(output, output, r=5, eps=2e-1)
    blur_cartoon = guided_filter(input_cartoon, input_cartoon, r=5, eps=2e-1)

    gray_fake, gray_cartoon = utils.color_shift(output, input_cartoon)

    d_loss_gray, g_loss_gray = loss.lsgan_loss(network.disc_sn,
                                               gray_cartoon,
                                               gray_fake,
                                               scale=1,
                                               patch=True,
                                               name='disc_gray')
    d_loss_blur, g_loss_blur = loss.lsgan_loss(network.disc_sn,
                                               blur_cartoon,
                                               blur_fake,
                                               scale=1,
                                               patch=True,
                                               name='disc_blur')

    vgg_model = loss.Vgg19('vgg19_no_fc.npy')
    vgg_photo = vgg_model.build_conv4_4(input_photo)
    vgg_output = vgg_model.build_conv4_4(output)
    vgg_superpixel = vgg_model.build_conv4_4(input_superpixel)
    h, w, c = vgg_photo.get_shape().as_list()[1:]

    photo_loss = tf.reduce_mean(
        tf.losses.absolute_difference(vgg_photo, vgg_output)) / (h * w * c)
    superpixel_loss = tf.reduce_mean(tf.losses.absolute_difference\
                                     (vgg_superpixel, vgg_output))/(h*w*c)
    recon_loss = photo_loss + superpixel_loss
    tv_loss = loss.total_variation_loss(output)

    g_loss_total = 1e4 * tv_loss + 1e-1 * g_loss_blur + g_loss_gray + 2e2 * recon_loss
    d_loss_total = d_loss_blur + d_loss_gray

    all_vars = tf.trainable_variables()
    gene_vars = [var for var in all_vars if 'gene' in var.name]
    disc_vars = [var for var in all_vars if 'disc' in var.name]

    tf.summary.scalar('tv_loss', tv_loss)
    tf.summary.scalar('photo_loss', photo_loss)
    tf.summary.scalar('superpixel_loss', superpixel_loss)
    tf.summary.scalar('recon_loss', recon_loss)
    tf.summary.scalar('d_loss_gray', d_loss_gray)
    tf.summary.scalar('g_loss_gray', g_loss_gray)
    tf.summary.scalar('d_loss_blur', d_loss_blur)
    tf.summary.scalar('g_loss_blur', g_loss_blur)
    tf.summary.scalar('d_loss_total', d_loss_total)
    tf.summary.scalar('g_loss_total', g_loss_total)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):

        g_optim = tf.train.AdamOptimizer(args.adv_train_lr, beta1=0.5, beta2=0.99)\
                                        .minimize(g_loss_total, var_list=gene_vars)

        d_optim = tf.train.AdamOptimizer(args.adv_train_lr, beta1=0.5, beta2=0.99)\
                                        .minimize(d_loss_total, var_list=disc_vars)
    '''
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    '''
    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=args.gpu_fraction)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    train_writer = tf.summary.FileWriter(args.save_dir + '/train_log')
    summary_op = tf.summary.merge_all()
    saver = tf.train.Saver(var_list=gene_vars, max_to_keep=20)

    with tf.device('/device:GPU:0'):

        sess.run(tf.global_variables_initializer())
        saver.restore(sess,
                      tf.train.latest_checkpoint('pretrain/saved_models'))

        face_photo_dir = 'dataset/photo_face'
        face_photo_list = utils.load_image_list(face_photo_dir)
        scenery_photo_dir = 'dataset/photo_scenery'
        scenery_photo_list = utils.load_image_list(scenery_photo_dir)

        face_cartoon_dir = 'dataset/cartoon_face'
        face_cartoon_list = utils.load_image_list(face_cartoon_dir)
        scenery_cartoon_dir = 'dataset/cartoon_scenery'
        scenery_cartoon_list = utils.load_image_list(scenery_cartoon_dir)

        for total_iter in tqdm(range(args.total_iter)):

            if np.mod(total_iter, 5) == 0:
                photo_batch = utils.next_batch(face_photo_list,
                                               args.batch_size)
                cartoon_batch = utils.next_batch(face_cartoon_list,
                                                 args.batch_size)
            else:
                photo_batch = utils.next_batch(scenery_photo_list,
                                               args.batch_size)
                cartoon_batch = utils.next_batch(scenery_cartoon_list,
                                                 args.batch_size)

            inter_out = sess.run(output,
                                 feed_dict={
                                     input_photo: photo_batch,
                                     input_superpixel: photo_batch,
                                     input_cartoon: cartoon_batch
                                 })
            '''
            adaptive coloring has to be applied with the clip_by_value 
            in the last layer of generator network, which is not very stable.
            to stabiliy reproduce our results, please use power=1.0
            and comment the clip_by_value function in the network.py first
            If this works, then try to use adaptive color with clip_by_value.
            '''
            if args.use_enhance:
                superpixel_batch = utils.selective_adacolor(inter_out,
                                                            power=1.2)
            else:
                superpixel_batch = utils.simple_superpixel(inter_out,
                                                           seg_num=200)

            _, g_loss, r_loss = sess.run(
                [g_optim, g_loss_total, recon_loss],
                feed_dict={
                    input_photo: photo_batch,
                    input_superpixel: superpixel_batch,
                    input_cartoon: cartoon_batch
                })

            _, d_loss, train_info = sess.run(
                [d_optim, d_loss_total, summary_op],
                feed_dict={
                    input_photo: photo_batch,
                    input_superpixel: superpixel_batch,
                    input_cartoon: cartoon_batch
                })

            train_writer.add_summary(train_info, total_iter)

            if np.mod(total_iter + 1, 50) == 0:

                print('Iter: {}, d_loss: {}, g_loss: {}, recon_loss: {}'.\
                        format(total_iter, d_loss, g_loss, r_loss))
                if np.mod(total_iter + 1, 500) == 0:
                    saver.save(sess,
                               args.save_dir + '/saved_models/model',
                               write_meta_graph=False,
                               global_step=total_iter)

                    photo_face = utils.next_batch(face_photo_list,
                                                  args.batch_size)
                    cartoon_face = utils.next_batch(face_cartoon_list,
                                                    args.batch_size)
                    photo_scenery = utils.next_batch(scenery_photo_list,
                                                     args.batch_size)
                    cartoon_scenery = utils.next_batch(scenery_cartoon_list,
                                                       args.batch_size)

                    result_face = sess.run(output,
                                           feed_dict={
                                               input_photo: photo_face,
                                               input_superpixel: photo_face,
                                               input_cartoon: cartoon_face
                                           })

                    result_scenery = sess.run(output,
                                              feed_dict={
                                                  input_photo: photo_scenery,
                                                  input_superpixel:
                                                  photo_scenery,
                                                  input_cartoon:
                                                  cartoon_scenery
                                              })

                    utils.write_batch_image(
                        result_face, args.save_dir + '/images',
                        str(total_iter) + '_face_result.jpg', 4)
                    utils.write_batch_image(
                        photo_face, args.save_dir + '/images',
                        str(total_iter) + '_face_photo.jpg', 4)

                    utils.write_batch_image(
                        result_scenery, args.save_dir + '/images',
                        str(total_iter) + '_scenery_result.jpg', 4)
                    utils.write_batch_image(
                        photo_scenery, args.save_dir + '/images',
                        str(total_iter) + '_scenery_photo.jpg', 4)
Esempio n. 2
0
def train(cfg):
    # Set device if gpu is available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Build network
    net = ImageTransformationNet().to(device)

    # Setup optimizer
    optimizer = optim.Adam(net.parameters())

    # Load state if resuming training
    if cfg['resume']:
        checkpoint = torch.load(cfg['resume'])
        net.load_state_dict(checkpoint['net_state_dict'])
        optimizer.load_state_dict(checkpoint['opt_state_dict'])

        # Get starting epoch and batch (expects weight file in form EPOCH_<>_BATCH_<>.pt)
        parts = cfg['resume'].split('_')
        first_epoch = int(checkpoint['epoch'])
        first_batch = int(parts[-1].split('.')[0])

        # Setup dataloader
        train_data = tqdm(build_data_loader(cfg), initial=first_batch)

    else:
        # Setup dataloader
        train_data = tqdm(build_data_loader(cfg))

        # Set first epoch and batch
        first_epoch = 1
        first_batch = 0

    # Fetch style image and style grams
    style_im = load_image(cfg['style_image'], cfg)
    style_grams = get_style_grams(style_im, cfg)

    # Setup log file if specified
    log_dir = Path('logs')
    log_dir.mkdir(parents=True, exist_ok=True)
    if cfg['log_file'] and not cfg['resume']:
        today = datetime.datetime.today().strftime('%m/%d/%Y')
        header = f'Feed-Forward Style Transfer Training Log - {today}'
        with open(cfg['log_file'], 'w+') as file:
            file.write(header + '\n\n')

    # Setup log CSV if specified
    if cfg['csv_log_file'] and not cfg['resume']:
        utils.setup_csv(cfg)

    for epoch in range(first_epoch, cfg['epochs'] + 1):

        # Keep track of per epoch loss
        content_loss = 0
        style_loss = 0
        total_var_loss = 0
        train_loss = 0
        num_batches = 0

        # Setup first batch to start enumerate at proper place
        if epoch == first_epoch:
            start = first_batch
        else:
            start = 0

        for i, batch in enumerate(train_data, start=start):
            batch = batch.to(device)

            # Put batch through network
            batch_styled = net(batch)

            # Get vgg activations for styled and unstyled batch
            features = vgg_activations(batch_styled)
            content_features = vgg_activations(batch)

            # Get loss
            c_loss, s_loss = perceptual_loss(features=features,
                                             content_features=content_features,
                                             style_grams=style_grams,
                                             cfg=cfg)
            tv_loss = total_variation_loss(batch_styled, cfg)
            total_loss = c_loss + s_loss + tv_loss

            # Backpropogate
            total_loss.backward()

            # Do one step of optimization
            optimizer.step()

            # Clear gradients before next batch
            optimizer.zero_grad()

            # Update summary statistics
            with torch.no_grad():
                content_loss += c_loss.item()
                style_loss += s_loss.item()
                total_var_loss += tv_loss.item()
                train_loss += total_loss.item()
                num_batches += 1

            # Update progress bar
            avg_loss = round(train_loss / num_batches, 2)
            avg_c_loss = round(content_loss / num_batches, 2)
            avg_s_loss = round(style_loss / num_batches, 1)
            avg_tv_loss = round(total_var_loss / num_batches, 3)
            train_data.set_description(
                f'C - {avg_c_loss} | S - {avg_s_loss} | TV - {avg_tv_loss} | Total - {avg_loss}'
            )
            train_data.refresh()

            # Create progress image if specified
            if cfg['image_checkpoint'] and ((i + 1) % cfg['image_checkpoint']
                                            == 0):
                save_path = str(
                    Path(
                        cfg['image_checkpoint_dir'],
                        f'EPOCH_{str(epoch).zfill(3)}_BATCH_{str(i+1).zfill(5)}.png'
                    ))
                utils.make_checkpoint_image(cfg, net, save_path)

            # Save weights if specified
            if cfg['save_checkpoint'] and ((i + 1) % cfg['save_checkpoint']
                                           == 0):
                save_path = str(
                    Path(
                        cfg['save_checkpoint_dir'],
                        f'EPOCH_{str(epoch).zfill(3)}_BATCH_{str(i+1).zfill(5)}.pth'
                    ))
                checkpoint = {
                    'epoch': epoch,
                    'net_state_dict': net.state_dict(),
                    'opt_state_dict': optimizer.state_dict(),
                    'loss': avg_loss
                }
                torch.save(checkpoint, save_path)

            # Write progress row to CSV
            if cfg['csv_checkpoint'] and ((i + 1) % cfg['csv_checkpoint']
                                          == 0):
                row = [
                    epoch, i + 1, avg_c_loss, avg_s_loss, avg_tv_loss, avg_loss
                ]
                utils.write_progress_row(cfg, row)

        # Write loss at end of each epoch
        if cfg['log_file']:
            avg_loss = round(train_loss / num_batches, 4)
            line = f'EPOCH {epoch} | Loss - {avg_loss}'
            with open(cfg['log_file'], 'a') as file:
                file.write(line + '\n')

        # Save network if specified
        if cfg['epoch_save_checkpoint'] and (
                epoch % cfg['epoch_save_checkpoint'] == 0):
            save_path = str(
                Path(cfg['save_checkpoint_dir'],
                     f'EPOCH_{str(epoch).zfill(3)}.pth'))
            checkpoint = {
                'epoch': epoch,
                'net_state_dict': net.state_dict(),
                'opt_state_dict': optimizer.state_dict(),
                'loss': round(train_loss / num_batches, 4)
            }
            torch.save(checkpoint, save_path)
def style_transfer(content_img_path,
                   img_size,
                   style_img_path,
                   style_size,
                   content_layer,
                   content_weight,
                   style_layers,
                   style_weights,
                   tv_weight,
                   init_random=False):
    """Perform style transfer from style image to source content image
    
    Args:
        content_img_path (str): File location of the content image.
        img_size (int): Size of the smallest content image dimension.
        style_img_path (str): File location of the style image.
        style_size (int): Size of the smallest style image dimension.
        content_layer (int): Index of the layer to use for content loss.
        content_weight (float): Scalar weight for content loss.
        style_layers ([]int): Indices of layers to use for style loss.
        style_weights ([]float): List of scalar weights to use for each layer in style_layers.
        tv_weigh (float): Scalar weight of total variation regularization term.
        init_random (boolean): Whether to initialize the starting image to uniform random noise.
    """
    tf.reset_default_graph()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    try:
        model = SqueezeNet(ckpt_path=CKPT_PATH, sess=sess)
    except NotFoundError:
        raise ValueError('checkpoint file is not found, please check %s' %
                         CKPT_PATH)

    # Extract features from content image
    content_img = preprocess_image(load_image(content_img_path, size=img_size))
    content_feats = model.extract_features(model.image)

    # Create content target
    content_target = sess.run(content_feats[content_layer],
                              {model.image: content_img[None]})

    # Extract features from style image
    style_img = preprocess_image(load_image(style_img_path, size=style_size))
    style_feats_by_layer = [content_feats[i] for i in style_layers]

    # Create style targets
    style_targets = []
    for style_feats in style_feats_by_layer:
        style_targets.append(gram_matrix(style_feats))
    style_targets = sess.run(style_targets, {model.image: style_img[None]})

    if init_random:
        generated_img = tf.Variable(tf.random_uniform(content_img[None].shape,
                                                      0, 1),
                                    name="image")
    else:
        generated_img = tf.Variable(content_img[None], name="image")

    # Extract features from generated image
    current_feats = model.extract_features(generated_img)

    loss = content_loss(content_weight, current_feats[content_layer], content_target) + \
        style_loss(current_feats, style_layers, style_targets, style_weights) + \
        total_variation_loss(generated_img, tv_weight)

    # Set up optimization parameters
    init_learning_rate = 3.0
    decayed_learning_rate = 0.1
    max_iter = 200

    learning_rate = tf.Variable(init_learning_rate, name="lr")
    with tf.variable_scope("optimizer") as opt_scope:
        train_op = tf.train.AdamOptimizer(learning_rate).minimize(
            loss, var_list=[generated_img])

    opt_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                 scope=opt_scope.name)
    sess.run(
        tf.variables_initializer([learning_rate, generated_img] + opt_vars))

    # Create an op that will clamp the image values when run
    clamp_image_op = tf.assign(generated_img,
                               tf.clip_by_value(generated_img, -1.5, 1.5))

    display_content_and_style(content_img, style_img)

    for t in range(max_iter):
        sess.run(train_op)
        if t < int(0.90 * max_iter):
            sess.run(clamp_image_op)
        elif t == int(0.90 * max_iter):
            sess.run(tf.assign(learning_rate, decayed_learning_rate))

        if t % 20 == 0:
            current_loss = sess.run(loss)
            print 'Iteration %d: %f' % (t, current_loss)

    img = sess.run(generated_img)
    plt.imshow(deprocess_image(img[0], rescale=True))
    plt.axis('off')
    plt.show()
Esempio n. 4
0
def optimize():
    MODEL_DIR_NAME = os.path.dirname(FLAGS.MODEL_PATH)
    if not os.path.exists(MODEL_DIR_NAME):
        os.mkdir(MODEL_DIR_NAME)

    style_paths = FLAGS.STYLE_IMAGES.split(',')
    style_layers = FLAGS.STYLE_LAYERS.split(',')
    content_layers = FLAGS.CONTENT_LAYERS.split(',')

    # style gram matrix
    style_features_t = loss.get_style_features(style_paths, style_layers,
                                               FLAGS.IMAGE_SIZE, FLAGS.STYLE_SCALE, FLAGS.VGG_PATH)

    with tf.Graph().as_default(), tf.Session() as sess:
        # train_images
        images = reader.image(FLAGS.BATCH_SIZE, FLAGS.IMAGE_SIZE,
                              FLAGS.TRAIN_IMAGES_FOLDER, FLAGS.EPOCHS)

        generated = transform.net(images - vgg.MEAN_PIXEL, training=True)
        net, _ = vgg.net(FLAGS.VGG_PATH, tf.concat([generated, images], 0) - vgg.MEAN_PIXEL)

        # 损失函数
        content_loss = loss.content_loss(net, content_layers)
        style_loss = loss.style_loss(
            net, style_features_t, style_layers) / len(style_paths)

        total_loss = FLAGS.STYLE_WEIGHT * style_loss + FLAGS.CONTENT_WEIGHT * content_loss + \
            FLAGS.TV_WEIGHT * loss.total_variation_loss(generated)

        # 准备训练
        global_step = tf.Variable(0, name="global_step", trainable=False)

        variable_to_train = []
        for variable in tf.trainable_variables():
            if not variable.name.startswith('vgg19'):
                variable_to_train.append(variable)

        train_op = tf.train.AdamOptimizer(FLAGS.LEARNING_RATE).minimize(
            total_loss, global_step=global_step, var_list=variable_to_train)

        variables_to_restore = []
        for v in tf.global_variables():
            if not v.name.startswith('vgg19'):
                variables_to_restore.append(v)

        # 开始训练
        saver = tf.train.Saver(variables_to_restore,
                               write_version=tf.train.SaverDef.V1)
        sess.run([tf.global_variables_initializer(),
                  tf.local_variables_initializer()])

        # 加载检查点
        ckpt = tf.train.latest_checkpoint(MODEL_DIR_NAME)
        if ckpt:
            tf.logging.info('Restoring model from {}'.format(ckpt))
            saver.restore(sess, ckpt)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        start_time = time.time()
        try:
            while not coord.should_stop():
                _, loss_t, step = sess.run([train_op, total_loss, global_step])
                elapsed_time = time.time() - start_time
                start_time = time.time()

                if step % 10 == 0:
                    tf.logging.info(
                        'step: %d,  total loss %f, secs/step: %f' % (step, loss_t, elapsed_time))

                if step % 10000 == 0:
                    saver.save(sess, FLAGS.MODEL_PATH, global_step=step)
                    tf.logging.info('Save model')

        except tf.errors.OutOfRangeError:
            saver.save(sess,  FLAGS.MODEL_PATH + '-done')
            tf.logging.info('Done training -- epoch limit reached')
        finally:
            coord.request_stop()

        coord.join(threads)