Example #1
0
def total_loss(combination_image, base_image, style_image):
    input_tensor = tf.concat([base_image, style_image, combination_image],
                             axis=0)
    features = feature_model(input_tensor)

    # Initialize the loss
    loss = tf.zeros(shape=())

    layer_features = features[content_layer_name]
    content_image_features = layer_features[0, :, :, :]
    combination_features = layer_features[2, :, :, :]

    loss = loss + CONTENT_WEIGHT * content_loss(content_image_features,
                                                combination_features)

    for layer_name in style_layer_names:
        layer_features = features[layer_name]
        style_features = layer_features[1, :, :, :]
        combination_features = layer_features[2, :, :, :]
        sl = style_loss(style_features,
                        combination_features,
                        size=image_height * image_width)
        loss += (STYLE_WEIGHT / len(style_layer_names)) * sl

    loss += TV_WEIGHT * \
        total_variation_loss(combination_image, image_height, image_width)
    return loss
Example #2
0
def main(argv=None):
    network_fn = nets_factory.get_network_fn('vgg_16',
                                             num_classes=1,
                                             is_training=False)
    image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
        'vgg_16', is_training=False)

    preprocess_content_image = reader.get_image(FLAGS.CONTENT_IMAGE,
                                                FLAGS.IMAGE_SIZE)

    # add bath for vgg net training
    preprocess_content_image = tf.expand_dims(preprocess_content_image, 0)
    _, endpoints_dict = network_fn(preprocess_content_image,
                                   spatial_squeeze=False)

    # Log the structure of loss network
    tf.logging.info(
        'Loss network layers(You can define them in "content_layers" and "style_layers"):'
    )
    for key in endpoints_dict:
        tf.logging.info(key)
    """Build Losses"""
    # style_features_t = losses.get_style_features(endpoints_dict, FLAGS.STYLE_LAYERS)
    content_loss, generaged_image = losses.content_loss(
        endpoints_dict, FLAGS.CONTENT_LAYERS, FLAGS.CONTENT_IMAGE)
    style_loss, style_loss_summary = losses.style_loss(endpoints_dict,
                                                       FLAGS.style_layers,
                                                       FLAGS.STYLE_IMAGE)
    tv_loss = losses.total_variation_loss(
        generaged_image)  # use the unprocessed image

    loss = FLAGS.STYLE_WEIGHT * style_loss + FLAGS.CONTENT_WEIGHT * content_loss + FLAGS.TV_WEIGHT * tv_loss
    train_op = tf.train.AdamOptimizer(FLAGS.LEARNING_RATE).minimize(loss)

    output_image = tf.image.encode_png(
        tf.saturate_cast(
            tf.squeeze(generaged_image) + reader.mean_pixel, tf.uint8))

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        start_time = time.time()
        for step in range(FLAGS.NUM_ITERATIONS):
            _, loss_t, cl, sl = sess.run(
                [train_op, loss, content_loss, style_loss])
            elapsed = time.time() - start_time
            start_time = time.time()
            print(step, elapsed, loss_t, cl, sl)
        image_t = sess.run(output_image)
        with open('out.png', 'wb') as f:
            f.write(image_t)
Example #3
0
    def __init__(self,
                 base_img_path,
                 style_img_path,
                 output_img_path,
                 output_width,
                 convnet,
                 content_weight,
                 style_weight,
                 tv_weight,
                 content_layer,
                 style_layers,
                 iterations):

        self.base_img_path = base_img_path
        self.style_img_path = style_img_path
        self.output_img_path = output_img_path

        self.width = output_width
        width, height = load_img(self.base_img_path).size
        new_dims = (height, width)

        self.img_nrows = height
        self.img_ncols = width

        if self.width is not None:
            num_rows = int(np.floor(float(height * self.width / width)))
            new_dims = (num_rows, self.width)

            self.img_nrows = num_rows
            self.img_ncols = self.width

        self.content_img = K.variable(preprocess_image(self.base_img_path, new_dims))
        self.style_img = K.variable(preprocess_image(self.style_img_path, new_dims))

        if K.image_dim_ordering() == 'th':
            self.output_img = K.placeholder((1, 3, new_dims[0], new_dims[1]))
        else:
            self.output_img = K.placeholder((1, new_dims[0], new_dims[1], 3))

        print("\tSize of content image is: {}".format(K.int_shape(self.content_img)))
        print("\tSize of style image is: {}".format(K.int_shape(self.style_img)))
        print("\tSize of output image is: {}".format(K.int_shape(self.output_img)))

        self.input_img = K.concatenate([self.content_img,
                                        self.style_img,
                                        self.output_img], axis=0)

        self.convnet = convnet
        self.iterations = iterations

        self.content_weight = content_weight
        self.style_weight = style_weight
        self.tv_weight = tv_weight

        self.content_layer = content_layer
        self.style_layers = style_layers

        print('\tLoading {} model'.format(self.convnet.upper()))

        if self.convnet == 'vgg16':
            self.model = vgg16.VGG16(input_tensor=self.input_img,
                                     weights='imagenet',
                                     include_top=False)
        else:
            self.model = vgg19.VGG19(input_tensor=self.input_img,
                                     weights='imagenet',
                                     include_top=False)

        outputs_dict = dict([(layer.name, layer.output) for layer in self.model.layers])
        content_features = outputs_dict[self.content_layer]

        base_image_features = content_features[0, :, :, :]
        combination_features = content_features[2, :, :, :]

        content_loss = self.content_weight * \
            feature_reconstruction_loss(base_image_features,
                                        combination_features)

        temp_style_loss = K.variable(0.0)
        weight = 1.0 / float(len(self.style_layers))

        for layer in self.style_layers:
            style_features = outputs_dict[layer]
            style_image_features = style_features[1, :, :, :]
            output_style_features = style_features[2, :, :, :]
            temp_style_loss += weight * \
                style_reconstruction_loss(style_image_features,
                                          output_style_features,
                                          self.img_nrows,
                                          self.img_ncols)
        style_loss = self.style_weight * temp_style_loss

        tv_loss = self.tv_weight * total_variation_loss(self.output_img,
                                                        self.img_nrows,
                                                        self.img_ncols)

        total_loss = content_loss + style_loss + tv_loss

        print('\tComputing gradients...')
        grads = K.gradients(total_loss, self.output_img)

        outputs = [total_loss]
        if type(grads) in {list, tuple}:
            outputs += grads
        else:
            outputs.append(grads)

        self.loss_and_grads = K.function([self.output_img], outputs)
Example #4
0
        mixed_image, reconstructionA, reconstructionB = perceptual(
            real_A, real_B)

        # Reconstruction loss
        loss_reconstruction_A = criterion_identity(reconstructionA,
                                                   real_A) * 30.0
        loss_reconstruction_B = criterion_identity(reconstructionB,
                                                   real_B) * 30.0

        # adv loss
        pred_fake = discriminator(mixed_image)
        loss_adv = criterion_adv(pred_fake, target_real)

        #Total Variational loss
        TV_loss = los.total_variation_loss(mixed_image)

        # Perceptual loss
        cuda_mixed_image = mixed_image.clone().requires_grad_(True).to(device)
        cuda_real_A = real_A.clone().requires_grad_(True).to(device)
        cuda_real_B = real_B.clone().requires_grad_(True).to(device)
        style_features = get_features(cuda_real_A, vgg)
        content_features = get_features(cuda_real_B, vgg)
        target_features = get_features(cuda_mixed_image, vgg)
        content_loss = los.compute_content_loss(
            target_features['conv4_2'], content_features['conv4_2']) * 0.1
        style_loss = los.compute_style_loss(style_features,
                                            target_features) * 0.05

        total_loss = loss_reconstruction_A + loss_reconstruction_B + loss_adv + content_loss + style_loss + TV_loss
        total_loss.backward()
Example #5
0
]

#iterating over the feature layers
for layer_name in feature_layers:
    layer_features = output_dict[layer_name]
    #style_image_features output at the particular feature layer
    style_image_features = layer_features[1, :, :, :]
    #combination_image_features output at the particular feature layer
    combination_image_features = layer_features[2, :, :, :]
    #calculating the style loss
    sl = style_loss(style_image_features, combination_image_features,
                    img_nrows, img_ncols)
    loss += (style_weight / len(feature_layers)) * sl

#calculating the total variation loss in the combination image
loss += total_variation_weight * total_variation_loss(combination_image,
                                                      img_nrows, img_ncols)

#getting the gradients of the loss w.r.t the pixels of the combination image
grads = K.gradients(loss, combination_image)

outs = [loss]
outs += grads

#instantiating a Keras function
f_outputs = K.function([combination_image], outs)


#defining the function to evaluate loss and gradients
def eval_loss_and_grads(x):
    x = x.reshape((1, img_nrows, img_ncols, 3))
    out = f_outputs([x])
def main():
    # Make sure the training path exists.
    training_path = 'models/log/'
    if not (os.path.exists(training_path)):
        os.makedirs(training_path)
    style_features = get_style_feature()

    with tf.Graph().as_default():
        train_image = reader.get_train_image(batch_size, image_size,
                                             image_size, dataset_path)
        generated = model.net(train_image)
        processed_generated = [
            reader.prepose_image(image, image_size, image_size)
            for image in tf.unstack(generated, axis=0, num=batch_size)
        ]
        processed_generated = tf.stack(processed_generated)
        net = model.load_model(
            tf.concat([processed_generated, train_image], 0), vgg16_ckpt_path)
        with tf.Session() as sess:
            """Build Losses"""
            content_loss = losses.content_loss(net, content_layers)
            style_loss, style_loss_summary = losses.style_loss(
                net, style_features, style_layers)
            tv_loss = losses.total_variation_loss(
                generated)  # use the unprocessed image

            loss = style_weight * style_loss + content_weight * content_loss + tv_weight * tv_loss

            # Add Summary for visualization in tensorboard.
            """Add Summary"""
            tf.summary.scalar('losses/content_loss', content_loss)
            tf.summary.scalar('losses/style_loss', style_loss)
            tf.summary.scalar('losses/regularizer_loss', tv_loss)

            tf.summary.scalar('weighted_losses/weighted_content_loss',
                              content_loss * content_weight)
            tf.summary.scalar('weighted_losses/weighted_style_loss',
                              style_loss * style_weight)
            tf.summary.scalar('weighted_losses/weighted_regularizer_loss',
                              tv_loss * tv_weight)
            tf.summary.scalar('total_loss', loss)

            for layer in style_layers:
                tf.summary.scalar('style_losses/' + layer,
                                  style_loss_summary[layer])
            tf.summary.image('generated', generated)
            tf.summary.image(
                'origin',
                tf.stack([
                    reader.mean_add(image) for image in tf.unstack(
                        train_image, axis=0, num=batch_size)
                ]))
            summary = tf.summary.merge_all()
            writer = tf.summary.FileWriter(training_path)
            """Prepare to Train"""
            global_step = tf.Variable(0, name="global_step", trainable=False)
            train_op = tf.train.AdamOptimizer(1e-3).minimize(
                loss,
                global_step=global_step,
                var_list=tf.trainable_variables())

            saver = tf.train.Saver(tf.trainable_variables())

            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer()
            ])

            # Restore variables for training model if the checkpoint file exists.
            last_file = tf.train.latest_checkpoint(training_path)
            if last_file:
                tf.logging.info('Restoring model from {}'.format(last_file))
                saver.restore(sess, last_file)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            start_time = time.time()
            try:
                while not coord.should_stop():
                    _, loss_t, step = sess.run([train_op, loss, global_step])
                    loss_c, loss_s = sess.run([content_loss, style_loss])
                    elapsed_time = time.time() - start_time
                    start_time = time.time()
                    """logging"""
                    if step % 10 == 0:
                        print(
                            'step: %d, content Loss %f, style Loss %f, total Loss %f, secs/step: %f'
                            % (step, loss_c, loss_s, loss_t, elapsed_time))
                    """summary"""
                    if step % 25 == 0:
                        print('adding summary...')
                        summary_str = sess.run(summary)
                        writer.add_summary(summary_str, step)
                        writer.flush()
                    """checkpoint"""
                    if step % 1000 == 0:
                        saver.save(sess,
                                   os.path.join(training_path,
                                                'fast-style-model.ckpt'),
                                   global_step=step)
            except tf.errors.OutOfRangeError:
                saver.save(
                    sess,
                    os.path.join(training_path, 'fast-style-model.ckpt-done'))
                print('Done training -- epoch limit reached')
            finally:
                coord.request_stop()
            coord.join(threads)
Example #7
0
def train(**kwargs):
    opt = Config()
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    '''数据加载'''
    # 读取文件名
    filenames = [os.path.join(opt.data_root, f)
                 for f in os.listdir(opt.data_root)
                 if os.path.isfile(os.path.join(opt.data_root, f))]
    # 判断文件格式,png为True,jpeg为False
    png = filenames[0].lower().endswith('png')  # If first file is a png, assume they all are
    # 维持文件名队列
    filename_queue = tf.train.string_input_producer(filenames, shuffle=True, num_epochs=opt.epoches)
    # 初始化阅读器
    reader = tf.WholeFileReader()
    # 返回tuple,是key-value对
    _, img_bytes = reader.read(filename_queue)
    # 图片格式解码
    image_row = tf.image.decode_png(img_bytes, channels=3) if png else tf.image.decode_jpeg(img_bytes, channels=3)
    # 预处理
    image = utils.img_proprocess(image_row, opt.image_size)
    image_batch = tf.train.batch([image], opt.batch_size, dynamic_pad=True)

    '''生成式网络生成数据'''
    generated = net(image_batch, training=True)
    generated = tf.image.resize_bilinear(generated, [opt.image_size, opt.image_size], align_corners=False)
    generated.set_shape([opt.batch_size, opt.image_size, opt.image_size, 3])
    # unstack将指定维度拆分为1后降维,split随意指定拆分后维度值且不会自动降维
    # processed_generated = tf.stack([utils.img_proprocess(tf.squeeze(img, axis=0), opt.image_size)
    #                                 for img in tf.split(generated, num_or_size_splits=opt.batch_size, axis=0)])
    processed_generated = tf.stack([utils.img_proprocess(img, opt.image_size) for img in tf.unstack(generated, axis=0)])

    '''数据流经损失网络_VGG'''
    # 一次送入数据量为2×batch_size:[原始batch经生成式网络生成的数据 + 原始batch]
    with slim.arg_scope(vgg.vgg_arg_scope(weight_decay=0.0)):  # 调用
        _, endpoint = vgg.vgg_16(tf.concat([processed_generated, image_batch], 0),
                                 num_classes=1, is_training=False, spatial_squeeze=False)
    tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):')
    for key in endpoint:
        tf.logging.info(key)

    '''损失函数构建'''
    style_gram = utils.get_style_feature(opt.style_path,
                                         opt.image_size,
                                         opt.style_layers,
                                         opt.model_path,
                                         opt.exclude_scopes)
    content_loss, content_loss_summary = losses.content_loss(endpoint, opt.content_layers)
    style_loss, style_loss_summary = losses.style_loss(endpoint, style_gram, opt.style_layers)
    tv_loss = losses.total_variation_loss(generated)  # use the unprocessed image, 我们想要的图像也是这个
    loss = opt.style_weight * style_loss + opt.content_weight * content_loss + opt.tv_weight * tv_loss

    '''优化器构建'''
    # 优化器维护非vgg16的可训练变量
    variables_to_train = []
    for variable in tf.trainable_variables():
        if not (variable.name.startswith("vgg_16")):  # "vgg16"
            variables_to_train.append(variable)

    global_step = tf.Variable(0, name="global_step", trainable=False)
    train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variables_to_train)

    '''存储器构建'''
    # 存储器保存非vgg16的全局变量

    # tf.global_variables():返回全局变量。
    # 全局变量是分布式环境中跨计算机共享的变量。该Variable()构造函数或get_variable()
    # 自动将新变量添加到图形集合:GraphKeys.GLOBAL_VARIABLES。这个方便函数返回该集合的内容。
    # 全局变量的替代方法是局部变量。参考:tf.local_variables
    variables_to_restore = []  # 比trainable多出的主要是用于bp的变量
    for variable in tf.global_variables():
        if not (variable.name.startswith("vgg_16")):  # "vgg16"
            variables_to_restore.append(variable)

    saver = tf.train.Saver(var_list=variables_to_restore, write_version=tf.train.SaverDef.V2)

    """添加监测项"""
    # 添加总体loss监测
    tf.summary.scalar('losses/content_loss', content_loss)
    tf.summary.scalar('losses/style_loss', style_loss)
    tf.summary.scalar('losses/regularizer_loss', tv_loss)
    tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * opt.content_weight)
    tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * opt.style_weight)
    tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * opt.tv_weight)
    tf.summary.scalar('total_loss', loss)
    # 添加各层style loss监测
    for layer in opt.style_layers:
        tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer])
    # 监测生成图监测
    tf.summary.image('generated', generated)
    # tf.image_summary('processed_generated', processed_generated)  # May be better?
    # 添加原图监测
    tf.summary.image('origin', image_batch)

    summary_path = "./logs"
    model_path = "./logs/model"
    summary = tf.summary.merge_all()
    # with open('train_v.txt', 'w') as f:
    #     for s in variable_to_train:
    #         f.write(s.name + '\n')
    # with open('restore_v.txt', 'w') as f:
    #     for s in variables_to_restore:
    #         f.write(s.name + '\n')

    '''训练'''
    with tf.Session(config=config) as sess:
        writer = tf.summary.FileWriter(summary_path, sess.graph) 
        sess.run(tf.group(tf.global_variables_initializer(),
                          tf.local_variables_initializer()))

        # vgg网络预训练参数载入
        param_init_fn = utils.param_load_fn(opt.model_path, opt.exclude_scopes)
        param_init_fn(sess)

        # 由于使用saver,故载入变量不包含vgg16相关变量
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        ckpt = tf.train.get_checkpoint_state(model_path)
        if ckpt:
            tf.logging.info("Success to read {}".format(ckpt.model_checkpoint_path))
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            tf.logging.info("Failed to find a checkpoint")

        coord = tf.train.Coordinator()  # 线程控制器
        threads = tf.train.start_queue_runners(coord=coord)  # 启动队列
        start_time = time.time()  # 计时开始
        try:
            while not coord.should_stop():
                _, loss_t, step = sess.run([train_op, loss, global_step])
                elapsed_time = time.time() - start_time
                start_time = time.time()

                if step % 50 == 0:
                    tf.logging.info('step: {0:d}, total Loss {1:.2f}, secs/step: {2:.3f}'.
                                    format(step, loss_t, elapsed_time))
                if step % 100 == 0:
                    tf.logging.info('adding summary...')
                    summary_str = sess.run(summary)
                    writer.add_summary(summary_str, step)
                    writer.flush()
                if step % 1000 == 0:
                    saver.save(sess, os.path.join(model_path, 'fast_style_model'), global_step=step)

        except tf.errors.OutOfRangeError:
            saver.save(sess, os.path.join(model_path, 'fast_style_model'))
            tf.logging.info('Epoch limit reached')
        finally:
            coord.request_stop()
        coord.join(threads)
        writer.close()

    '''调试输出'''
def main(FLAGS):
    # 得到风格特征
    style_features_t = losses.get_style_features(FLAGS)

    # Make sure the training path exists.
    training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
    if not (os.path.exists(training_path)):
        os.makedirs(training_path)

    with tf.Graph().as_default():
        with tf.Session() as sess:
            """Build Network"""
            # 构造vgg网络,按照FLAGS.loss_model中的网络名字,可以在/nets/nets_factory.py 中的networks_map找到对应
            network_fn = nets_factory.get_network_fn(FLAGS.loss_model,
                                                     num_classes=1,
                                                     is_training=False)
            # 根据不同网络做不同的预处理
            image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
                FLAGS.loss_model, is_training=False)
            # 读取一个批次的数据,并且预处理
            # 这里的数据你可以不用coco,可以直接给一个包含很多图片的文件夹即可
            # 因为coco过于大
            processed_images = reader.image(FLAGS.batch_size,
                                            FLAGS.image_height,
                                            FLAGS.image_width,
                                            'F:/CASIA/train_frame/real/',
                                            image_preprocessing_fn,
                                            epochs=FLAGS.epoch)
            # 通过生成网络,生成图片,相当于y^
            generated = model.net(processed_images, training=True)
            # 因为一会要把生成图片喂入到后面vgg进行计算两个损失,所以要先进行预处理
            processed_generated = [
                image_preprocessing_fn(image, FLAGS.image_height,
                                       FLAGS.image_width) for image in
                tf.unstack(generated, axis=0, num=FLAGS.batch_size)
            ]
            # 因为上面是list格式,所以用tf.stack堆叠成tensor
            processed_generated = tf.stack(processed_generated)
            # 按照batch那一个维度,拼起来,比如原来两个是[batch_size,h,w,c],concat后变为[2*batch_size,h,w,c]
            # 这样一次前向传播把y^ 和y_c的特征都计算出来了
            _, endpoints_dict = network_fn(tf.concat(
                [processed_generated, processed_images], 0),
                                           spatial_squeeze=False)

            # Log the structure of loss network
            tf.logging.info(
                'Loss network layers(You can define them in "content_layers" and "style_layers"):'
            )
            for key in endpoints_dict:
                tf.logging.info(key)
            """Build Losses"""
            # 计算三个损失
            content_loss = losses.content_loss(endpoints_dict,
                                               FLAGS.content_layers)
            style_loss, style_loss_summary = losses.style_loss(
                endpoints_dict, style_features_t, FLAGS.style_layers)
            tv_loss = losses.total_variation_loss(
                generated)  # use the unprocessed image

            loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss

            # Add Summary for visualization in tensorboard.
            """Add Summary"""
            # 为了tensorboard,可以忽略
            tf.summary.scalar('losses/content_loss', content_loss)
            tf.summary.scalar('losses/style_loss', style_loss)
            tf.summary.scalar('losses/regularizer_loss', tv_loss)

            tf.summary.scalar('weighted_losses/weighted_content_loss',
                              content_loss * FLAGS.content_weight)
            tf.summary.scalar('weighted_losses/weighted_style_loss',
                              style_loss * FLAGS.style_weight)
            tf.summary.scalar('weighted_losses/weighted_regularizer_loss',
                              tv_loss * FLAGS.tv_weight)
            tf.summary.scalar('total_loss', loss)

            for layer in FLAGS.style_layers:
                tf.summary.scalar('style_losses/' + layer,
                                  style_loss_summary[layer])
            tf.summary.image('generated', generated)
            # tf.image_summary('processed_generated', processed_generated)  # May be better?
            tf.summary.image(
                'origin',
                tf.stack([
                    image_unprocessing_fn(image) for image in tf.unstack(
                        processed_images, axis=0, num=FLAGS.batch_size)
                ]))
            summary = tf.summary.merge_all()
            writer = tf.summary.FileWriter(training_path)
            """Prepare to Train"""
            # 步数
            global_step = tf.Variable(0, name="global_step", trainable=False)

            variable_to_train = []
            for variable in tf.trainable_variables():
                # 把非vgg网络里面的可训练变量加入variable_to_train
                if not (variable.name.startswith(FLAGS.loss_model)):
                    variable_to_train.append(variable)
            # 注意var_list
            train_op = tf.train.AdamOptimizer(1e-3).minimize(
                loss, global_step=global_step, var_list=variable_to_train)

            variables_to_restore = []
            for v in tf.global_variables():
                # 把非vgg中的可存储变量加入variables_to_restore
                if not (v.name.startswith(FLAGS.loss_model)):
                    variables_to_restore.append(v)

            # 注意variables_to_restore
            saver = tf.train.Saver(variables_to_restore,
                                   write_version=tf.train.SaverDef.V1)

            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer()
            ])

            # Restore variables for loss network.
            # slim的,可以根据FLAGS里面配置把网络参数加载到sess这个会话里面
            init_func = utils._get_init_fn(FLAGS)
            init_func(sess)

            # Restore variables for training model if the checkpoint file exists.
            last_file = tf.train.latest_checkpoint(training_path)
            if last_file:
                tf.logging.info('Restoring model from {}'.format(last_file))
                saver.restore(sess, last_file)
            """Start Training"""
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            start_time = time.time()
            try:
                while not coord.should_stop():
                    _, loss_t, step = sess.run([train_op, loss, global_step])
                    elapsed_time = time.time() - start_time
                    start_time = time.time()
                    """logging"""
                    print(step)
                    if step % 10 == 0:
                        tf.logging.info(
                            'step: %d,  total Loss %f, secs/step: %f' %
                            (step, loss_t, elapsed_time))
                    """summary"""
                    if step % 25 == 0:
                        tf.logging.info('adding summary...')
                        summary_str = sess.run(summary)
                        writer.add_summary(summary_str, step)
                        writer.flush()
                    """checkpoint"""
                    if step % 1000 == 0:
                        saver.save(sess,
                                   os.path.join(training_path,
                                                'fast-style-model.ckpt'),
                                   global_step=step)
            except tf.errors.OutOfRangeError:
                saver.save(
                    sess,
                    os.path.join(training_path, 'fast-style-model.ckpt-done'))
                tf.logging.info('Done training -- epoch limit reached')
            finally:
                coord.request_stop()
            coord.join(threads)
Example #9
0
            layer_features = outputs_dict[layer_name]
            combination_features = layer_features[0, :, :, :]
            sl = losses.mrf_loss(ap_image_features,
                                 combination_features,
                                 patch_size=patch_size,
                                 patch_stride=patch_stride)
            loss += (mrf_weight / len(mrf_layers)) * sl

    if b_bp_content_weight != 0.0:
        for layer_name in b_content_layers:
            b_features = K.variable(all_b_features[layer_name][0])
            bp_features = outputs_dict[layer_name]
            cl = losses.content_loss(bp_features, b_features)
            loss += b_bp_content_weight / len(b_content_layers) * cl

    loss += total_variation_weight * losses.total_variation_loss(
        vgg_input, img_width, img_height)

    # get the gradients of the generated image wrt the loss
    grads = K.gradients(loss, vgg_input)

    outputs = [loss]
    if type(grads) in {list, tuple}:
        outputs += grads
    else:
        outputs.append(grads)

    f_outputs = K.function([vgg_input], outputs)

    def eval_loss_and_grads(x):
        x = x.reshape((1, 3, img_height, img_width))
        outs = f_outputs([x])
def main(FLAGS):
    style_features_t = losses.get_style_features(FLAGS)

    # Make sure the training path exists.
    training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
    if not(os.path.exists(training_path)):
        os.makedirs(training_path)

    with tf.Graph().as_default():
        with tf.Session() as sess:
            """Build Network"""
            network_fn = nets_factory.get_network_fn(
                FLAGS.loss_model,
                num_classes=1,
                is_training=False)

            image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
                FLAGS.loss_model,
                is_training=False)
            processed_images = reader.image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size,
                                            'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch)
            generated = model.net(processed_images, training=True)
            processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
                                   for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size)
                                   ]
            processed_generated = tf.stack(processed_generated)
            _, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False)

            # Log the structure of loss network
            tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):')
            for key in endpoints_dict:
                tf.logging.info(key)

            """Build Losses"""
            content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers)
            style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers)
            tv_loss = losses.total_variation_loss(generated)  # use the unprocessed image

            loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss

            # Add Summary for visualization in tensorboard.
            """Add Summary"""
            tf.summary.scalar('losses/content_loss', content_loss)
            tf.summary.scalar('losses/style_loss', style_loss)
            tf.summary.scalar('losses/regularizer_loss', tv_loss)

            tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * FLAGS.content_weight)
            tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * FLAGS.style_weight)
            tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * FLAGS.tv_weight)
            tf.summary.scalar('total_loss', loss)

            for layer in FLAGS.style_layers:
                tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer])
            tf.summary.image('generated', generated)
            # tf.image_summary('processed_generated', processed_generated)  # May be better?
            tf.summary.image('origin', tf.stack([
                image_unprocessing_fn(image) for image in tf.unstack(processed_images, axis=0, num=FLAGS.batch_size)
            ]))
            summary = tf.summary.merge_all()
            writer = tf.summary.FileWriter(training_path)

            """Prepare to Train"""
            global_step = tf.Variable(0, name="global_step", trainable=False)

            variable_to_train = []
            for variable in tf.trainable_variables():
                if not(variable.name.startswith(FLAGS.loss_model)):
                    variable_to_train.append(variable)
            train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train)

            variables_to_restore = []
            for v in tf.global_variables():
                if not(v.name.startswith(FLAGS.loss_model)):
                    variables_to_restore.append(v)
            saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1)

            sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

            # Restore variables for loss network.
            init_func = utils._get_init_fn(FLAGS)
            init_func(sess)

            # Restore variables for training model if the checkpoint file exists.
            last_file = tf.train.latest_checkpoint(training_path)
            if last_file:
                tf.logging.info('Restoring model from {}'.format(last_file))
                saver.restore(sess, last_file)

            """Start Training"""
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            start_time = time.time()
            try:
                while not coord.should_stop():
                    _, loss_t, step = sess.run([train_op, loss, global_step])
                    elapsed_time = time.time() - start_time
                    start_time = time.time()
                    """logging"""
                    # print(step)
                    if step % 10 == 0:
                        tf.logging.info('step: %d,  total Loss %f, secs/step: %f' % (step, loss_t, elapsed_time))
                    """summary"""
                    if step % 25 == 0:
                        tf.logging.info('adding summary...')
                        summary_str = sess.run(summary)
                        writer.add_summary(summary_str, step)
                        writer.flush()
                    """checkpoint"""
                    if step % 1000 == 0:
                        saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt'), global_step=step)
            except tf.errors.OutOfRangeError:
                saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done'))
                tf.logging.info('Done training -- epoch limit reached')
            finally:
                coord.request_stop()
            coord.join(threads)
Example #11
0
def main(FLAGS):
    style_features_t = losses.get_style_features(FLAGS) #style target的Gram

    #make sure the training path exists
    training_path = os.path.join(FLAGS.model_path,FLAGS.naming) #model/wave/ ;用于存放训练好的模型
    if not (os.path.exists(training_path)):
        os.makedirs(training_path)

    with tf.Graph().as_default(): #默认计算图
        with tf.Session() as sess:#没有as_default(),因此,走出with 语句,sess停止执行,不能在被用
            """build loss network"""
            network_fn =nets_factory.get_network_fn(FLAGS.loss_model,num_classes=1,is_training=False) #取出loss model,且该model不用训练
            #对要进入loss_model的content_image,和generated_image进行preprocessing
            image_preprocessing_fn,image_unpreprocessing_fn = preprocessing_factory.get_preprocessing(FLAGS.loss_model,is_training=False) #取出用于loss_model的,对image进行preprocessing和unpreprocessing的function
            processed_image = reader.image(FLAGS.batch_size,FLAGS.image_size,FLAGS.image_size,'train2014/',image_preprocessing_fn,epochs=FLAGS.epoch) #这里要preprocessing的image是一个batch,为training_data
            generated = model.net(processed_images,training=True) #输入“图像生成网络”的image为经过preprocessing_image,“图像生成网络”为要训练的网络
            processed_generated = [image_preprocessing_fn(image,FLAGS.image_size,FLAGS.image_size) for image in tf.unstack(generated,axis=0,num=FLAGS.batch_size)]
            processed_generated = tf.stack(processed_generated)
            #计算generated_image和content_image进入loss_model后,更layer的output
            _,endpoints_dict= network_fn(tf.concat([processed_generated,processed_images],0),spatial_squeeze=False)#endpoints_dict中存储的是2类image各个layer的值
            #log the structure of loss network
            tf.logging.info('loss network layers(you can define them in "content layer" and "style layer"):')
            for key in endpoints_dict:
                tf.logging.info(key) #屏幕输出loss_model的各个layer name

            """build losses"""
            content_loss = losses.content_loss(endpoints_dict,FLAGS.content_layers)
            style_loss,style_loss_summary = losses.style_loss(endpoints_dict,style_features_t,FLAGS.style_layers)
            tv_loss = losses.total_variation_loss(generated)

            loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss

            # Add Summary for visualization in tensorboard
            """Add Summary"""
            tf.summary.scalar('losses/content_loss',content_loss)
            tf.summary.scalar('losses/style_loss',style_loss)
            tf.summary.scalar('losses/regularizer_loss',tv_loss)

            tf.summary.scalar('weighted_losses/weighted content_loss',content_loss * FLAGS.content_weight)
            tf.summary.scalar('weighted_losses/weighted style_loss',style_loss * FLAGS.style_weight)
            tf.summary.scalar('weighted_losses/weighted_regularizer_loss',tv_loss * FLAGS.tv_weight)
            tf.summary.scalar('total_loss',loss)

            for layer in FLAGS.style_layers:
                tf.summary.scalar('style_losses/' + layer,style_loss_summary[layer])
            tf.summary.image('genearted',generated)
            tf.summary.image('origin',tf.stack([image_unprocessing_fn(image) for image in tf.unstack(processed_images,axis=0,num=FLAGS.batch_size)]))
            summary = tf.summary.merge_all()
            writer = tf.summary.FileWriter(training_path)

            """prepare to train"""
            global_step = tf.Variable(0,name='global_step',trainable=False)#iteration step

            variable_to_train = []#需要训练的变量
            for variable in tf.trainable_variables():#在图像风格迁移网络(图像生成网络+损失网络)各参数中,找需要训练的参数
                if not (variable.name.startswith(FLAGS.loss_model)):
                    variable_to_train.append(variable)
            train_op = tf.train.AdamOptimizer(1e-3).minimize(loss,global_step = global_step,var_list = variable_to_train) #需要放入sess.run()

            variable_to_restore = []#在所有的全局变量中,找需要恢复默认设置的变量; 注意:local_variable指的是一些临时变量和中间变量,用于线程中,线程结束则消失
            for v tf.global_variables():
                if not (v.name.startswith(FLAGS.loss_model)):
                    variables_to_restore.append(v)
            saver = tf.train.Saver(variables_to_restore,write_version=tf.train.SaverDef.V1)#利用saver.restore()恢复默认设置;这里的variable_to_restore,是需要save and restore的var_list

            sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])#对全局变量和局部变量进行初始化操作:即恢复默认设置

            #restore variables for loss model 恢复loss model中的参数
            init_func = utils._get_init_fn(FLAGS)
            init_func(sess)

            #restore variables for training model if the checkpoint file exists. 如果training_model已有训练好的参数,将其载入
            last_file = tf.train.latest_checkpoint(training_path)#将train_path中的model参数数据取出
            if last_file:
                tf.logging.info('restoringmodel from {}'.format(last_file))
                saver.restore(sess,last_file) #那如果last_file不存在,就不执行restore操作吗?需要restore的参数只是图像生成网络吗?

            """start training"""
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            start_time = time.time()
            try:
                while not coord.should_stop():#查看线程是否停止(即:是否所有数据均运行完毕)
                    _,loss_t,step = sess.run([train_op,loss,global_step])
                    elapsed_time = time.time()
                    """logging"""
                    #print(step)
                    if step % 10 == 0:
                        tf.logging.info('step:%d, total loss %f, secs/step: %f' % (step,loss_t,elapsed_time))
                    """summary"""
                    if step % 25 == 0:
                        tf.logging.info('adding summary...')
                        summary_str = sess.run(summary)
                        writer.add_summary(summary_str,step)
                        writer.flush()
                    """checkpoint"""
                    if step % 1000 == 0:
                        saver.save(sess,os.path.join(training_path,'fast-style-model.ckpt'),global_step=step)#保存variable_to_restore中的参数值
            except tf.errors.OutOfRangeError:
                saver.save(sess,os.path.join(training_path,'fast-style-model.ckpt-done'))
                tf.logging.info('Done training -- epoch limit reached')
            finally:
                coord.request_stop()#要求停止所有线程
            coord.join(threads)#将线程并入主线程,删除
Example #12
0
        for layer_name in mrf_layers:
            ap_image_features = K.variable(all_ap_image_features[layer_name][0])
            layer_features = outputs_dict[layer_name]
            combination_features = layer_features[0, :, :, :]
            sl = losses.mrf_loss(ap_image_features, combination_features,
                patch_size=patch_size, patch_stride=patch_stride)
            loss += (mrf_weight / len(mrf_layers)) * sl

    if b_bp_content_weight != 0.0:
        for layer_name in b_content_layers:
            b_features = K.variable(all_b_features[layer_name][0])
            bp_features = outputs_dict[layer_name]
            cl = losses.content_loss(bp_features, b_features)
            loss += b_bp_content_weight / len(b_content_layers) * cl

    loss += total_variation_weight * losses.total_variation_loss(vgg_input, img_width, img_height)

    # get the gradients of the generated image wrt the loss
    grads = K.gradients(loss, vgg_input)

    outputs = [loss]
    if type(grads) in {list, tuple}:
        outputs += grads
    else:
        outputs.append(grads)

    f_outputs = K.function([vgg_input], outputs)
    evaluator = Evaluator()

    # run scipy-based optimization (L-BFGS) over the pixels of the generated image
    # so as to minimize the neural style loss
Example #13
0
def train():

    style_feature, style_grams = get_style_feature()

    with tf.Graph().as_default():
        with tf.Session() as sess:

            #loss_input_style = tf.placeholder(dtype = tf.float32 , shape = [args.batch , args.size , args.size , args.in_dim ])
            #loss_input_target =tf.placeholder(dtype = tf.float32 , shape = [args.batch , args.size , args.size , args.in_dim ])

            # For online optimization problem, use testing preprocess for both train and test
            preprocess_func, unprocess_func = preprocessing.preprocessing_factory.get_preprocessing(
                args.loss_model, is_training=False)



            images = reader.image(args.batch, args.size , args.size, args.target_dir , preprocess_func, \
                                 args.epoch , shuffle = True)

            model = transform(sess, args)
            transformed_images = model.generator(images, reuse=False)

            #print('qqq')
            #print( tf.shape(transformed_images).eval())

            unprocess_transform = [(img) for img in tf.unstack(
                transformed_images, axis=0, num=args.batch)]

            processed_generated = [
                preprocess_func(img, args.size, args.size)
                for img in unprocess_transform
            ]
            processed_generated = tf.stack(processed_generated)

            loss_model = nets.nets_factory.get_network_fn(args.loss_model,
                                                          num_classes=1,
                                                          is_training=False)

            pair = tf.concat([processed_generated, images], axis=0)
            _, end_dicts = loss_model(pair, spatial_squeeze=False)

            init_loss_model = load_pretrained_weight(args.loss_model)

            c_loss = losses.content_loss(
                end_dicts, loss_config.content_loss_dict[args.loss_model])

            s_loss, s_loss_sum = losses.style_loss(
                end_dicts, loss_config.style_loss_dict[args.loss_model],
                style_grams)

            tv_loss = losses.total_variation_loss(transformed_images)

            loss = args.c_weight * c_loss + args.s_weight * s_loss + args.tv_weight * tv_loss

            print('shapes')
            print(pair.get_shape())

            #tf.summary.scalar('average', tf.reduce_mean(images))
            #tf.summary.scalar('gram average', tf.reduce_mean(tf.stack(style_feature)))

            tf.summary.scalar('losses/content_loss', c_loss)
            tf.summary.scalar('losses/style_loss', s_loss)
            tf.summary.scalar('losses/tv_loss', tv_loss)

            tf.summary.scalar('weighted_losses/weighted_content_loss',
                              c_loss * args.c_weight)
            tf.summary.scalar('weighted_losses/weighted_style_loss',
                              s_loss * args.s_weight)
            tf.summary.scalar('weighted_losses/weighted_regularizer_loss',
                              tv_loss * args.tv_weight)
            tf.summary.scalar('total_loss', loss)

            for layer in loss_config.style_loss_dict[args.loss_model]:
                tf.summary.scalar('style_losses/' + layer, s_loss_sum[layer])

            tf.summary.image('transformed',
                             tf.stack(unprocess_transform, axis=0))
            # tf.image_summary('processed_generated', processed_generated)  # May be better?
            tf.summary.image(
                'ori',
                tf.stack([
                    unprocess_func(image)
                    for image in tf.unstack(images, axis=0, num=args.batch)
                ]))

            summary = tf.summary.merge_all()
            writer = tf.summary.FileWriter(args.log_dir)

            step = tf.Variable(0, name='global_step', trainable=False)

            all_trainables = tf.trainable_variables()
            all_vars = tf.global_variables()
            to_train = [
                var for var in all_trainables
                if not args.loss_model in var.name
            ]
            to_restore = [
                var for var in all_vars if not args.loss_model in var.name
            ]


            optim = tf.train.AdamOptimizer( 1e-3 ).minimize(\
                                           loss = loss , var_list = to_train , global_step = step)

            saver = tf.train.Saver(to_restore)
            style_name = (args.style_dir.split('/')[-1]).split('.')[0]

            ckpt = tf.train.latest_checkpoint(
                os.path.join(args.ckpt_dir, style_name))
            if ckpt:
                tf.logging.info('Restoring model from {}'.format(ckpt))
                saver.restore(sess, ckpt)

            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer()
            ])
            #sess.run(init_loss_model)
            init_loss_model(sess)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            start_time = time()
            #i = 0
            try:
                while True:

                    _, gs, sum_info, c_info, s_info, tv_info, loss_info = sess.run(
                        [optim, step, summary, c_loss, s_loss, tv_loss, loss])
                    writer.add_summary(sum_info, gs)
                    elapsed = time() - start_time

                    print(gs)

                    if gs % 10 == 0:
                        tf.logging.info(
                            'step: %d, c_loss %f  s_loss %f  tv_loss %f total Loss %f, secs/step: %f'
                            %
                            (gs, c_info, s_info, tv_info, loss_info, elapsed))

                    if gs % args.save_freq == 0:
                        saver.save(
                            sess,
                            os.path.join(args.ckpt_dir, style_name,
                                         style_name + '.ckpt'))

            except tf.errors.OutOfRangeError:
                print('run out of images!  save final model: ' +
                      os.path.join(args.ckpt_dir, style_name + '.ckpt-done'))
                saver.save(
                    sess,
                    os.path.join(args.ckpt_dir, style_name,
                                 style_name + '.ckpt-done'))
                tf.logging.info('Done -- file ran out of range')
            finally:
                coord.request_stop()

            coord.join(threads)

            print('end training')
            '''
Example #14
0
def main(style_img_path: str,
         content_img_path: str, 
         img_dim: int,
         num_iter: int,
         style_weight: int,
         content_weight: int,
         variation_weight: int,
         print_every: int,
         save_every: int):

    assert style_img_path is not None
    assert content_img_path is not None

    # define the device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # read the images
    style_img = Image.open(style_img_path)
    cont_img = Image.open(content_img_path)
    
    # define the transform
    transform = transforms.Compose([transforms.Resize((img_dim, img_dim)),
                                    transforms.ToTensor(), 
                                    transforms.Normalize([0.485, 0.456, 0.406],
                                                         [0.229, 0.224, 0.225])])
    
    # get the tensor of the image
    content_image = transform(cont_img).unsqueeze(0).to(device)
    style_image = transform(style_img).unsqueeze(0).to(device)
    
    # init the network
    vgg = VGG().to(device).eval()
    
    # replace the MaxPool with the AvgPool layers
    for name, child in vgg.vgg.named_children():
        if isinstance(child, nn.MaxPool2d):
            vgg.vgg[int(name)] = nn.AvgPool2d(kernel_size=2, stride=2)
            
    # lock the gradients
    for param in vgg.parameters():
        param.requires_grad = False
    
    # get the content activations of the content image and detach them from the graph
    content_activations = vgg.get_content_activations(content_image).detach()
    
    # unroll the content activations
    content_activations = content_activations.view(512, -1)
    
    # get the style activations of the style image
    style_activations = vgg.get_style_activations(style_image)
    
    # for every layer in the style activations
    for i in range(len(style_activations)):

        # unroll the activations and detach them from the graph
        style_activations[i] = style_activations[i].squeeze().view(style_activations[i].shape[1], -1).detach()

    # calculate the gram matrices of the style image
    style_grams = [gram(style_activations[i]) for i in range(len(style_activations))]
    
    # generate the Gaussian noise
    noise = torch.randn(1, 3, img_dim, img_dim, device=device, requires_grad=True)
    
    # define the adam optimizer
    # pass the feature map pixels to the optimizer as parameters
    adam = optim.Adam(params=[noise], lr=0.01, betas=(0.9, 0.999))

    # run the iteration
    for iteration in range(num_iter):

        # zero the gradient
        adam.zero_grad()

        # get the content activations of the Gaussian noise
        noise_content_activations = vgg.get_content_activations(noise)

        # unroll the feature maps of the noise
        noise_content_activations = noise_content_activations.view(512, -1)

        # calculate the content loss
        content_loss_ = content_loss(noise_content_activations, content_activations)

        # get the style activations of the noise image
        noise_style_activations = vgg.get_style_activations(noise)

        # for every layer
        for i in range(len(noise_style_activations)):

            # unroll the the noise style activations
            noise_style_activations[i] = noise_style_activations[i].squeeze().view(noise_style_activations[i].shape[1], -1)

        # calculate the noise gram matrices
        noise_grams = [gram(noise_style_activations[i]) for i in range(len(noise_style_activations))]

        # calculate the total weighted style loss
        style_loss = 0
        for i in range(len(style_activations)):
            N, M = noise_style_activations[i].shape[0], noise_style_activations[i].shape[1]
            style_loss += (gram_loss(noise_grams[i], style_grams[i], N, M) / 5.)

        # put the style loss on device
        style_loss = style_loss.to(device)
            
        # calculate the total variation loss
        variation_loss = total_variation_loss(noise).to(device)

        # weight the final losses and add them together
        total_loss = content_weight * content_loss_ + style_weight * style_loss + variation_weight * variation_loss

        if iteration % print_every == 0:
            print("Iteration: {}, Content Loss: {:.3f}, Style Loss: {:.3f}, Var Loss: {:.3f}".format(iteration, 
                                                                                                     content_weight * content_loss_.item(),
                                                                                                     style_weight * style_loss.item(), 
                                                                                                     variation_weight * variation_loss.item()))

        # create the folder for the generated images
        if not os.path.exists('./generated/'):
            os.mkdir('./generated/')
        
        # generate the image
        if iteration % save_every == 0:
            save_image(noise.cpu().detach(), filename='./generated/iter_{}.png'.format(iteration))

        # backprop
        total_loss.backward()
        
        # update parameters
        adam.step()
Example #15
0
def main(FLAGS):
    style_features_t = losses.get_style_features(FLAGS)
    training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
    if not (os.path.exists(training_path)):
        os.makedirs(training_path)

    with tf.Graph().as_default():
        with tf.Session() as sess:
            """创建Network"""
            network_fn = nets_factory.get_network_fn(
                FLAGS.loss_model,
                num_classes=1,
                is_training=False)

            image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
                FLAGS.loss_model,
                is_training=False)

            """训练图片预处理"""
            processed_images = reader.batch_image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size,
                                                  'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch)
            generated = model.transform_network(processed_images, training=True)
            processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
                                   for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size)
                                   ]
            processed_generated = tf.stack(processed_generated)
            _, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False)
            tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):')
            for key in endpoints_dict:
                tf.logging.info(key)

            """创建 Losses"""
            content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers)
            style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers)
            tv_loss = losses.total_variation_loss(generated)  # use the unprocessed image

            loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss

            """准备训练"""
            global_step = tf.Variable(0, name="global_step", trainable=False)
            variable_to_train = []
            for variable in tf.trainable_variables():
                # 只训练和保存生成网络中的变量
                if not (variable.name.startswith(FLAGS.loss_model)):
                    variable_to_train.append(variable)

            """优化"""
            train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train)

            variables_to_restore = []
            for v in tf.global_variables():
                if not (v.name.startswith(FLAGS.loss_model)):
                    variables_to_restore.append(v)
            saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1)
            sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
            init_func = utils._get_init_fn(FLAGS)
            init_func(sess)
            last_file = tf.train.latest_checkpoint(training_path)
            if last_file:
                tf.logging.info('Restoring model from {}'.format(last_file))
                saver.restore(sess, last_file)

            """开始训练"""
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            start_time = time.time()
            try:
                while not coord.should_stop():
                    _, loss_t, step = sess.run([train_op, loss, global_step])
                    elapsed_time = time.time() - start_time
                    start_time = time.time()
                    if step % 10 == 0:
                        tf.logging.info(
                            'step: %d,  total Loss %f, secs/step: %f,%s' % (step, loss_t, elapsed_time, time.asctime()))
                    """checkpoint"""
                    if step % 50 == 0:
                        tf.logging.info('saving check point...')
                        saver.save(sess, os.path.join(training_path, FLAGS.naming + '.ckpt'), global_step=step)
            except tf.errors.OutOfRangeError:
                saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done'))
                tf.logging.info('Done training -- epoch limit reached')
            finally:
                coord.request_stop()
                tf.logging.info('coordinator stop')
            coord.join(threads)
def run_stroke_style_transfer(num_steps=100,
                              style_weight=3.,
                              content_weight=1.,
                              tv_weight=0.008,
                              curv_weight=4):
    vgg_loss = losses.StyleTransferLosses(vgg_weight_file,
                                          content_img,
                                          style_img,
                                          bs_content_layers,
                                          bs_style_layers,
                                          scale_by_y=True)
    vgg_loss.to(device).eval()

    # brush stroke init
    bs_renderer = BrushStrokeRenderer(canvas_height,
                                      canvas_width,
                                      num_strokes,
                                      samples_per_curve,
                                      brushes_per_pixel,
                                      canvas_color,
                                      length_scale,
                                      width_scale,
                                      content_img=content_img[0].permute(
                                          1, 2, 0).cpu().numpy())
    bs_renderer.to(device)

    optimizer = optim.Adam([
        bs_renderer.location, bs_renderer.curve_s, bs_renderer.curve_e,
        bs_renderer.curve_c, bs_renderer.width
    ],
                           lr=1e-1)
    optimizer_color = optim.Adam([bs_renderer.color], lr=1e-2)

    logger.info('Optimizing brushstroke-styled canvas..')
    for _ in mon.iter_batch(range(num_steps)):
        optimizer.zero_grad()
        optimizer_color.zero_grad()
        input_img = bs_renderer()
        input_img = input_img[None].permute(0, 3, 1, 2).contiguous()
        content_score, style_score = vgg_loss(input_img)

        style_score *= style_weight
        content_score *= content_weight
        tv_score = tv_weight * losses.total_variation_loss(
            bs_renderer.location,
            bs_renderer.curve_s,
            bs_renderer.curve_e,
            K=10)
        curv_score = curv_weight * losses.curvature_loss(
            bs_renderer.curve_s, bs_renderer.curve_e, bs_renderer.curve_c)
        loss = style_score + content_score + tv_score + curv_score
        loss.backward(inputs=[
            bs_renderer.location, bs_renderer.curve_s, bs_renderer.curve_e,
            bs_renderer.curve_c, bs_renderer.width
        ],
                      retain_graph=True)
        optimizer.step()
        style_score.backward(inputs=[bs_renderer.color])
        optimizer_color.step()

        # plot some stuffs
        mon.plot('stroke style loss', style_score.item())
        mon.plot('stroke content loss', content_score.item())
        mon.plot('stroke tv loss', tv_score.item())
        mon.plot('stroke curvature loss', curv_score.item())
        if mon.iter % mon.print_freq == 0:
            mon.imwrite('stroke stylized', input_img)

    with T.no_grad():
        return bs_renderer()
Example #17
0
def get_optimal_image(model_fn,
                      model_kwargs,
                      checkpoint_path,
                      params,
                      preproc=True,
                      preproc_params=None,
                      layer_name=None,
                      image_resolution=128):
    """
    Does gradient ascent to get the optimal image for a given model
    Inputs
        model_fn (fn): function to which image tensor (batch x h x w x c) can be passed
        model_kwargs (dict): other keyword arguments for the model function
        checkpoint_path (str): where to find the model checkpoint
        params (dict): keys include
            - "channel": which channel to do optimization for
            - "learning rate"
            - "regularization"
            - "steps": how many steps to run for
            - optional: "unit_index"
            - optional: "loss" (str) - what loss function to use (default is just L2 regulariation). 
            - optional: "loss_lambda" (int) - constant to scale additional loss by (default is 1)
        preproc (bool): whether or not to preprocess the images ala lucid (default is True)
        preproc_params (dict): if preproc_params not included, default to lucid transform
            keys include
            - "pad": amount to pad by
            - "scale": to scale or not (bool)
            - "rotate": to rotate or not (bool)
            - "pre_jitter": how much to jitter the image before scaling and rotating
            - "post_jitter": how much to jitter the image at the end of preprocessing
        layer_name (str): which to layer to get image for
        image_resolution (int): how many pixels to make the image on each side
        
    Outputs
        optimal image (224 x 224 x 3)
    """
    # set up model
    tf.reset_default_graph()
    init = tf.random_uniform_initializer(minval=0,
                                         maxval=1)  #initialize random noise
    reg = tf.contrib.layers.l2_regularizer(
        scale=params['regularization'])  #setup L2 reg

    # set up the image variable
    image_shape = (1, image_resolution, image_resolution, 3)
    images = tf.get_variable("images",
                             image_shape,
                             initializer=init,
                             regularizer=reg)

    # preprocess the images
    if preproc is True:
        images = preprocess(images, preproc_params)

    # get features for a given layer from a given model
    tensor_name = params.get('tensor_name', None)
    layer = model_fn(images,
                     layer_name=layer_name,
                     tensor_name=tensor_name,
                     **model_kwargs)

    # extract specified aspect of the netowork to optimize from conv or fc layer
    target = get_network_aspect(params, layer)

    # set up loss function
    if ('loss' not in params) or (params['loss'] is None):  #loss is None:
        total_reg = tf.reduce_sum(
            tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES))  #basic L2 loss
    else:
        tv_loss = total_variation_loss(images)
        if 'loss_lambda' in params:
            scale_loss = tf.constant(params['loss_lambda'])
        else:
            scale_loss = tf.constant(1.0)
        # check and fix data type
        scale_loss = tf.cast(scale_loss, tv_loss.dtype)
        total_reg = tf.reduce_sum(
            tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES)) + tf.scalar_mul(
                    scale_loss, tv_loss)
    loss_tensor = tf.negative(tf.reduce_mean(target)) + total_reg

    # set up optimizer
    lr_tensor = tf.constant(params['learning_rate'])

    # restrict trainable variables to the image itself
    train_vars = [
        var for var in tf.trainable_variables() if 'images' in var.name
    ]
    train_op = tf.train.AdamOptimizer(lr_tensor).minimize(loss_tensor,
                                                          var_list=train_vars)

    ## Start the session, initialize variables and restore the model weights
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
    # exclude the new image variables
    temp_saver = tf.train.Saver(var_list=[
        v for v in all_variables
        if "images" not in v.name and "beta" not in v.name
    ])
    temp_saver.restore(sess, checkpoint_path)  #restore model weigths

    ## Main Loop
    loss_list = list()
    for i in range(params['steps']):
        loss_list.append(
            sess.run(loss_tensor))  #keep track of the loss over steps
        sess.run(train_op)  #optimize image

    final_image = sess.run(images)
    return norm_image(final_image.squeeze()), loss_list
Example #18
0
    def __init__(self, base_img_path, style_img_path, output_img_path,
                 output_width, convnet, content_weight, style_weight,
                 tv_weight, content_layer, style_layers, iterations):
        """
		Initialize and store parameters of the neural styler. Initialize the
		desired convnet and compute the 3 losses and gradients with respect to the
		output image.

		Params
		------
		- input_img: tensor containing: content_img, style_img and output_img.
		- convnet: [string], defines which VGG to use: vgg16 or vgg19.
		- style_layers: list containing name of layers to use for style
		  reconstruction. Defined in Gatys et. al but can be changed.
		- content_layer: string containing name of layer to use for content
		  reconstruction. Also defined in Gatys et. al.
		- content_weight: weight for the content loss.
		- style_weight: weight for the style loss.
		- tv_weight: weight for the total variation loss.
		- iterations: iterations for optimization algorithm
		- output_img_path: path to output image.

		Notes
		-----
		[1] If user specifies output width, then calculate the corresponding
		    image height. Else, output image height and width should be the
		    same as that of the content image. Also note that style image
		    should be resized to whatever was decided above.

		[2] PIL returns (width, height) whereas numpy returns (height, width)
		    since nrows=height and ncols=width.

		[3] L_BFGS requires that loss and grad be two functions so we create
		    a keras function that computes the gradients and loss and return
		    each separately using two different class methods.
		"""
        print('\nInitializing Neural Style model...')

        # store paths
        self.base_img_path = base_img_path
        self.style_img_path = style_img_path
        self.output_img_path = output_img_path

        # configuring image sizes [1, 2]
        print('\n\tResizing images...')
        self.width = output_width
        width, height = load_img(self.base_img_path).size
        new_dims = (height, width)
        size = width * height

        # store shapes for future use
        self.img_nrows = height
        self.img_ncols = width

        if self.width is not None:
            # calculate new height
            num_rows = int(np.floor(float(height * self.width / width)))
            new_dims = (num_rows, self.width)

            # update the stored shapes
            self.img_nrows = num_rows
            self.img_ncols = self.width

        # resize content and style images to this desired shape
        self.content_img = K.variable(
            preprocess_image(self.base_img_path, new_dims))
        self.style_img = K.variable(
            preprocess_image(self.style_img_path, new_dims))

        # and also create output placeholder with desired shape
        # K.set_image_dim_ordering('th')
        if K.image_dim_ordering() == 'th':
            self.output_img = K.placeholder((1, 3, new_dims[0], new_dims[1]))
            print("USE Th")
        else:
            self.output_img = K.placeholder((1, new_dims[0], new_dims[1], 3))
            print("USE TF")

        # sanity check on dimensions
        print("\tSize of content image is: {}".format(
            K.int_shape(self.content_img)))
        print("\tSize of style image is: {}".format(K.int_shape(
            self.style_img)))
        print("\tSize of output image is: {}".format(
            K.int_shape(self.output_img)))

        # combine the 3 images into a single Keras tensor
        self.input_img = K.concatenate(
            [self.content_img, self.style_img, self.output_img], axis=0)

        self.convnet = convnet
        self.iterations = iterations

        # store weights of the loss components
        self.content_weight = content_weight
        self.style_weight = style_weight
        self.tv_weight = tv_weight

        # store convnet layers
        self.content_layer = content_layer
        self.style_layers = style_layers

        # initialize the vgg16 model
        print('\tLoading {} model'.format(self.convnet.upper()))

        if self.convnet == 'vgg16':
            self.model = vgg16.VGG16(input_tensor=self.input_img,
                                     weights='imagenet',
                                     include_top=False)
        elif self.convnet == 'vgg19':
            self.model = vgg19.VGG19(input_tensor=self.input_img,
                                     weights='imagenet',
                                     include_top=False)
        elif self.convnet == 'xception':
            self.model = xception.Xception(input_tensor=self.input_img,
                                           weights='imagenet')
        else:
            self.model = resnet50.ResNet50(input_tensor=self.input_img,
                                           weights='imagenet')

        print('\tComputing losses...')
        # get the symbolic outputs of each "key" layer (we gave them unique names).
        outputs_dict = dict([(layer.name, layer.output)
                             for layer in self.model.layers])
        # print(outputs_dict)
        # extract features only from the content layer
        content_features = outputs_dict[self.content_layer]

        # extract the activations of the base image and the output image
        base_image_features = content_features[
            0, :, :, :]  # 0 corresponds to base
        combination_features = content_features[
            2, :, :, :]  # 2 coresponds to output
        # print(combination_features)
        # calculate the feature reconstruction loss
        #print(base_image_features)
        #print(type(base_image_features))
        #test = feature_reconstruction_loss(base_image_features,combination_features,self.img_nrows, self.img_ncols)

        print("Size is :", size)
        content_loss = self.content_weight * feature_reconstruction_loss(
            base_image_features, combination_features, size)
        print("=========================")
        print(K.ndim(content_loss))
        print("=========================")

        # for each style layer compute style loss
        # total style loss is then weighted sum of those losses
        temp_style_loss = K.variable(0.0)
        weight = 1.0 / float(len(self.style_layers))

        for layer in self.style_layers:
            # extract features of given layer
            style_features = outputs_dict[layer]
            # from those features, extract style and output activations
            style_image_features = style_features[1, :, :, :]
            output_style_features = style_features[2, :, :, :]
            temp_style_loss += weight * \
                style_reconstruction_loss(style_image_features,
                     output_style_features,
                     self.img_nrows,
                     self.img_ncols)
        style_loss = self.style_weight * temp_style_loss

        # compute total variational loss
        tv_loss = self.tv_weight * total_variation_loss(
            self.output_img, self.img_nrows, self.img_ncols)

        # composite loss
        total_loss = content_loss + style_loss + tv_loss
        print("==========================")
        print(total_loss.ndim)
        print("==========================")
        # compute gradients of output img with respect to loss
        print('\tComputing gradients...')
        grads = K.gradients(total_loss, self.output_img)

        outputs = [total_loss]
        if type(grads) in {list, tuple}:
            outputs += grads
        else:
            outputs.append(grads)

        # [3]
        # K.set_learning_phase(1)
        # keras_learning_phase(1)
        # print(K.learning_phase())
        self.loss_and_grads = K.function([self.output_img], outputs)
layer_features = outputs_dict[content_layer]

target_image_features = layer_features[0, :, :, :]

combination_features = layer_features[2, :, :, :]
loss.assign_add(content_weight * content_loss(target_image_features,combination_features))
for layer_name in style_layers:

  layer_features = outputs_dict[layer_name]

  style_reference_features = layer_features[1, :, :, :]

  combination_features = layer_features[2, :, :, :]
  sl = style_loss(style_reference_features, combination_features, img_height, img_width)
  loss.assign_add((style_weight / len(style_layers)) * sl)
loss.assign_add(total_variation_weight * total_variation_loss(combination_image, img_height, img_width))

grads = K.gradients(loss, combination_image)[0]
fetch_loss_and_grads = K.function([combination_image], [loss, grads])


class Evaluator(object):
  def __init__(self):
    self.loss_value = None
    self.grads_values = None
  def loss(self, x):
    assert self.loss_value is None
    x = x.reshape((1, img_height, img_width, 3))
    outs = fetch_loss_and_grads([x])
    loss_value = outs[0]
    grad_values = outs[1].flatten().astype('float64')
Example #20
0
def main(FLAGS):
    style_features_t = losses.get_style_features(FLAGS)

    # Make sure the training path exists.
    training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
    if not(os.path.exists(training_path)):
        os.makedirs(training_path)

    with tf.Graph().as_default():
        with tf.Session() as sess:
            """Build Network"""
            network_fn = nets_factory.get_network_fn(
                FLAGS.loss_model,
                num_classes=1,
                is_training=False)

            image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
                FLAGS.loss_model,
                is_training=False)
            processed_images = reader.image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size,
                                            'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch)
            generated = model.net(processed_images, training=True)
            processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
                                   for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size)
                                   ]
            processed_generated = tf.stack(processed_generated)
            _, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False)

            # Log the structure of loss network
            tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):')
            for key in endpoints_dict:
                tf.logging.info(key)

            """Build Losses"""
            content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers)
            style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers)
            tv_loss = losses.total_variation_loss(generated)  # use the unprocessed image

            loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss

            # Add Summary for visualization in tensorboard.
            """Add Summary"""
            tf.summary.scalar('losses/content_loss', content_loss)
            tf.summary.scalar('losses/style_loss', style_loss)
            tf.summary.scalar('losses/regularizer_loss', tv_loss)

            tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * FLAGS.content_weight)
            tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * FLAGS.style_weight)
            tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * FLAGS.tv_weight)
            tf.summary.scalar('total_loss', loss)

            for layer in FLAGS.style_layers:
                tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer])
            tf.summary.image('generated', generated)
            # tf.image_summary('processed_generated', processed_generated)  # May be better?
            tf.summary.image('origin', tf.stack([
                image_unprocessing_fn(image) for image in tf.unstack(processed_images, axis=0, num=FLAGS.batch_size)
            ]))
            summary = tf.summary.merge_all()
            writer = tf.summary.FileWriter(training_path)

            """Prepare to Train"""
            global_step = tf.Variable(0, name="global_step", trainable=False)

            variable_to_train = []
            for variable in tf.trainable_variables():
                if not(variable.name.startswith(FLAGS.loss_model)):
                    variable_to_train.append(variable)
            train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train)

            variables_to_restore = []
            for v in tf.global_variables():
                if not(v.name.startswith(FLAGS.loss_model)):
                    variables_to_restore.append(v)
            saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1)

            sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

            # Restore variables for loss network.
            init_func = utils._get_init_fn(FLAGS)
            init_func(sess)

            # Restore variables for training model if the checkpoint file exists.
            last_file = tf.train.latest_checkpoint(training_path)
            if last_file:
                tf.logging.info('Restoring model from {}'.format(last_file))
                saver.restore(sess, last_file)

            """Start Training"""
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            start_time = time.time()
            try:
                while not coord.should_stop():
                    _, loss_t, step = sess.run([train_op, loss, global_step])
                    elapsed_time = time.time() - start_time
                    start_time = time.time()
                    """logging"""
                    # print(step)
                    if step % 10 == 0:
                        tf.logging.info('step: %d,  total Loss %f, secs/step: %f' % (step, loss_t, elapsed_time))
                    """summary"""
                    if step % 25 == 0:
                        tf.logging.info('adding summary...')
                        summary_str = sess.run(summary)
                        writer.add_summary(summary_str, step)
                        writer.flush()
                    """checkpoint"""
                    if step % 1000 == 0:
                        saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt'), global_step=step)
            except tf.errors.OutOfRangeError:
                saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done'))
                tf.logging.info('Done training -- epoch limit reached')
            finally:
                coord.request_stop()
            coord.join(threads)
Example #21
0
def main(argv=None):
    content_layers = FLAGS.content_layers.split(',')
    style_layers = FLAGS.style_layers.split(',')
    style_layers_weights = [
        float(i) for i in FLAGS.style_layers_weights.split(",")
    ]
    #num_steps_decay = 82786 / FLAGS.batch_size
    num_steps_decay = 10000

    style_features_t = losses.get_style_features(FLAGS)
    training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
    if not (os.path.exists(training_path)):
        os.makedirs(training_path)

    with tf.Session() as sess:
        """Build Network"""
        network_fn = nets_factory.get_network_fn(FLAGS.loss_model,
                                                 num_classes=1,
                                                 is_training=False)
        image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
            FLAGS.loss_model, is_training=False)
        processed_images = reader.image(FLAGS.batch_size,
                                        FLAGS.image_size,
                                        FLAGS.image_size,
                                        'train2014/',
                                        image_preprocessing_fn,
                                        epochs=FLAGS.epoch)
        generated = model.net(processed_images, FLAGS.alpha)
        processed_generated = [
            image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
            for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size)
        ]
        processed_generated = tf.stack(processed_generated)
        _, endpoints_dict = network_fn(tf.concat(
            [processed_generated, processed_images], 0),
                                       spatial_squeeze=False)
        """Build Losses"""
        content_loss = losses.content_loss(endpoints_dict, content_layers)
        style_loss, style_losses = losses.style_loss(endpoints_dict,
                                                     style_features_t,
                                                     style_layers,
                                                     style_layers_weights)
        tv_loss = losses.total_variation_loss(
            generated)  # use the unprocessed image
        content_loss = FLAGS.content_weight * content_loss
        style_loss = FLAGS.style_weight * style_loss
        tv_loss = FLAGS.tv_weight * tv_loss
        loss = style_loss + content_loss + tv_loss
        """Prepare to Train"""
        global_step = tf.Variable(0, name="global_step", trainable=False)
        variable_to_train = []
        for variable in tf.trainable_variables():
            if not (variable.name.startswith(FLAGS.loss_model)):
                variable_to_train.append(variable)

        lr = tf.train.exponential_decay(learning_rate=1e-1,
                                        global_step=global_step,
                                        decay_steps=num_steps_decay,
                                        decay_rate=1e-1,
                                        staircase=True)
        optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon=1e-8)
        train_op = optimizer.minimize(loss,
                                      global_step=global_step,
                                      var_list=variable_to_train)
        #train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train)
        variables_to_restore = []
        for v in tf.global_variables():
            if not (v.name.startswith(FLAGS.loss_model)):
                variables_to_restore.append(v)
        saver = tf.train.Saver(variables_to_restore)
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])
        init_func = utils._get_init_fn(FLAGS)
        init_func(sess)
        last_file = tf.train.latest_checkpoint(training_path)
        if last_file:
            saver.restore(sess, last_file)
        """Start Training"""
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while not coord.should_stop():
                _, c_loss, s_losses, t_loss, total_loss, step = sess.run([
                    train_op, content_loss, style_losses, tv_loss, loss,
                    global_step
                ])
                """logging"""
                if step % 10 == 0:
                    print(step, c_loss, s_losses, t_loss, total_loss)
                """checkpoint"""
                if step % 10000 == 0:
                    saver.save(sess,
                               os.path.join(training_path, 'fast-style-model'),
                               global_step=step)
                if step == FLAGS.max_iter:
                    saver.save(
                        sess,
                        os.path.join(training_path, 'fast-style-model-done'))
                    break
        except tf.errors.OutOfRangeError:
            saver.save(sess,
                       os.path.join(training_path, 'fast-style-model-done'))
            tf.logging.info('Done training -- epoch limit reached')
        finally:
            coord.request_stop()
        coord.join(threads)