Ejemplo n.º 1
0
def total_loss(combination_image, base_image, style_image):
    input_tensor = tf.concat([base_image, style_image, combination_image],
                             axis=0)
    features = feature_model(input_tensor)

    # Initialize the loss
    loss = tf.zeros(shape=())

    layer_features = features[content_layer_name]
    content_image_features = layer_features[0, :, :, :]
    combination_features = layer_features[2, :, :, :]

    loss = loss + CONTENT_WEIGHT * content_loss(content_image_features,
                                                combination_features)

    for layer_name in style_layer_names:
        layer_features = features[layer_name]
        style_features = layer_features[1, :, :, :]
        combination_features = layer_features[2, :, :, :]
        sl = style_loss(style_features,
                        combination_features,
                        size=image_height * image_width)
        loss += (STYLE_WEIGHT / len(style_layer_names)) * sl

    loss += TV_WEIGHT * \
        total_variation_loss(combination_image, image_height, image_width)
    return loss
Ejemplo n.º 2
0
def main(argv=None):
    network_fn = nets_factory.get_network_fn('vgg_16',
                                             num_classes=1,
                                             is_training=False)
    image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
        'vgg_16', is_training=False)

    preprocess_content_image = reader.get_image(FLAGS.CONTENT_IMAGE,
                                                FLAGS.IMAGE_SIZE)

    # add bath for vgg net training
    preprocess_content_image = tf.expand_dims(preprocess_content_image, 0)
    _, endpoints_dict = network_fn(preprocess_content_image,
                                   spatial_squeeze=False)

    # Log the structure of loss network
    tf.logging.info(
        'Loss network layers(You can define them in "content_layers" and "style_layers"):'
    )
    for key in endpoints_dict:
        tf.logging.info(key)
    """Build Losses"""
    # style_features_t = losses.get_style_features(endpoints_dict, FLAGS.STYLE_LAYERS)
    content_loss, generaged_image = losses.content_loss(
        endpoints_dict, FLAGS.CONTENT_LAYERS, FLAGS.CONTENT_IMAGE)
    style_loss, style_loss_summary = losses.style_loss(endpoints_dict,
                                                       FLAGS.style_layers,
                                                       FLAGS.STYLE_IMAGE)
    tv_loss = losses.total_variation_loss(
        generaged_image)  # use the unprocessed image

    loss = FLAGS.STYLE_WEIGHT * style_loss + FLAGS.CONTENT_WEIGHT * content_loss + FLAGS.TV_WEIGHT * tv_loss
    train_op = tf.train.AdamOptimizer(FLAGS.LEARNING_RATE).minimize(loss)

    output_image = tf.image.encode_png(
        tf.saturate_cast(
            tf.squeeze(generaged_image) + reader.mean_pixel, tf.uint8))

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        start_time = time.time()
        for step in range(FLAGS.NUM_ITERATIONS):
            _, loss_t, cl, sl = sess.run(
                [train_op, loss, content_loss, style_loss])
            elapsed = time.time() - start_time
            start_time = time.time()
            print(step, elapsed, loss_t, cl, sl)
        image_t = sess.run(output_image)
        with open('out.png', 'wb') as f:
            f.write(image_t)
Ejemplo n.º 3
0
 def _total_content_loss(sess, network, content, config: ArtistConfig):
     sess.run(network['input'].assign(content))
     loss = 0.
     if config.verbose:
         print('Content Layer: ')
     for indx, w in zip(config.content_layers,
                        config.content_layer_weights):
         if config.verbose:
             print(f'\t{list(network.keys())[indx]}')
         p = tf.convert_to_tensor(
             sess.run(network[list(network.keys())[indx]]))
         x = network[list(network.keys())[indx]]
         loss += content_loss(p, x)
     loss /= float(len(config.content_layer_weights))
     return loss
Ejemplo n.º 4
0
def main():
    # Make sure the training path exists.
    training_path = 'models/log/'
    if not (os.path.exists(training_path)):
        os.makedirs(training_path)
    style_features = get_style_feature()

    with tf.Graph().as_default():
        train_image = reader.get_train_image(batch_size, image_size,
                                             image_size, dataset_path)
        generated = model.net(train_image)
        processed_generated = [
            reader.prepose_image(image, image_size, image_size)
            for image in tf.unstack(generated, axis=0, num=batch_size)
        ]
        processed_generated = tf.stack(processed_generated)
        net = model.load_model(
            tf.concat([processed_generated, train_image], 0), vgg16_ckpt_path)
        with tf.Session() as sess:
            """Build Losses"""
            content_loss = losses.content_loss(net, content_layers)
            style_loss, style_loss_summary = losses.style_loss(
                net, style_features, style_layers)
            tv_loss = losses.total_variation_loss(
                generated)  # use the unprocessed image

            loss = style_weight * style_loss + content_weight * content_loss + tv_weight * tv_loss

            # Add Summary for visualization in tensorboard.
            """Add Summary"""
            tf.summary.scalar('losses/content_loss', content_loss)
            tf.summary.scalar('losses/style_loss', style_loss)
            tf.summary.scalar('losses/regularizer_loss', tv_loss)

            tf.summary.scalar('weighted_losses/weighted_content_loss',
                              content_loss * content_weight)
            tf.summary.scalar('weighted_losses/weighted_style_loss',
                              style_loss * style_weight)
            tf.summary.scalar('weighted_losses/weighted_regularizer_loss',
                              tv_loss * tv_weight)
            tf.summary.scalar('total_loss', loss)

            for layer in style_layers:
                tf.summary.scalar('style_losses/' + layer,
                                  style_loss_summary[layer])
            tf.summary.image('generated', generated)
            tf.summary.image(
                'origin',
                tf.stack([
                    reader.mean_add(image) for image in tf.unstack(
                        train_image, axis=0, num=batch_size)
                ]))
            summary = tf.summary.merge_all()
            writer = tf.summary.FileWriter(training_path)
            """Prepare to Train"""
            global_step = tf.Variable(0, name="global_step", trainable=False)
            train_op = tf.train.AdamOptimizer(1e-3).minimize(
                loss,
                global_step=global_step,
                var_list=tf.trainable_variables())

            saver = tf.train.Saver(tf.trainable_variables())

            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer()
            ])

            # Restore variables for training model if the checkpoint file exists.
            last_file = tf.train.latest_checkpoint(training_path)
            if last_file:
                tf.logging.info('Restoring model from {}'.format(last_file))
                saver.restore(sess, last_file)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            start_time = time.time()
            try:
                while not coord.should_stop():
                    _, loss_t, step = sess.run([train_op, loss, global_step])
                    loss_c, loss_s = sess.run([content_loss, style_loss])
                    elapsed_time = time.time() - start_time
                    start_time = time.time()
                    """logging"""
                    if step % 10 == 0:
                        print(
                            'step: %d, content Loss %f, style Loss %f, total Loss %f, secs/step: %f'
                            % (step, loss_c, loss_s, loss_t, elapsed_time))
                    """summary"""
                    if step % 25 == 0:
                        print('adding summary...')
                        summary_str = sess.run(summary)
                        writer.add_summary(summary_str, step)
                        writer.flush()
                    """checkpoint"""
                    if step % 1000 == 0:
                        saver.save(sess,
                                   os.path.join(training_path,
                                                'fast-style-model.ckpt'),
                                   global_step=step)
            except tf.errors.OutOfRangeError:
                saver.save(
                    sess,
                    os.path.join(training_path, 'fast-style-model.ckpt-done'))
                print('Done training -- epoch limit reached')
            finally:
                coord.request_stop()
            coord.join(threads)
Ejemplo n.º 5
0
                    include_top=False)
print('Loaded Model')

#creating a dictionary with the layer name as key and the layer output as value
output_dict = dict([(layer.name, layer.output) for layer in model.layers])

#initializing the loss as a tensorflow variable
loss = K.variable(0.)

layer_features = output_dict['block2_conv2']
#content_image_features output at block2_conv2
content_image_features = layer_features[0, :, :, :]
#combination_image_features output at block2_conv2
combination_image_features = layer_features[2, :, :, :]
#calculating the content loss
loss += content_weight * content_loss(content_image_features,
                                      combination_image_features)

#layers at which the style loss is to be calculated
feature_layers = [
    'block1_conv2', 'block2_conv2', 'block3_conv3', 'block4_conv3',
    'block5_conv3'
]

#iterating over the feature layers
for layer_name in feature_layers:
    layer_features = output_dict[layer_name]
    #style_image_features output at the particular feature layer
    style_image_features = layer_features[1, :, :, :]
    #combination_image_features output at the particular feature layer
    combination_image_features = layer_features[2, :, :, :]
    #calculating the style loss
def solve(Config):
    gc.enable()
    # get the style feature
    style_features = losses.get_style_feature(Config)
    # prepare some dirs for use
    # tf.reset_default_graph()
    model_dir = Config.model_dir
    if not osp.exists(model_dir):
        os.mkdir(model_dir)

    # construct the graph and model
    # prepare the dataset
    images = Dataset(Config).imagedata_pipelines()
    # the trainnet

    generated = model.inference_trainnet(images)
    # concat the content image and the generated together to save time and feed to the vgg net one time
    # preprocess the generated
    preprocess_generated = preprocess(generated, Config)
    layer_infos = Vgg(Config.feature_path).build(
        tf.concat([preprocess_generated, images], 0))
    # get the loss
    content_loss = losses.content_loss(layer_infos, Config.content_layers)
    style_loss = losses.style_loss(layer_infos, Config.style_layers,
                                   style_features)
    tv_loss = losses.tv_loss(generated)
    loss = Config.style_weight * style_loss + Config.content_weight * content_loss + Config.tv_weight * tv_loss
    # train op
    global_step = tf.Variable(0, name='global_step', trainable=False)
    train_op = tf.train.AdamOptimizer(Config.lr).minimize(
        loss, global_step=global_step)

    # add summary
    with tf.name_scope('losses'):
        tf.summary.scalar('content_loss', content_loss)
        tf.summary.scalar('style_loss', style_loss)
        tf.summary.scalar('tv_loss', tv_loss)
    with tf.name_scope('weighted_losses'):
        tf.summary.scalar('weighted_content_loss',
                          content_loss * Config.content_weight)
        tf.summary.scalar('weighted_style_loss',
                          style_loss * Config.style_weight)
        tf.summary.scalar('weighted_tv_loss', tv_loss * Config.tv_weight)
    tf.summary.scalar('total_loss', loss)
    tf.summary.image('generated', generated)
    tf.summary.image('original', images)
    summary = tf.summary.merge_all()
    summary_path = osp.join(model_dir, 'summary')
    if not osp.exists(summary_path):
        os.mkdir(summary_path)
    writer = tf.summary.FileWriter(summary_path)

    # the saver loader
    saver = tf.train.Saver(tf.global_variables())
    #for var in tf.global_variables():
    #    print var
    restore = tf.train.latest_checkpoint(model_dir)

    # begin training work
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        # restore the variables
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])

        # if we need finetune?
        if Config.finetune:
            if restore:
                print 'restoring model from {}'.format(restore)
                saver.restore(sess, restore)
            else:
                print 'no model exist, from scratch'

            # pop the data queue
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        print 'begin training'
        start_time = time.time()
        local_time = time.time()
        for step in xrange(Config.max_iter + 1):
            _, loss_value = sess.run([train_op, loss])
            #plt.imshow(np.uint8(gen[0,...]))
            if step % Config.display == 0 or step == Config.max_iter:
                print "{}[iterations], train loss {}, time consumes {}s".format(
                    step, loss_value,
                    time.time() - local_time)
                local_time = time.time()
            assert not np.isnan(loss_value), 'model with loss nan'
            if step != 0 and (step % Config.snapshot == 0
                              or step == Config.max_iter):
                # save the generated to see
                print 'adding summary and saving snapshot...'
                saver.save(sess,
                           osp.join(model_dir, 'model.ckpt'),
                           global_step=step)
                summary_str = sess.run(summary)
                writer.add_summary(summary_str, global_step=step)
                writer.flush()
        coord.request_stop()
        coord.join(threads)
        sess.close()

        print 'done, consumes time {}s'.format(time.time() - start_time)
Ejemplo n.º 7
0
def main(unused_agrv=None):
    """main

    :param args:
        argparse.Namespace object from argparse.parse_args().
    """
    # Unpack command-line arguments.
    train_dir = FLAGS.train_dir
    style_dataset = FLAGS.style_dataset
    model_name = FLAGS.model_name
    preprocess_size = [FLAGS.image_size, FLAGS.image_size]
    batch_size = FLAGS.batch_size
    n_epochs = FLAGS.n_epochs
    learn_rate = FLAGS.learning_rate
    content_weights = FLAGS.content_weights
    style_weights = FLAGS.style_weights
    num_pipe_buffer = FLAGS.num_pipe_buffer
    num_styles = FLAGS.num_styles
    train_steps = FLAGS.train_steps
    upsample_method = FLAGS.upsample_method

    # Setup input pipeline (delegate it to CPU to let GPU handle neural net)
    files = tf.train.match_filenames_once(train_dir + '/train-*')
    style_files = tf.train.match_filenames_once(style_dataset)
    print("style %s" % style_files)

    with tf.variable_scope('input_pipe'), tf.device('/cpu:0'):
        _, style_labels, style_grams = datapipe.style_batcher(
            style_files, batch_size, preprocess_size, n_epochs,
            num_pipe_buffer)
        batch_op = datapipe.batcher(files, batch_size, preprocess_size,
                                    n_epochs, num_pipe_buffer)
    """ Set up weight of style and content image """
    content_weights = ast.literal_eval(content_weights)
    style_weights = ast.literal_eval(style_weights)

    target_grams = []
    for name, val in style_weights.iteritems():
        target_grams.append(style_grams[name])

    # Alter the names to include a namescope that we'll use + output suffix.
    loss_style_layers = []
    loss_style_weights = []
    loss_content_layers = []
    loss_content_weights = []
    for key, val in style_weights.iteritems():
        loss_style_layers.append(key + ':0')
        loss_style_weights.append(val)
    for key, val in content_weights.iteritems():
        loss_content_layers.append(key + ':0')
        loss_content_weights.append(val)

    # Load in image transformation network into default graph.
    shape = [batch_size] + preprocess_size + [3]
    with tf.variable_scope('styleNet'):
        X = tf.placeholder(tf.float32, shape=shape, name='input')
        Y = transform(X, style_labels, num_styles, upsample_method)
        print(Y)

    # Connect vgg directly to the image transformation network.
    with tf.variable_scope('vgg'):
        vggnet = vgg16.vgg16(Y)

    # Get the gram matrices' tensors for the style loss features.
    input_img_grams = losses.get_grams(loss_style_layers)

    # Get the tensors for content loss features.
    content_layers = losses.get_layers(loss_content_layers)

    # Create loss function
    content_targets = tuple(
        tf.placeholder(tf.float32,
                       shape=layer.get_shape(),
                       name='content_input_{}'.format(i))
        for i, layer in enumerate(content_layers))
    cont_loss = losses.content_loss(content_layers, content_targets,
                                    loss_content_weights)
    style_loss = losses.style_loss(input_img_grams, target_grams,
                                   loss_style_weights)
    tv_loss = losses.tv_loss(Y)
    loss = cont_loss + style_loss + tv_loss

    # We do not want to train VGG, so we must grab the subset.
    train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   scope='styleNet')

    # Setup step + optimizer
    global_step = tf.Variable(0, name='global_step', trainable=False)
    optimizer = tf.train.AdamOptimizer(learn_rate).minimize(
        loss, global_step, train_vars)

    if not os.path.exists('./models'):  # Dir that save final models to
        os.makedirs('./models')
    final_saver = tf.train.Saver(train_vars)

    # We must include local variables because of batch pipeline.
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    # Begin training.
    print 'Starting training...'
    with tf.Session() as sess:
        # Initialization
        sess.run(init_op)
        vggnet.load_weights(vgg16.checkpoint_file(), sess)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            while not coord.should_stop():
                current_step = sess.run(global_step)
                batch = sess.run(batch_op)

                # Collect content targets
                content_data = sess.run(content_layers, feed_dict={Y: batch})
                feed_dict = {X: batch, content_targets: content_data}

                _, loss_out = sess.run([optimizer, loss], feed_dict=feed_dict)
                if (current_step % 10 == 0):
                    print current_step, loss_out

                # Throw error if we reach number of steps to break after.
                if current_step == train_steps:
                    print('Done training.')
                    break
        except tf.errors.OutOfRangeError:
            print('Done training.')
        finally:
            # Save the model (the image transformation network) for later usage
            # in predict.py
            final_saver.save(sess,
                             'models/' + model_name + '_final.ckpt',
                             write_meta_graph=False)
            coord.request_stop()

        coord.join(threads)
def main(args):
    """main
    :param args:
        argparse.Namespace object from argparse.parse_args().
    """
    # Unpack command-line arguments.
    train_dir = args.train_dir
    style_img_path = args.style_img_path
    model_name = args.model_name
    preprocess_size = args.preprocess_size
    batch_size = args.batch_size
    n_epochs = args.n_epochs
    run_name = args.run_name
    learn_rate = args.learn_rate
    loss_content_layers = args.loss_content_layers
    loss_style_layers = args.loss_style_layers
    content_weights = args.content_weights
    style_weights = args.style_weights
    num_steps_ckpt = args.num_steps_ckpt
    num_pipe_buffer = args.num_pipe_buffer
    num_steps_break = args.num_steps_break
    beta_val = args.beta
    style_target_resize = args.style_target_resize
    upsample_method = args.upsample_method

    # Load in style image that will define the model.
    style_img = utils.imread(style_img_path)
    style_img = utils.imresize(style_img, style_target_resize)
    style_img = style_img[np.newaxis, :].astype(np.float32)

    # Alter the names to include a namescope that we'll use + output suffix.
    loss_style_layers = ['vgg/' + i + ':0' for i in loss_style_layers]
    loss_content_layers = ['vgg/' + i + ':0' for i in loss_content_layers]

    # Get target Gram matrices from the style image.
    with tf.variable_scope('vgg'):
        X_vgg = tf.placeholder(tf.float32, shape=style_img.shape, name='input')
        vggnet = vgg16.vgg16(X_vgg)
    with tf.Session() as sess:
        vggnet.load_weights('libs/vgg16_weights.npz', sess)
        print('Precomputing target style layers.')
        target_grams = sess.run(utils.get_grams(loss_style_layers),
                                feed_dict={X_vgg: style_img})

    # Clean up so we can re-create vgg connected to our image network.
    print('Resetting default graph.')
    tf.reset_default_graph()

    # Load in image transformation network into default graph.
    shape = [batch_size] + preprocess_size + [3]
    with tf.variable_scope('img_t_net'):
        X = tf.placeholder(tf.float32, shape=shape, name='input')
        Y = create_net(X, upsample_method)

    # Connect vgg directly to the image transformation network.
    with tf.variable_scope('vgg'):
        vggnet = vgg16.vgg16(Y)

    # Get the gram matrices' tensors for the style loss features.
    input_img_grams = utils.get_grams(loss_style_layers)

    # Get the tensors for content loss features.
    content_layers = utils.get_layers(loss_content_layers)

    # Create loss function
    content_targets = tuple(
        tf.placeholder(tf.float32,
                       shape=layer.get_shape(),
                       name='content_input_{}'.format(i))
        for i, layer in enumerate(content_layers))
    cont_loss = losses.content_loss(content_layers, content_targets,
                                    content_weights)
    style_loss = losses.style_loss(input_img_grams, target_grams,
                                   style_weights)
    tv_loss = losses.tv_loss(Y)
    beta = tf.placeholder(tf.float32, shape=[], name='tv_scale')
    loss = cont_loss + style_loss + beta * tv_loss
    with tf.name_scope('summaries'):
        tf.summary.scalar('loss', loss)
        tf.summary.scalar('style_loss', style_loss)
        tf.summary.scalar('content_loss', cont_loss)
        tf.summary.scalar('tv_loss', beta * tv_loss)

    # Setup input pipeline (delegate it to CPU to let GPU handle neural net)
    files = tf.train.match_filenames_once(train_dir + '/train-*')
    with tf.variable_scope('input_pipe'), tf.device('/cpu:0'):
        batch_op = datapipe.batcher(files, batch_size, preprocess_size,
                                    n_epochs, num_pipe_buffer)

    # We do not want to train VGG, so we must grab the subset.
    train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   scope='img_t_net')

    # Setup step + optimizer
    global_step = tf.Variable(0, name='global_step', trainable=False)
    optimizer = tf.train.AdamOptimizer(learn_rate).minimize(
        loss, global_step, train_vars)

    # Setup subdirectory for this run's Tensoboard logs.
    if not os.path.exists('./summaries/train/'):
        os.makedirs('./summaries/train/')
    if run_name is None:
        current_dirs = [
            name for name in os.listdir('./summaries/train/')
            if os.path.isdir('./summaries/train/' + name)
        ]
        name = model_name + '0'
        count = 0
        while name in current_dirs:
            count += 1
            name = model_name + '{}'.format(count)
        run_name = name

    # Savers and summary writers
    if not os.path.exists('./training'):  # Dir that we'll later save .ckpts to
        os.makedirs('./training')
    if not os.path.exists('./models'):  # Dir that save final models to
        os.makedirs('./models')
    saver = tf.train.Saver()
    final_saver = tf.train.Saver(train_vars)
    merged = tf.summary.merge_all()
    full_log_path = './summaries/train/' + run_name
    train_writer = tf.summary.FileWriter(full_log_path)

    # We must include local variables because of batch pipeline.
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    # Begin training.
    print('Starting training...')
    with tf.Session() as sess:
        # Initialization
        sess.run(init_op)
        vggnet.load_weights('libs/vgg16_weights.npz', sess)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            while not coord.should_stop():
                current_step = sess.run(global_step)
                batch = sess.run(batch_op)

                # Collect content targets
                content_data = sess.run(content_layers, feed_dict={Y: batch})

                feed_dict = {
                    X: batch,
                    content_targets: content_data,
                    beta: beta_val
                }
                if (current_step % num_steps_ckpt == 0):
                    # Save a checkpoint
                    save_path = 'training/' + model_name + '.ckpt'
                    saver.save(sess, save_path, global_step=global_step)
                    summary, _, loss_out = sess.run([merged, optimizer, loss],
                                                    feed_dict=feed_dict)
                    train_writer.add_summary(summary, current_step)
                    print(current_step, loss_out)

                elif (current_step % 10 == 0):
                    # Collect some diagnostic data for Tensorboard.
                    summary, _, loss_out = sess.run([merged, optimizer, loss],
                                                    feed_dict=feed_dict)
                    train_writer.add_summary(summary, current_step)

                    # Do some standard output.
                    print(current_step, loss_out)
                else:
                    _, loss_out = sess.run([optimizer, loss],
                                           feed_dict=feed_dict)

                # Throw error if we reach number of steps to break after.
                if current_step == num_steps_break:
                    print('Done training.')
                    break
        except tf.errors.OutOfRangeError:
            print('Done training.')
        finally:
            # Save the model (the image transformation network) for later usage
            # in predict.py
            final_saver.save(sess, 'models/' + model_name + '_final.ckpt')

            coord.request_stop()

        coord.join(threads)
Ejemplo n.º 9
0
                'block2_conv1',
                'block3_conv1',
                'block4_conv1',
                'block5_conv1']
total_variation_weight = 10 **  -4 
style_weight = 1. 
content_weight = 0.025 


loss = K.variable(0.)
layer_features = outputs_dict[content_layer]

target_image_features = layer_features[0, :, :, :]

combination_features = layer_features[2, :, :, :]
loss.assign_add(content_weight * content_loss(target_image_features,combination_features))
for layer_name in style_layers:

  layer_features = outputs_dict[layer_name]

  style_reference_features = layer_features[1, :, :, :]

  combination_features = layer_features[2, :, :, :]
  sl = style_loss(style_reference_features, combination_features, img_height, img_width)
  loss.assign_add((style_weight / len(style_layers)) * sl)
loss.assign_add(total_variation_weight * total_variation_loss(combination_image, img_height, img_width))

grads = K.gradients(loss, combination_image)[0]
fetch_loss_and_grads = K.function([combination_image], [loss, grads])

Ejemplo n.º 10
0
def main(FLAGS):
    style_features_t = losses.get_style_features(FLAGS) #style target的Gram

    #make sure the training path exists
    training_path = os.path.join(FLAGS.model_path,FLAGS.naming) #model/wave/ ;用于存放训练好的模型
    if not (os.path.exists(training_path)):
        os.makedirs(training_path)

    with tf.Graph().as_default(): #默认计算图
        with tf.Session() as sess:#没有as_default(),因此,走出with 语句,sess停止执行,不能在被用
            """build loss network"""
            network_fn =nets_factory.get_network_fn(FLAGS.loss_model,num_classes=1,is_training=False) #取出loss model,且该model不用训练
            #对要进入loss_model的content_image,和generated_image进行preprocessing
            image_preprocessing_fn,image_unpreprocessing_fn = preprocessing_factory.get_preprocessing(FLAGS.loss_model,is_training=False) #取出用于loss_model的,对image进行preprocessing和unpreprocessing的function
            processed_image = reader.image(FLAGS.batch_size,FLAGS.image_size,FLAGS.image_size,'train2014/',image_preprocessing_fn,epochs=FLAGS.epoch) #这里要preprocessing的image是一个batch,为training_data
            generated = model.net(processed_images,training=True) #输入“图像生成网络”的image为经过preprocessing_image,“图像生成网络”为要训练的网络
            processed_generated = [image_preprocessing_fn(image,FLAGS.image_size,FLAGS.image_size) for image in tf.unstack(generated,axis=0,num=FLAGS.batch_size)]
            processed_generated = tf.stack(processed_generated)
            #计算generated_image和content_image进入loss_model后,更layer的output
            _,endpoints_dict= network_fn(tf.concat([processed_generated,processed_images],0),spatial_squeeze=False)#endpoints_dict中存储的是2类image各个layer的值
            #log the structure of loss network
            tf.logging.info('loss network layers(you can define them in "content layer" and "style layer"):')
            for key in endpoints_dict:
                tf.logging.info(key) #屏幕输出loss_model的各个layer name

            """build losses"""
            content_loss = losses.content_loss(endpoints_dict,FLAGS.content_layers)
            style_loss,style_loss_summary = losses.style_loss(endpoints_dict,style_features_t,FLAGS.style_layers)
            tv_loss = losses.total_variation_loss(generated)

            loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss

            # Add Summary for visualization in tensorboard
            """Add Summary"""
            tf.summary.scalar('losses/content_loss',content_loss)
            tf.summary.scalar('losses/style_loss',style_loss)
            tf.summary.scalar('losses/regularizer_loss',tv_loss)

            tf.summary.scalar('weighted_losses/weighted content_loss',content_loss * FLAGS.content_weight)
            tf.summary.scalar('weighted_losses/weighted style_loss',style_loss * FLAGS.style_weight)
            tf.summary.scalar('weighted_losses/weighted_regularizer_loss',tv_loss * FLAGS.tv_weight)
            tf.summary.scalar('total_loss',loss)

            for layer in FLAGS.style_layers:
                tf.summary.scalar('style_losses/' + layer,style_loss_summary[layer])
            tf.summary.image('genearted',generated)
            tf.summary.image('origin',tf.stack([image_unprocessing_fn(image) for image in tf.unstack(processed_images,axis=0,num=FLAGS.batch_size)]))
            summary = tf.summary.merge_all()
            writer = tf.summary.FileWriter(training_path)

            """prepare to train"""
            global_step = tf.Variable(0,name='global_step',trainable=False)#iteration step

            variable_to_train = []#需要训练的变量
            for variable in tf.trainable_variables():#在图像风格迁移网络(图像生成网络+损失网络)各参数中,找需要训练的参数
                if not (variable.name.startswith(FLAGS.loss_model)):
                    variable_to_train.append(variable)
            train_op = tf.train.AdamOptimizer(1e-3).minimize(loss,global_step = global_step,var_list = variable_to_train) #需要放入sess.run()

            variable_to_restore = []#在所有的全局变量中,找需要恢复默认设置的变量; 注意:local_variable指的是一些临时变量和中间变量,用于线程中,线程结束则消失
            for v tf.global_variables():
                if not (v.name.startswith(FLAGS.loss_model)):
                    variables_to_restore.append(v)
            saver = tf.train.Saver(variables_to_restore,write_version=tf.train.SaverDef.V1)#利用saver.restore()恢复默认设置;这里的variable_to_restore,是需要save and restore的var_list

            sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])#对全局变量和局部变量进行初始化操作:即恢复默认设置

            #restore variables for loss model 恢复loss model中的参数
            init_func = utils._get_init_fn(FLAGS)
            init_func(sess)

            #restore variables for training model if the checkpoint file exists. 如果training_model已有训练好的参数,将其载入
            last_file = tf.train.latest_checkpoint(training_path)#将train_path中的model参数数据取出
            if last_file:
                tf.logging.info('restoringmodel from {}'.format(last_file))
                saver.restore(sess,last_file) #那如果last_file不存在,就不执行restore操作吗?需要restore的参数只是图像生成网络吗?

            """start training"""
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            start_time = time.time()
            try:
                while not coord.should_stop():#查看线程是否停止(即:是否所有数据均运行完毕)
                    _,loss_t,step = sess.run([train_op,loss,global_step])
                    elapsed_time = time.time()
                    """logging"""
                    #print(step)
                    if step % 10 == 0:
                        tf.logging.info('step:%d, total loss %f, secs/step: %f' % (step,loss_t,elapsed_time))
                    """summary"""
                    if step % 25 == 0:
                        tf.logging.info('adding summary...')
                        summary_str = sess.run(summary)
                        writer.add_summary(summary_str,step)
                        writer.flush()
                    """checkpoint"""
                    if step % 1000 == 0:
                        saver.save(sess,os.path.join(training_path,'fast-style-model.ckpt'),global_step=step)#保存variable_to_restore中的参数值
            except tf.errors.OutOfRangeError:
                saver.save(sess,os.path.join(training_path,'fast-style-model.ckpt-done'))
                tf.logging.info('Done training -- epoch limit reached')
            finally:
                coord.request_stop()#要求停止所有线程
            coord.join(threads)#将线程并入主线程,删除
Ejemplo n.º 11
0
def train():

    style_feature, style_grams = get_style_feature()

    with tf.Graph().as_default():
        with tf.Session() as sess:

            #loss_input_style = tf.placeholder(dtype = tf.float32 , shape = [args.batch , args.size , args.size , args.in_dim ])
            #loss_input_target =tf.placeholder(dtype = tf.float32 , shape = [args.batch , args.size , args.size , args.in_dim ])

            # For online optimization problem, use testing preprocess for both train and test
            preprocess_func, unprocess_func = preprocessing.preprocessing_factory.get_preprocessing(
                args.loss_model, is_training=False)



            images = reader.image(args.batch, args.size , args.size, args.target_dir , preprocess_func, \
                                 args.epoch , shuffle = True)

            model = transform(sess, args)
            transformed_images = model.generator(images, reuse=False)

            #print('qqq')
            #print( tf.shape(transformed_images).eval())

            unprocess_transform = [(img) for img in tf.unstack(
                transformed_images, axis=0, num=args.batch)]

            processed_generated = [
                preprocess_func(img, args.size, args.size)
                for img in unprocess_transform
            ]
            processed_generated = tf.stack(processed_generated)

            loss_model = nets.nets_factory.get_network_fn(args.loss_model,
                                                          num_classes=1,
                                                          is_training=False)

            pair = tf.concat([processed_generated, images], axis=0)
            _, end_dicts = loss_model(pair, spatial_squeeze=False)

            init_loss_model = load_pretrained_weight(args.loss_model)

            c_loss = losses.content_loss(
                end_dicts, loss_config.content_loss_dict[args.loss_model])

            s_loss, s_loss_sum = losses.style_loss(
                end_dicts, loss_config.style_loss_dict[args.loss_model],
                style_grams)

            tv_loss = losses.total_variation_loss(transformed_images)

            loss = args.c_weight * c_loss + args.s_weight * s_loss + args.tv_weight * tv_loss

            print('shapes')
            print(pair.get_shape())

            #tf.summary.scalar('average', tf.reduce_mean(images))
            #tf.summary.scalar('gram average', tf.reduce_mean(tf.stack(style_feature)))

            tf.summary.scalar('losses/content_loss', c_loss)
            tf.summary.scalar('losses/style_loss', s_loss)
            tf.summary.scalar('losses/tv_loss', tv_loss)

            tf.summary.scalar('weighted_losses/weighted_content_loss',
                              c_loss * args.c_weight)
            tf.summary.scalar('weighted_losses/weighted_style_loss',
                              s_loss * args.s_weight)
            tf.summary.scalar('weighted_losses/weighted_regularizer_loss',
                              tv_loss * args.tv_weight)
            tf.summary.scalar('total_loss', loss)

            for layer in loss_config.style_loss_dict[args.loss_model]:
                tf.summary.scalar('style_losses/' + layer, s_loss_sum[layer])

            tf.summary.image('transformed',
                             tf.stack(unprocess_transform, axis=0))
            # tf.image_summary('processed_generated', processed_generated)  # May be better?
            tf.summary.image(
                'ori',
                tf.stack([
                    unprocess_func(image)
                    for image in tf.unstack(images, axis=0, num=args.batch)
                ]))

            summary = tf.summary.merge_all()
            writer = tf.summary.FileWriter(args.log_dir)

            step = tf.Variable(0, name='global_step', trainable=False)

            all_trainables = tf.trainable_variables()
            all_vars = tf.global_variables()
            to_train = [
                var for var in all_trainables
                if not args.loss_model in var.name
            ]
            to_restore = [
                var for var in all_vars if not args.loss_model in var.name
            ]


            optim = tf.train.AdamOptimizer( 1e-3 ).minimize(\
                                           loss = loss , var_list = to_train , global_step = step)

            saver = tf.train.Saver(to_restore)
            style_name = (args.style_dir.split('/')[-1]).split('.')[0]

            ckpt = tf.train.latest_checkpoint(
                os.path.join(args.ckpt_dir, style_name))
            if ckpt:
                tf.logging.info('Restoring model from {}'.format(ckpt))
                saver.restore(sess, ckpt)

            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer()
            ])
            #sess.run(init_loss_model)
            init_loss_model(sess)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            start_time = time()
            #i = 0
            try:
                while True:

                    _, gs, sum_info, c_info, s_info, tv_info, loss_info = sess.run(
                        [optim, step, summary, c_loss, s_loss, tv_loss, loss])
                    writer.add_summary(sum_info, gs)
                    elapsed = time() - start_time

                    print(gs)

                    if gs % 10 == 0:
                        tf.logging.info(
                            'step: %d, c_loss %f  s_loss %f  tv_loss %f total Loss %f, secs/step: %f'
                            %
                            (gs, c_info, s_info, tv_info, loss_info, elapsed))

                    if gs % args.save_freq == 0:
                        saver.save(
                            sess,
                            os.path.join(args.ckpt_dir, style_name,
                                         style_name + '.ckpt'))

            except tf.errors.OutOfRangeError:
                print('run out of images!  save final model: ' +
                      os.path.join(args.ckpt_dir, style_name + '.ckpt-done'))
                saver.save(
                    sess,
                    os.path.join(args.ckpt_dir, style_name,
                                 style_name + '.ckpt-done'))
                tf.logging.info('Done -- file ran out of range')
            finally:
                coord.request_stop()

            coord.join(threads)

            print('end training')
            '''
Ejemplo n.º 12
0
def main(style_img_path: str,
         content_img_path: str, 
         img_dim: int,
         num_iter: int,
         style_weight: int,
         content_weight: int,
         variation_weight: int,
         print_every: int,
         save_every: int):

    assert style_img_path is not None
    assert content_img_path is not None

    # define the device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # read the images
    style_img = Image.open(style_img_path)
    cont_img = Image.open(content_img_path)
    
    # define the transform
    transform = transforms.Compose([transforms.Resize((img_dim, img_dim)),
                                    transforms.ToTensor(), 
                                    transforms.Normalize([0.485, 0.456, 0.406],
                                                         [0.229, 0.224, 0.225])])
    
    # get the tensor of the image
    content_image = transform(cont_img).unsqueeze(0).to(device)
    style_image = transform(style_img).unsqueeze(0).to(device)
    
    # init the network
    vgg = VGG().to(device).eval()
    
    # replace the MaxPool with the AvgPool layers
    for name, child in vgg.vgg.named_children():
        if isinstance(child, nn.MaxPool2d):
            vgg.vgg[int(name)] = nn.AvgPool2d(kernel_size=2, stride=2)
            
    # lock the gradients
    for param in vgg.parameters():
        param.requires_grad = False
    
    # get the content activations of the content image and detach them from the graph
    content_activations = vgg.get_content_activations(content_image).detach()
    
    # unroll the content activations
    content_activations = content_activations.view(512, -1)
    
    # get the style activations of the style image
    style_activations = vgg.get_style_activations(style_image)
    
    # for every layer in the style activations
    for i in range(len(style_activations)):

        # unroll the activations and detach them from the graph
        style_activations[i] = style_activations[i].squeeze().view(style_activations[i].shape[1], -1).detach()

    # calculate the gram matrices of the style image
    style_grams = [gram(style_activations[i]) for i in range(len(style_activations))]
    
    # generate the Gaussian noise
    noise = torch.randn(1, 3, img_dim, img_dim, device=device, requires_grad=True)
    
    # define the adam optimizer
    # pass the feature map pixels to the optimizer as parameters
    adam = optim.Adam(params=[noise], lr=0.01, betas=(0.9, 0.999))

    # run the iteration
    for iteration in range(num_iter):

        # zero the gradient
        adam.zero_grad()

        # get the content activations of the Gaussian noise
        noise_content_activations = vgg.get_content_activations(noise)

        # unroll the feature maps of the noise
        noise_content_activations = noise_content_activations.view(512, -1)

        # calculate the content loss
        content_loss_ = content_loss(noise_content_activations, content_activations)

        # get the style activations of the noise image
        noise_style_activations = vgg.get_style_activations(noise)

        # for every layer
        for i in range(len(noise_style_activations)):

            # unroll the the noise style activations
            noise_style_activations[i] = noise_style_activations[i].squeeze().view(noise_style_activations[i].shape[1], -1)

        # calculate the noise gram matrices
        noise_grams = [gram(noise_style_activations[i]) for i in range(len(noise_style_activations))]

        # calculate the total weighted style loss
        style_loss = 0
        for i in range(len(style_activations)):
            N, M = noise_style_activations[i].shape[0], noise_style_activations[i].shape[1]
            style_loss += (gram_loss(noise_grams[i], style_grams[i], N, M) / 5.)

        # put the style loss on device
        style_loss = style_loss.to(device)
            
        # calculate the total variation loss
        variation_loss = total_variation_loss(noise).to(device)

        # weight the final losses and add them together
        total_loss = content_weight * content_loss_ + style_weight * style_loss + variation_weight * variation_loss

        if iteration % print_every == 0:
            print("Iteration: {}, Content Loss: {:.3f}, Style Loss: {:.3f}, Var Loss: {:.3f}".format(iteration, 
                                                                                                     content_weight * content_loss_.item(),
                                                                                                     style_weight * style_loss.item(), 
                                                                                                     variation_weight * variation_loss.item()))

        # create the folder for the generated images
        if not os.path.exists('./generated/'):
            os.mkdir('./generated/')
        
        # generate the image
        if iteration % save_every == 0:
            save_image(noise.cpu().detach(), filename='./generated/iter_{}.png'.format(iteration))

        # backprop
        total_loss.backward()
        
        # update parameters
        adam.step()
Ejemplo n.º 13
0
def main(FLAGS):
    style_features_t = losses.get_style_features(FLAGS)
    training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
    if not (os.path.exists(training_path)):
        os.makedirs(training_path)

    with tf.Graph().as_default():
        with tf.Session() as sess:
            """创建Network"""
            network_fn = nets_factory.get_network_fn(
                FLAGS.loss_model,
                num_classes=1,
                is_training=False)

            image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
                FLAGS.loss_model,
                is_training=False)

            """训练图片预处理"""
            processed_images = reader.batch_image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size,
                                                  'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch)
            generated = model.transform_network(processed_images, training=True)
            processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
                                   for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size)
                                   ]
            processed_generated = tf.stack(processed_generated)
            _, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False)
            tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):')
            for key in endpoints_dict:
                tf.logging.info(key)

            """创建 Losses"""
            content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers)
            style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers)
            tv_loss = losses.total_variation_loss(generated)  # use the unprocessed image

            loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss

            """准备训练"""
            global_step = tf.Variable(0, name="global_step", trainable=False)
            variable_to_train = []
            for variable in tf.trainable_variables():
                # 只训练和保存生成网络中的变量
                if not (variable.name.startswith(FLAGS.loss_model)):
                    variable_to_train.append(variable)

            """优化"""
            train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train)

            variables_to_restore = []
            for v in tf.global_variables():
                if not (v.name.startswith(FLAGS.loss_model)):
                    variables_to_restore.append(v)
            saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1)
            sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
            init_func = utils._get_init_fn(FLAGS)
            init_func(sess)
            last_file = tf.train.latest_checkpoint(training_path)
            if last_file:
                tf.logging.info('Restoring model from {}'.format(last_file))
                saver.restore(sess, last_file)

            """开始训练"""
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            start_time = time.time()
            try:
                while not coord.should_stop():
                    _, loss_t, step = sess.run([train_op, loss, global_step])
                    elapsed_time = time.time() - start_time
                    start_time = time.time()
                    if step % 10 == 0:
                        tf.logging.info(
                            'step: %d,  total Loss %f, secs/step: %f,%s' % (step, loss_t, elapsed_time, time.asctime()))
                    """checkpoint"""
                    if step % 50 == 0:
                        tf.logging.info('saving check point...')
                        saver.save(sess, os.path.join(training_path, FLAGS.naming + '.ckpt'), global_step=step)
            except tf.errors.OutOfRangeError:
                saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done'))
                tf.logging.info('Done training -- epoch limit reached')
            finally:
                coord.request_stop()
                tf.logging.info('coordinator stop')
            coord.join(threads)
Ejemplo n.º 14
0
def main(FLAGS):
    style_features_t = losses.get_style_features(FLAGS)

    # Make sure the training path exists.
    training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
    if not(os.path.exists(training_path)):
        os.makedirs(training_path)

    with tf.Graph().as_default():
        with tf.Session() as sess:
            """Build Network"""
            network_fn = nets_factory.get_network_fn(
                FLAGS.loss_model,
                num_classes=1,
                is_training=False)

            image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
                FLAGS.loss_model,
                is_training=False)
            processed_images = reader.image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size,
                                            'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch)
            generated = model.net(processed_images, training=True)
            processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
                                   for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size)
                                   ]
            processed_generated = tf.stack(processed_generated)
            _, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False)

            # Log the structure of loss network
            tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):')
            for key in endpoints_dict:
                tf.logging.info(key)

            """Build Losses"""
            content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers)
            style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers)
            tv_loss = losses.total_variation_loss(generated)  # use the unprocessed image

            loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss

            # Add Summary for visualization in tensorboard.
            """Add Summary"""
            tf.summary.scalar('losses/content_loss', content_loss)
            tf.summary.scalar('losses/style_loss', style_loss)
            tf.summary.scalar('losses/regularizer_loss', tv_loss)

            tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * FLAGS.content_weight)
            tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * FLAGS.style_weight)
            tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * FLAGS.tv_weight)
            tf.summary.scalar('total_loss', loss)

            for layer in FLAGS.style_layers:
                tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer])
            tf.summary.image('generated', generated)
            # tf.image_summary('processed_generated', processed_generated)  # May be better?
            tf.summary.image('origin', tf.stack([
                image_unprocessing_fn(image) for image in tf.unstack(processed_images, axis=0, num=FLAGS.batch_size)
            ]))
            summary = tf.summary.merge_all()
            writer = tf.summary.FileWriter(training_path)

            """Prepare to Train"""
            global_step = tf.Variable(0, name="global_step", trainable=False)

            variable_to_train = []
            for variable in tf.trainable_variables():
                if not(variable.name.startswith(FLAGS.loss_model)):
                    variable_to_train.append(variable)
            train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train)

            variables_to_restore = []
            for v in tf.global_variables():
                if not(v.name.startswith(FLAGS.loss_model)):
                    variables_to_restore.append(v)
            saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1)

            sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

            # Restore variables for loss network.
            init_func = utils._get_init_fn(FLAGS)
            init_func(sess)

            # Restore variables for training model if the checkpoint file exists.
            last_file = tf.train.latest_checkpoint(training_path)
            if last_file:
                tf.logging.info('Restoring model from {}'.format(last_file))
                saver.restore(sess, last_file)

            """Start Training"""
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            start_time = time.time()
            try:
                while not coord.should_stop():
                    _, loss_t, step = sess.run([train_op, loss, global_step])
                    elapsed_time = time.time() - start_time
                    start_time = time.time()
                    """logging"""
                    # print(step)
                    if step % 10 == 0:
                        tf.logging.info('step: %d,  total Loss %f, secs/step: %f' % (step, loss_t, elapsed_time))
                    """summary"""
                    if step % 25 == 0:
                        tf.logging.info('adding summary...')
                        summary_str = sess.run(summary)
                        writer.add_summary(summary_str, step)
                        writer.flush()
                    """checkpoint"""
                    if step % 1000 == 0:
                        saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt'), global_step=step)
            except tf.errors.OutOfRangeError:
                saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done'))
                tf.logging.info('Done training -- epoch limit reached')
            finally:
                coord.request_stop()
            coord.join(threads)
Ejemplo n.º 15
0
    def compute_losses(self):
        """
        In this function we are defining the variables for loss calculations
        and training model.

        d_loss_A/d_loss_B -> loss for discriminator A/B
        g_loss_A/g_loss_B -> loss for generator A/B
        *_trainer -> Various trainer for above loss functions
        *_summ -> Summary variables for above loss functions
        """
        #cycle loss
        cycle_consistency_loss_a = \
            self._lambda_a * losses.cycle_consistency_loss(
                real_images=self.input_a, generated_images=self.cycle_images_a,
            )
        cycle_consistency_loss_b = \
            self._lambda_b * losses.cycle_consistency_loss(
                real_images=self.input_b, generated_images=self.cycle_images_b,
            )

        #vgg_loss
        content_loss_a = self._delta_a * losses.content_loss(real_images=self.input_a, generated_images=self.cycle_images_a)
        content_loss_b = self._delta_b * losses.content_loss(real_images=self.input_b, generated_images=self.cycle_images_b)

        #adv_loss
        lsgan_loss_a = losses.lsgan_loss_generator(self.prob_fake_a_is_real)
        lsgan_loss_b = losses.lsgan_loss_generator(self.prob_fake_b_is_real)

        g_loss_A = cycle_consistency_loss_a + cycle_consistency_loss_b + lsgan_loss_b + content_loss_a + content_loss_b 
        g_loss_B = cycle_consistency_loss_b + cycle_consistency_loss_a + lsgan_loss_a + content_loss_b + content_loss_a 

        d_loss_A = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_real_a_is_real,
            prob_fake_is_real=self.prob_fake_pool_a_is_real,
        )
        d_loss_B = losses.lsgan_loss_discriminator(
            prob_real_is_real=self.prob_real_b_is_real,
            prob_fake_is_real=self.prob_fake_pool_b_is_real,
        )

        optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5)

        self.model_vars = tf.trainable_variables()

        d_A_vars = [var for var in self.model_vars if 'd_A' in var.name]
        g_A_vars = [var for var in self.model_vars if 'g_A' in var.name]
        d_B_vars = [var for var in self.model_vars if 'd_B' in var.name]
        g_B_vars = [var for var in self.model_vars if 'g_B' in var.name]

        self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)
        self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)
        self.g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars)
        self.g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars)

        for var in self.model_vars:
            print(var.name)

        # Summary variables for tensorboard
        self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A)
        self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B)
        self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A)
        self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B)
def main(args):
    # Unpack command-line arguments.
    style_img_path = args.style_img_path
    cont_img_path = args.cont_img_path
    learn_rate = args.learn_rate
    loss_content_layers = args.loss_content_layers
    loss_style_layers = args.loss_style_layers
    content_weights = args.content_weights
    style_weights = args.style_weights
    num_steps_break = args.num_steps_break
    beta = args.beta
    style_target_resize = args.style_target_resize
    cont_target_resize = args.cont_target_resize
    output_img_path = args.output_img_path

    # Load in style image that will define the model.
    style_img = utils.imread(style_img_path)
    style_img = utils.imresize(style_img, style_target_resize)
    style_img = style_img[np.newaxis, :].astype(np.float32)

    # Alter the names to include a namescope that we'll use + output suffix.
    loss_style_layers = ['vgg/' + i + ':0' for i in loss_style_layers]
    loss_content_layers = ['vgg/' + i + ':0' for i in loss_content_layers]

    # Get target Gram matrices from the style image.
    with tf.variable_scope('vgg'):
        X_vgg = tf.placeholder(tf.float32, shape=style_img.shape, name='input')
        vggnet = vgg16.vgg16(X_vgg)
    with tf.Session() as sess:
        vggnet.load_weights('libs/vgg16_weights.npz', sess)
        print 'Precomputing target style layers.'
        target_grams = sess.run(utils.get_grams(loss_style_layers),
                                feed_dict={'vgg/input:0': style_img})

    # Clean up so we can re-create vgg at size of input content image for
    # training.
    print 'Resetting default graph.'
    tf.reset_default_graph()

    # Read in + resize the content image.
    cont_img = utils.imread(cont_img_path)
    cont_img = utils.imresize(cont_img, cont_target_resize)
    cont_img = cont_img[np.newaxis, :].astype(np.float32)

    # Setup VGG and initialize it with white noise image that we'll optimize.
    shape = cont_img.shape
    with tf.variable_scope('to_train'):
        white_noise = np.random.rand(shape[0], shape[1], shape[2],
                                     shape[3]) * 255.0
        white_noise = tf.constant(white_noise.astype(np.float32))
        X = tf.get_variable('input', dtype=tf.float32, initializer=white_noise)
    with tf.variable_scope('vgg'):
        vggnet = vgg16.vgg16(X)

    # Get the gram matrices' tensors for the style loss features.
    input_img_grams = utils.get_grams(loss_style_layers)

    # Get the tensors for content loss features.
    content_layers = utils.get_layers(loss_content_layers)

    # Get the target content features
    with tf.Session() as sess:
        vggnet.load_weights('libs/vgg16_weights.npz', sess)
        print 'Precomputing target content layers.'
        content_targets = sess.run(content_layers,
                                   feed_dict={'to_train/input:0': cont_img})

    # Create loss function
    cont_loss = losses.content_loss(content_layers, content_targets,
                                    content_weights)
    style_loss = losses.style_loss(input_img_grams, target_grams,
                                   style_weights)
    tv_loss = losses.tv_loss(X)
    loss = cont_loss + style_loss + beta * tv_loss

    # We do not want to train VGG, so we must grab the subset.
    train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   scope='to_train')

    # Setup step + optimizer
    global_step = tf.Variable(0, name='global_step', trainable=False)
    optimizer = tf.train.AdamOptimizer(learn_rate) \
                  .minimize(loss, global_step, train_vars)

    # Initializer
    init_op = tf.global_variables_initializer()

    # Begin training
    with tf.Session() as sess:
        sess.run(init_op)
        vggnet.load_weights('libs/vgg16_weights.npz', sess)

        current_step = 0
        while current_step < num_steps_break:
            current_step = sess.run(global_step)

            if (current_step % 10 == 0):
                # Collect some diagnostic data for Tensorboard.
                _, loss_out = sess.run([optimizer, loss])

                # Do some standard output.
                print current_step, loss_out
            else:
                # optimizer.minimize(sess)
                _, loss_out = sess.run([optimizer, loss])

        # Upon finishing, get the X tensor (our image).
        img_out = sess.run(X)

    # Save it.
    img_out = np.squeeze(img_out)
    utils.imwrite(output_img_path, img_out)
Ejemplo n.º 17
0
                         feed_dict={image: style_pre})

# create image merging content and style
g = tf.Graph()
with g.as_default(), g.device('/gpu:0'), tf.Session() as sess:
    # init randomly
    # white noise
    target = tf.random_normal((1, ) + content.shape)

    target_pre_var = tf.Variable(target)

    # build model with empty layer activations for generated target image
    model = network_model.get_model(target_pre_var)

    # compute loss
    cont_cost = losses.content_loss(content_out, model, C_LAYER,
                                    options.content_weight)
    style_cost = losses.style_loss(style_out, model, S_LAYERS,
                                   style_weight_layer)
    tv_cost = losses.total_var_loss(target_pre_var, options.tv_weight)

    total_loss = cont_cost + tf.add_n(style_cost) + tv_cost
    # total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss')

    train_step = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)

    sess.run(tf.global_variables_initializer())
    min_loss = float("inf")
    best = None
    for i in range(options.iter):
        train_step.run()
        print('Iteration %d/%d' % (i + 1, options.iter))
def main(unused_agrv=None):
    """main

    :param args:
        argparse.Namespace object from argparse.parse_args().
    """
    # Unpack command-line arguments.
    train_dir = FLAGS.train_dir
    style_dataset = FLAGS.style_dataset
    model_name = FLAGS.model_name
    preprocess_size = [FLAGS.image_size, FLAGS.image_size]
    batch_size = FLAGS.batch_size
    n_epochs = FLAGS.n_epochs
    run_name = FLAGS.run_name
    checkpoint = FLAGS.checkpoint
    learn_rate = FLAGS.learning_rate
    content_weights = FLAGS.content_weights
    style_weights = FLAGS.style_weights
    num_pipe_buffer = FLAGS.num_pipe_buffer
    style_coefficients = FLAGS.style_coefficients
    num_styles = FLAGS.num_styles
    train_steps = FLAGS.train_steps
    upsample_method = FLAGS.upsample_method

    # Setup input pipeline (delegate it to CPU to let GPU handle neural net)
    files = tf.train.match_filenames_once(train_dir + '/train-*')
    style_files = tf.train.match_filenames_once(style_dataset)
    print("style %s" % style_files)

    with tf.variable_scope('input_pipe'), tf.device('/cpu:0'):
        _, style_labels, style_grams = datapipe.style_batcher(
            style_files, batch_size, preprocess_size, n_epochs,
            num_pipe_buffer)
        batch_op = datapipe.batcher(files, batch_size, preprocess_size,
                                    n_epochs, num_pipe_buffer)
    """ Set up the style coefficients """
    if style_coefficients is None:
        style_coefficients = [1.0 for _ in range(num_styles)]
    else:
        style_coefficients = ast.literal_eval(style_coefficients)
    if len(style_coefficients) != num_styles:
        raise ValueError(
            'number of style coeffients differs from number of styles')
    style_coefficient = tf.gather(tf.constant(style_coefficients),
                                  style_labels)
    """ Set up weight of style and content image """
    content_weights = ast.literal_eval(content_weights)
    style_weights = ast.literal_eval(style_weights)
    style_weights = dict([(key, style_coefficient * val)
                          for key, val in style_weights.iteritems()])

    target_grams = []
    for name, val in style_weights.iteritems():
        target_grams.append(style_grams[name])

    # Alter the names to include a name_scope that we'll use + output suffix.
    loss_style_layers = []
    loss_style_weights = []
    loss_content_layers = []
    loss_content_weights = []
    for key, val in style_weights.iteritems():
        loss_style_layers.append(key + ':0')
        loss_style_weights.append(val)
    for key, val in content_weights.iteritems():
        loss_content_layers.append(key + ':0')
        loss_content_weights.append(val)

    # Load in image transformation network into default graph.
    shape = [batch_size] + preprocess_size + [3]
    with tf.variable_scope('styleNet'):
        X = tf.placeholder(tf.float32, shape=shape, name='input')
        Y = transform(X, style_labels, num_styles, upsample_method)
        print(Y)

    # Connect vgg directly to the image transformation network.
    with tf.variable_scope('vgg'):
        vggnet = vgg16.vgg16(Y)

    # Get the gram matrices' tensors for the style loss features.
    input_img_grams = utils.get_grams(loss_style_layers)

    # Get the tensors for content loss features.
    content_layers = utils.get_layers(loss_content_layers)

    # Create loss function
    content_targets = tuple(
        tf.placeholder(tf.float32,
                       shape=layer.get_shape(),
                       name='content_input_{}'.format(i))
        for i, layer in enumerate(content_layers))

    cont_loss = losses.content_loss(content_layers, content_targets,
                                    loss_content_weights)
    style_loss = losses.style_loss(input_img_grams, target_grams,
                                   loss_style_weights)
    loss = cont_loss + style_loss
    with tf.name_scope('summaries'):
        tf.summary.scalar('loss', loss)
        tf.summary.scalar('style_loss', style_loss)
        tf.summary.scalar('content_loss', cont_loss)

    # We do not want to train VGG, so we must grab the subset.
    other_vars = [
        var for var in tf.get_variable_scope('styleNet')
        if 'CondInstNorm' not in var.name
    ]

    train_vars = [
        var for var in tf.get_variable_scope('styleNet')
        if 'CondInstNorm' in var.name
    ]

    # Setup step + optimizer
    global_step = tf.Variable(0, name='global_step', trainable=False)
    optimizer = tf.train.AdamOptimizer(learn_rate).minimize(
        loss, global_step, train_vars)

    # Setup subdirectory for this run's Tensoboard logs.
    if not os.path.exists('./summaries/train/'):
        os.makedirs('./summaries/train/')
    if run_name is None:
        current_dirs = [
            name for name in os.listdir('./summaries/train/')
            if os.path.isdir('./summaries/train/' + name)
        ]
        name = model_name + '0'
        count = 0
        while name in current_dirs:
            count += 1
            name = model_name + '{}'.format(count)
        run_name = name

    # Savers and summary writers
    if not os.path.exists('./training'):  # Dir that we'll later save .ckpts to
        os.makedirs('./training')
    if not os.path.exists('./models'):  # Dir that save final models to
        os.makedirs('./models')

    saver = tf.train.Saver()
    saver_n_stylee = tf.train.Saver(other_vars)
    final_saver = tf.train.Saver(train_vars)
    merged = tf.summary.merge_all()
    full_log_path = './summaries/train/' + run_name
    train_writer = tf.summary.FileWriter(full_log_path, tf.Session().graph)

    # We must include local variables because of batch pipeline.
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    # Begin training.
    print 'Starting training...'
    with tf.Session() as sess:
        # Initialization
        sess.run(init_op)
        vggnet.load_weights(vgg16.checkpoint_file(), sess)
        saver_n_stylee.restore(sess, checkpoint)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            while not coord.should_stop():
                current_step = sess.run(global_step)
                batch = sess.run(batch_op)

                # Collect content targets
                content_data = sess.run(content_layers, feed_dict={Y: batch})
                feed_dict = {X: batch, content_targets: content_data}

                if (current_step % 1000 == 0):
                    # Save a checkpoint
                    save_path = 'training/' + model_name + '.ckpt'
                    saver.save(sess, save_path, global_step=global_step)
                    summary, _, loss_out, c_loss, s_loss = sess.run(
                        [merged, optimizer, loss, cont_loss, style_loss],
                        feed_dict=feed_dict)
                    train_writer.add_summary(summary, current_step)
                    print current_step, loss_out, c_loss, s_loss

                elif (current_step % 10 == 0):
                    # Collect some diagnostic data for Tensorboard.
                    summary, _, loss_out, c_loss, s_loss = sess.run(
                        [merged, optimizer, loss, cont_loss, style_loss],
                        feed_dict=feed_dict)
                    train_writer.add_summary(summary, current_step)

                    # Do some standard output.
                    # if (current_step % 1000 == 0):
                    print current_step, loss_out, c_loss, s_loss
                else:
                    _, loss_out = sess.run([optimizer, loss],
                                           feed_dict=feed_dict)

                # Throw error if we reach number of steps to break after.
                if current_step == train_steps:
                    print('Done training.')
                    break
        except tf.errors.OutOfRangeError:
            print('Done training.')
        finally:
            # Save the model (the image transformation network) for later usage
            # in predict.py
            final_saver.save(sess,
                             'models/' + model_name + '_final.ckpt',
                             write_meta_graph=False)

            coord.request_stop()

        coord.join(threads)
Ejemplo n.º 19
0
            loss += (analogy_weight / len(analogy_layers)) * al

    if mrf_weight != 0.0:
        for layer_name in mrf_layers:
            ap_image_features = K.variable(all_ap_image_features[layer_name][0])
            layer_features = outputs_dict[layer_name]
            combination_features = layer_features[0, :, :, :]
            sl = losses.mrf_loss(ap_image_features, combination_features,
                patch_size=patch_size, patch_stride=patch_stride)
            loss += (mrf_weight / len(mrf_layers)) * sl

    if b_bp_content_weight != 0.0:
        for layer_name in b_content_layers:
            b_features = K.variable(all_b_features[layer_name][0])
            bp_features = outputs_dict[layer_name]
            cl = losses.content_loss(bp_features, b_features)
            loss += b_bp_content_weight / len(b_content_layers) * cl

    loss += total_variation_weight * losses.total_variation_loss(vgg_input, img_width, img_height)

    # get the gradients of the generated image wrt the loss
    grads = K.gradients(loss, vgg_input)

    outputs = [loss]
    if type(grads) in {list, tuple}:
        outputs += grads
    else:
        outputs.append(grads)

    f_outputs = K.function([vgg_input], outputs)
    evaluator = Evaluator()
Ejemplo n.º 20
0
def main(argv=None):
    content_layers = FLAGS.content_layers.split(',')
    style_layers = FLAGS.style_layers.split(',')
    style_layers_weights = [
        float(i) for i in FLAGS.style_layers_weights.split(",")
    ]
    #num_steps_decay = 82786 / FLAGS.batch_size
    num_steps_decay = 10000

    style_features_t = losses.get_style_features(FLAGS)
    training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
    if not (os.path.exists(training_path)):
        os.makedirs(training_path)

    with tf.Session() as sess:
        """Build Network"""
        network_fn = nets_factory.get_network_fn(FLAGS.loss_model,
                                                 num_classes=1,
                                                 is_training=False)
        image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
            FLAGS.loss_model, is_training=False)
        processed_images = reader.image(FLAGS.batch_size,
                                        FLAGS.image_size,
                                        FLAGS.image_size,
                                        'train2014/',
                                        image_preprocessing_fn,
                                        epochs=FLAGS.epoch)
        generated = model.net(processed_images, FLAGS.alpha)
        processed_generated = [
            image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
            for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size)
        ]
        processed_generated = tf.stack(processed_generated)
        _, endpoints_dict = network_fn(tf.concat(
            [processed_generated, processed_images], 0),
                                       spatial_squeeze=False)
        """Build Losses"""
        content_loss = losses.content_loss(endpoints_dict, content_layers)
        style_loss, style_losses = losses.style_loss(endpoints_dict,
                                                     style_features_t,
                                                     style_layers,
                                                     style_layers_weights)
        tv_loss = losses.total_variation_loss(
            generated)  # use the unprocessed image
        content_loss = FLAGS.content_weight * content_loss
        style_loss = FLAGS.style_weight * style_loss
        tv_loss = FLAGS.tv_weight * tv_loss
        loss = style_loss + content_loss + tv_loss
        """Prepare to Train"""
        global_step = tf.Variable(0, name="global_step", trainable=False)
        variable_to_train = []
        for variable in tf.trainable_variables():
            if not (variable.name.startswith(FLAGS.loss_model)):
                variable_to_train.append(variable)

        lr = tf.train.exponential_decay(learning_rate=1e-1,
                                        global_step=global_step,
                                        decay_steps=num_steps_decay,
                                        decay_rate=1e-1,
                                        staircase=True)
        optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon=1e-8)
        train_op = optimizer.minimize(loss,
                                      global_step=global_step,
                                      var_list=variable_to_train)
        #train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train)
        variables_to_restore = []
        for v in tf.global_variables():
            if not (v.name.startswith(FLAGS.loss_model)):
                variables_to_restore.append(v)
        saver = tf.train.Saver(variables_to_restore)
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])
        init_func = utils._get_init_fn(FLAGS)
        init_func(sess)
        last_file = tf.train.latest_checkpoint(training_path)
        if last_file:
            saver.restore(sess, last_file)
        """Start Training"""
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while not coord.should_stop():
                _, c_loss, s_losses, t_loss, total_loss, step = sess.run([
                    train_op, content_loss, style_losses, tv_loss, loss,
                    global_step
                ])
                """logging"""
                if step % 10 == 0:
                    print(step, c_loss, s_losses, t_loss, total_loss)
                """checkpoint"""
                if step % 10000 == 0:
                    saver.save(sess,
                               os.path.join(training_path, 'fast-style-model'),
                               global_step=step)
                if step == FLAGS.max_iter:
                    saver.save(
                        sess,
                        os.path.join(training_path, 'fast-style-model-done'))
                    break
        except tf.errors.OutOfRangeError:
            saver.save(sess,
                       os.path.join(training_path, 'fast-style-model-done'))
            tf.logging.info('Done training -- epoch limit reached')
        finally:
            coord.request_stop()
        coord.join(threads)
Ejemplo n.º 21
0
        for layer_name in mrf_layers:
            ap_image_features = K.variable(
                all_ap_image_features[layer_name][0])
            layer_features = outputs_dict[layer_name]
            combination_features = layer_features[0, :, :, :]
            sl = losses.mrf_loss(ap_image_features,
                                 combination_features,
                                 patch_size=patch_size,
                                 patch_stride=patch_stride)
            loss += (mrf_weight / len(mrf_layers)) * sl

    if b_bp_content_weight != 0.0:
        for layer_name in b_content_layers:
            b_features = K.variable(all_b_features[layer_name][0])
            bp_features = outputs_dict[layer_name]
            cl = losses.content_loss(bp_features, b_features)
            loss += b_bp_content_weight / len(b_content_layers) * cl

    loss += total_variation_weight * losses.total_variation_loss(
        vgg_input, img_width, img_height)

    # get the gradients of the generated image wrt the loss
    grads = K.gradients(loss, vgg_input)

    outputs = [loss]
    if type(grads) in {list, tuple}:
        outputs += grads
    else:
        outputs.append(grads)

    f_outputs = K.function([vgg_input], outputs)
Ejemplo n.º 22
0
def train(**kwargs):
    opt = Config()
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)

    '''数据加载'''
    # 读取文件名
    filenames = [os.path.join(opt.data_root, f)
                 for f in os.listdir(opt.data_root)
                 if os.path.isfile(os.path.join(opt.data_root, f))]
    # 判断文件格式,png为True,jpeg为False
    png = filenames[0].lower().endswith('png')  # If first file is a png, assume they all are
    # 维持文件名队列
    filename_queue = tf.train.string_input_producer(filenames, shuffle=True, num_epochs=opt.epoches)
    # 初始化阅读器
    reader = tf.WholeFileReader()
    # 返回tuple,是key-value对
    _, img_bytes = reader.read(filename_queue)
    # 图片格式解码
    image_row = tf.image.decode_png(img_bytes, channels=3) if png else tf.image.decode_jpeg(img_bytes, channels=3)
    # 预处理
    image = utils.img_proprocess(image_row, opt.image_size)
    image_batch = tf.train.batch([image], opt.batch_size, dynamic_pad=True)

    '''生成式网络生成数据'''
    generated = net(image_batch, training=True)
    generated = tf.image.resize_bilinear(generated, [opt.image_size, opt.image_size], align_corners=False)
    generated.set_shape([opt.batch_size, opt.image_size, opt.image_size, 3])
    # unstack将指定维度拆分为1后降维,split随意指定拆分后维度值且不会自动降维
    # processed_generated = tf.stack([utils.img_proprocess(tf.squeeze(img, axis=0), opt.image_size)
    #                                 for img in tf.split(generated, num_or_size_splits=opt.batch_size, axis=0)])
    processed_generated = tf.stack([utils.img_proprocess(img, opt.image_size) for img in tf.unstack(generated, axis=0)])

    '''数据流经损失网络_VGG'''
    # 一次送入数据量为2×batch_size:[原始batch经生成式网络生成的数据 + 原始batch]
    with slim.arg_scope(vgg.vgg_arg_scope(weight_decay=0.0)):  # 调用
        _, endpoint = vgg.vgg_16(tf.concat([processed_generated, image_batch], 0),
                                 num_classes=1, is_training=False, spatial_squeeze=False)
    tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):')
    for key in endpoint:
        tf.logging.info(key)

    '''损失函数构建'''
    style_gram = utils.get_style_feature(opt.style_path,
                                         opt.image_size,
                                         opt.style_layers,
                                         opt.model_path,
                                         opt.exclude_scopes)
    content_loss, content_loss_summary = losses.content_loss(endpoint, opt.content_layers)
    style_loss, style_loss_summary = losses.style_loss(endpoint, style_gram, opt.style_layers)
    tv_loss = losses.total_variation_loss(generated)  # use the unprocessed image, 我们想要的图像也是这个
    loss = opt.style_weight * style_loss + opt.content_weight * content_loss + opt.tv_weight * tv_loss

    '''优化器构建'''
    # 优化器维护非vgg16的可训练变量
    variables_to_train = []
    for variable in tf.trainable_variables():
        if not (variable.name.startswith("vgg_16")):  # "vgg16"
            variables_to_train.append(variable)

    global_step = tf.Variable(0, name="global_step", trainable=False)
    train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variables_to_train)

    '''存储器构建'''
    # 存储器保存非vgg16的全局变量

    # tf.global_variables():返回全局变量。
    # 全局变量是分布式环境中跨计算机共享的变量。该Variable()构造函数或get_variable()
    # 自动将新变量添加到图形集合:GraphKeys.GLOBAL_VARIABLES。这个方便函数返回该集合的内容。
    # 全局变量的替代方法是局部变量。参考:tf.local_variables
    variables_to_restore = []  # 比trainable多出的主要是用于bp的变量
    for variable in tf.global_variables():
        if not (variable.name.startswith("vgg_16")):  # "vgg16"
            variables_to_restore.append(variable)

    saver = tf.train.Saver(var_list=variables_to_restore, write_version=tf.train.SaverDef.V2)

    """添加监测项"""
    # 添加总体loss监测
    tf.summary.scalar('losses/content_loss', content_loss)
    tf.summary.scalar('losses/style_loss', style_loss)
    tf.summary.scalar('losses/regularizer_loss', tv_loss)
    tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * opt.content_weight)
    tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * opt.style_weight)
    tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * opt.tv_weight)
    tf.summary.scalar('total_loss', loss)
    # 添加各层style loss监测
    for layer in opt.style_layers:
        tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer])
    # 监测生成图监测
    tf.summary.image('generated', generated)
    # tf.image_summary('processed_generated', processed_generated)  # May be better?
    # 添加原图监测
    tf.summary.image('origin', image_batch)

    summary_path = "./logs"
    model_path = "./logs/model"
    summary = tf.summary.merge_all()
    # with open('train_v.txt', 'w') as f:
    #     for s in variable_to_train:
    #         f.write(s.name + '\n')
    # with open('restore_v.txt', 'w') as f:
    #     for s in variables_to_restore:
    #         f.write(s.name + '\n')

    '''训练'''
    with tf.Session(config=config) as sess:
        writer = tf.summary.FileWriter(summary_path, sess.graph) 
        sess.run(tf.group(tf.global_variables_initializer(),
                          tf.local_variables_initializer()))

        # vgg网络预训练参数载入
        param_init_fn = utils.param_load_fn(opt.model_path, opt.exclude_scopes)
        param_init_fn(sess)

        # 由于使用saver,故载入变量不包含vgg16相关变量
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        ckpt = tf.train.get_checkpoint_state(model_path)
        if ckpt:
            tf.logging.info("Success to read {}".format(ckpt.model_checkpoint_path))
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            tf.logging.info("Failed to find a checkpoint")

        coord = tf.train.Coordinator()  # 线程控制器
        threads = tf.train.start_queue_runners(coord=coord)  # 启动队列
        start_time = time.time()  # 计时开始
        try:
            while not coord.should_stop():
                _, loss_t, step = sess.run([train_op, loss, global_step])
                elapsed_time = time.time() - start_time
                start_time = time.time()

                if step % 50 == 0:
                    tf.logging.info('step: {0:d}, total Loss {1:.2f}, secs/step: {2:.3f}'.
                                    format(step, loss_t, elapsed_time))
                if step % 100 == 0:
                    tf.logging.info('adding summary...')
                    summary_str = sess.run(summary)
                    writer.add_summary(summary_str, step)
                    writer.flush()
                if step % 1000 == 0:
                    saver.save(sess, os.path.join(model_path, 'fast_style_model'), global_step=step)

        except tf.errors.OutOfRangeError:
            saver.save(sess, os.path.join(model_path, 'fast_style_model'))
            tf.logging.info('Epoch limit reached')
        finally:
            coord.request_stop()
        coord.join(threads)
        writer.close()

    '''调试输出'''
Ejemplo n.º 23
0
def main(FLAGS):
    style_features_t = losses.get_style_features(FLAGS)

    # Make sure the training path exists.
    training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
    if not(os.path.exists(training_path)):
        os.makedirs(training_path)

    with tf.Graph().as_default():
        with tf.Session() as sess:
            """Build Network"""
            network_fn = nets_factory.get_network_fn(
                FLAGS.loss_model,
                num_classes=1,
                is_training=False)

            image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
                FLAGS.loss_model,
                is_training=False)
            processed_images = reader.image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size,
                                            'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch)
            generated = model.net(processed_images, training=True)
            processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
                                   for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size)
                                   ]
            processed_generated = tf.stack(processed_generated)
            _, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False)

            # Log the structure of loss network
            tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):')
            for key in endpoints_dict:
                tf.logging.info(key)

            """Build Losses"""
            content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers)
            style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers)
            tv_loss = losses.total_variation_loss(generated)  # use the unprocessed image

            loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss

            # Add Summary for visualization in tensorboard.
            """Add Summary"""
            tf.summary.scalar('losses/content_loss', content_loss)
            tf.summary.scalar('losses/style_loss', style_loss)
            tf.summary.scalar('losses/regularizer_loss', tv_loss)

            tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * FLAGS.content_weight)
            tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * FLAGS.style_weight)
            tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * FLAGS.tv_weight)
            tf.summary.scalar('total_loss', loss)

            for layer in FLAGS.style_layers:
                tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer])
            tf.summary.image('generated', generated)
            # tf.image_summary('processed_generated', processed_generated)  # May be better?
            tf.summary.image('origin', tf.stack([
                image_unprocessing_fn(image) for image in tf.unstack(processed_images, axis=0, num=FLAGS.batch_size)
            ]))
            summary = tf.summary.merge_all()
            writer = tf.summary.FileWriter(training_path)

            """Prepare to Train"""
            global_step = tf.Variable(0, name="global_step", trainable=False)

            variable_to_train = []
            for variable in tf.trainable_variables():
                if not(variable.name.startswith(FLAGS.loss_model)):
                    variable_to_train.append(variable)
            train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train)

            variables_to_restore = []
            for v in tf.global_variables():
                if not(v.name.startswith(FLAGS.loss_model)):
                    variables_to_restore.append(v)
            saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1)

            sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

            # Restore variables for loss network.
            init_func = utils._get_init_fn(FLAGS)
            init_func(sess)

            # Restore variables for training model if the checkpoint file exists.
            last_file = tf.train.latest_checkpoint(training_path)
            if last_file:
                tf.logging.info('Restoring model from {}'.format(last_file))
                saver.restore(sess, last_file)

            """Start Training"""
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            start_time = time.time()
            try:
                while not coord.should_stop():
                    _, loss_t, step = sess.run([train_op, loss, global_step])
                    elapsed_time = time.time() - start_time
                    start_time = time.time()
                    """logging"""
                    # print(step)
                    if step % 10 == 0:
                        tf.logging.info('step: %d,  total Loss %f, secs/step: %f' % (step, loss_t, elapsed_time))
                    """summary"""
                    if step % 25 == 0:
                        tf.logging.info('adding summary...')
                        summary_str = sess.run(summary)
                        writer.add_summary(summary_str, step)
                        writer.flush()
                    """checkpoint"""
                    if step % 1000 == 0:
                        saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt'), global_step=step)
            except tf.errors.OutOfRangeError:
                saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done'))
                tf.logging.info('Done training -- epoch limit reached')
            finally:
                coord.request_stop()
            coord.join(threads)
def main(FLAGS):
    # 得到风格特征
    style_features_t = losses.get_style_features(FLAGS)

    # Make sure the training path exists.
    training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
    if not (os.path.exists(training_path)):
        os.makedirs(training_path)

    with tf.Graph().as_default():
        with tf.Session() as sess:
            """Build Network"""
            # 构造vgg网络,按照FLAGS.loss_model中的网络名字,可以在/nets/nets_factory.py 中的networks_map找到对应
            network_fn = nets_factory.get_network_fn(FLAGS.loss_model,
                                                     num_classes=1,
                                                     is_training=False)
            # 根据不同网络做不同的预处理
            image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
                FLAGS.loss_model, is_training=False)
            # 读取一个批次的数据,并且预处理
            # 这里的数据你可以不用coco,可以直接给一个包含很多图片的文件夹即可
            # 因为coco过于大
            processed_images = reader.image(FLAGS.batch_size,
                                            FLAGS.image_height,
                                            FLAGS.image_width,
                                            'F:/CASIA/train_frame/real/',
                                            image_preprocessing_fn,
                                            epochs=FLAGS.epoch)
            # 通过生成网络,生成图片,相当于y^
            generated = model.net(processed_images, training=True)
            # 因为一会要把生成图片喂入到后面vgg进行计算两个损失,所以要先进行预处理
            processed_generated = [
                image_preprocessing_fn(image, FLAGS.image_height,
                                       FLAGS.image_width) for image in
                tf.unstack(generated, axis=0, num=FLAGS.batch_size)
            ]
            # 因为上面是list格式,所以用tf.stack堆叠成tensor
            processed_generated = tf.stack(processed_generated)
            # 按照batch那一个维度,拼起来,比如原来两个是[batch_size,h,w,c],concat后变为[2*batch_size,h,w,c]
            # 这样一次前向传播把y^ 和y_c的特征都计算出来了
            _, endpoints_dict = network_fn(tf.concat(
                [processed_generated, processed_images], 0),
                                           spatial_squeeze=False)

            # Log the structure of loss network
            tf.logging.info(
                'Loss network layers(You can define them in "content_layers" and "style_layers"):'
            )
            for key in endpoints_dict:
                tf.logging.info(key)
            """Build Losses"""
            # 计算三个损失
            content_loss = losses.content_loss(endpoints_dict,
                                               FLAGS.content_layers)
            style_loss, style_loss_summary = losses.style_loss(
                endpoints_dict, style_features_t, FLAGS.style_layers)
            tv_loss = losses.total_variation_loss(
                generated)  # use the unprocessed image

            loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss

            # Add Summary for visualization in tensorboard.
            """Add Summary"""
            # 为了tensorboard,可以忽略
            tf.summary.scalar('losses/content_loss', content_loss)
            tf.summary.scalar('losses/style_loss', style_loss)
            tf.summary.scalar('losses/regularizer_loss', tv_loss)

            tf.summary.scalar('weighted_losses/weighted_content_loss',
                              content_loss * FLAGS.content_weight)
            tf.summary.scalar('weighted_losses/weighted_style_loss',
                              style_loss * FLAGS.style_weight)
            tf.summary.scalar('weighted_losses/weighted_regularizer_loss',
                              tv_loss * FLAGS.tv_weight)
            tf.summary.scalar('total_loss', loss)

            for layer in FLAGS.style_layers:
                tf.summary.scalar('style_losses/' + layer,
                                  style_loss_summary[layer])
            tf.summary.image('generated', generated)
            # tf.image_summary('processed_generated', processed_generated)  # May be better?
            tf.summary.image(
                'origin',
                tf.stack([
                    image_unprocessing_fn(image) for image in tf.unstack(
                        processed_images, axis=0, num=FLAGS.batch_size)
                ]))
            summary = tf.summary.merge_all()
            writer = tf.summary.FileWriter(training_path)
            """Prepare to Train"""
            # 步数
            global_step = tf.Variable(0, name="global_step", trainable=False)

            variable_to_train = []
            for variable in tf.trainable_variables():
                # 把非vgg网络里面的可训练变量加入variable_to_train
                if not (variable.name.startswith(FLAGS.loss_model)):
                    variable_to_train.append(variable)
            # 注意var_list
            train_op = tf.train.AdamOptimizer(1e-3).minimize(
                loss, global_step=global_step, var_list=variable_to_train)

            variables_to_restore = []
            for v in tf.global_variables():
                # 把非vgg中的可存储变量加入variables_to_restore
                if not (v.name.startswith(FLAGS.loss_model)):
                    variables_to_restore.append(v)

            # 注意variables_to_restore
            saver = tf.train.Saver(variables_to_restore,
                                   write_version=tf.train.SaverDef.V1)

            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer()
            ])

            # Restore variables for loss network.
            # slim的,可以根据FLAGS里面配置把网络参数加载到sess这个会话里面
            init_func = utils._get_init_fn(FLAGS)
            init_func(sess)

            # Restore variables for training model if the checkpoint file exists.
            last_file = tf.train.latest_checkpoint(training_path)
            if last_file:
                tf.logging.info('Restoring model from {}'.format(last_file))
                saver.restore(sess, last_file)
            """Start Training"""
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            start_time = time.time()
            try:
                while not coord.should_stop():
                    _, loss_t, step = sess.run([train_op, loss, global_step])
                    elapsed_time = time.time() - start_time
                    start_time = time.time()
                    """logging"""
                    print(step)
                    if step % 10 == 0:
                        tf.logging.info(
                            'step: %d,  total Loss %f, secs/step: %f' %
                            (step, loss_t, elapsed_time))
                    """summary"""
                    if step % 25 == 0:
                        tf.logging.info('adding summary...')
                        summary_str = sess.run(summary)
                        writer.add_summary(summary_str, step)
                        writer.flush()
                    """checkpoint"""
                    if step % 1000 == 0:
                        saver.save(sess,
                                   os.path.join(training_path,
                                                'fast-style-model.ckpt'),
                                   global_step=step)
            except tf.errors.OutOfRangeError:
                saver.save(
                    sess,
                    os.path.join(training_path, 'fast-style-model.ckpt-done'))
                tf.logging.info('Done training -- epoch limit reached')
            finally:
                coord.request_stop()
            coord.join(threads)