Ejemplo n.º 1
0
def _build_graph(vgg_weights, decoder_weights, alpha, data_format,
                 representation_layer, rep_index):
    if data_format == 'channels_first':
        num_channels = 64 * (2**rep_index)
        image = tf.placeholder(shape=(None, 3, None, None), dtype=tf.float32)
        content = tf.placeholder(shape=(1, num_channels, None, None),
                                 dtype=tf.float32)
        style = tf.placeholder(shape=(1, num_channels, None, None),
                               dtype=tf.float32)
    else:
        image = tf.placeholder(shape=(None, None, None, 3), dtype=tf.float32)
        content = tf.placeholder(shape=(1, None, None, 512), dtype=tf.float32)
        style = tf.placeholder(shape=(1, None, None, 512), dtype=tf.float32)

    target = adain(content, style, data_format=data_format)
    weighted_target = target * alpha + (1 - alpha) * content

    with open_weights(vgg_weights) as w:
        vgg = build_vgg(image, w, data_format=data_format)
        encoder = vgg[representation_layer]

    if decoder_weights:
        with open_weights(decoder_weights) as w:
            decoder = build_decoder(weighted_target,
                                    w,
                                    trainable=False,
                                    data_format=data_format)
    else:
        decoder = build_decoder(weighted_target,
                                None,
                                trainable=False,
                                data_format=data_format)

    # Return other layers on top of original outputs
    return image, content, style, target, encoder, decoder
Ejemplo n.º 2
0
def init_graph(vgg_weights, decoder_weights, alpha, data_format):
    if data_format == 'channels_first':
        image = tf.placeholder(shape=(None, 3, None, None), dtype=tf.float32)
        content = tf.placeholder(shape=(1, 512, None, None), dtype=tf.float32)
        style = tf.placeholder(shape=(1, 512, None, None), dtype=tf.float32)
    else:
        image = tf.placeholder(shape=(None, None, None, 3), dtype=tf.float32)
        content = tf.placeholder(shape=(1, None, None, 512), dtype=tf.float32)
        style = tf.placeholder(shape=(1, None, None, 512), dtype=tf.float32)

    target = adain(content, style, data_format=data_format)
    weighted_target = target * alpha + (1 - alpha) * content

    with open_weights(vgg_weights) as w:
        vgg = build_vgg(image, w, data_format=data_format)
        encoder = vgg['conv4_1']

    if decoder_weights:
        with open_weights(decoder_weights) as w:
            decoder = build_decoder(weighted_target,
                                    w,
                                    trainable=False,
                                    data_format=data_format)
    else:
        decoder = build_decoder(weighted_target,
                                None,
                                trainable=False,
                                data_format=data_format)

    return image, content, style, target, encoder, decoder
Ejemplo n.º 3
0
def _build_graph(vgg_weights, decoder_weights, alpha, data_format):
    if data_format == 'channels_first':
        image = tf.placeholder(shape=(None, 3, None, None), dtype=tf.float32)
        content = tf.placeholder(shape=(1, 512, None, None), dtype=tf.float32)
        style = tf.placeholder(shape=(1, 512, None, None), dtype=tf.float32)
    else:
        image = tf.placeholder(shape=(None, None, None, 3), dtype=tf.float32)
        content = tf.placeholder(shape=(1, None, None, 512), dtype=tf.float32)
        style = tf.placeholder(shape=(1, None, None, 512), dtype=tf.float32)

    target = adain(content, style, data_format=data_format)
    weighted_target = target * alpha + (1 - alpha) * content

    with open_weights(vgg_weights) as w:
        vgg = build_vgg(image, w, data_format=data_format)
        encoder = vgg['conv4_1']

        # Here we can return other layers to check out style encoding
        otherLayerNames = ["conv3_1", "conv4_1"]
        otherLayers = [vgg[i] for i in otherLayerNames]

    if decoder_weights:
        with open_weights(decoder_weights) as w:
            decoder = build_decoder(weighted_target,
                                    w,
                                    trainable=False,
                                    data_format=data_format)
    else:
        decoder = build_decoder(weighted_target,
                                None,
                                trainable=False,
                                data_format=data_format)

    # Return other layers on top of original outputs
    return image, content, style, target, encoder, decoder, otherLayers
Ejemplo n.º 4
0
def _build_graph(vgg_weights, decoder_weights, alpha, data_format):
    if data_format == CHANNELS_FIRST:
        content = tf.placeholder(shape=(1, 3, None, None),
                                 dtype=tf.float32,
                                 name='content')
        style = tf.placeholder(shape=(1, 3, None, None),
                               dtype=tf.float32,
                               name='style')
    else:
        content = tf.placeholder(shape=(1, None, None, 3),
                                 dtype=tf.float32,
                                 name='content')
        style = tf.placeholder(shape=(1, None, None, 3),
                               dtype=tf.float32,
                               name='style')

    with open_weights(vgg_weights) as w:
        with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE):
            vgg_content = build_vgg(content, w, data_format=data_format)
            vgg_style = build_vgg(style, w, data_format=data_format)
            content_feature = vgg_content['conv4_1']
            style_feature = vgg_style['conv4_1']

    target = adain(content_feature, style_feature, data_format=data_format)
    weighted_target = target * alpha + (1 - alpha) * content_feature

    if decoder_weights:
        with open_weights(decoder_weights) as w:
            decoder = build_decoder(weighted_target,
                                    w,
                                    trainable=False,
                                    data_format=data_format)
    else:
        decoder = build_decoder(weighted_target,
                                None,
                                trainable=False,
                                data_format=data_format)

    decoder = tf.identity(decoder, name='output')

    return content, style, decoder
Ejemplo n.º 5
0
def _build_graph(vgg_weights, decoder_weights, alpha, data_format):
    concat_axis = None # Which axis to concatenate target and style encoding on
    if data_format == 'channels_first':
        image = tf.placeholder(shape=(None,3,None,None), dtype=tf.float32)
        content = tf.placeholder(shape=(1,512,None,None), dtype=tf.float32)
        style = tf.placeholder(shape=(1,512,None,None), dtype=tf.float32)
        concat_axis = 1
    else:
        image = tf.placeholder(shape=(None,None,None,3), dtype=tf.float32)
        content = tf.placeholder(shape=(1,None,None,512), dtype=tf.float32)
        style = tf.placeholder(shape=(1,None,None,512), dtype=tf.float32)
        concat_axis = 3

    target = adain(content, style, data_format=data_format)
    # weighted_target = target * alpha + (1 - alpha) * content

    # NOTE: Here is where we add the style encoding to the decoder
    combined_target = tf.concat([target, style], axis=concat_axis)
    
    with open_weights(vgg_weights) as w:
        vgg = build_vgg(image, w, data_format=data_format)
        encoder = vgg['conv4_1']

        # Here we can return other layers to check out style encoding
        otherLayerNames = ["conv3_1", "conv4_1"]
        otherLayers = [vgg[i] for i in otherLayerNames]

    if decoder_weights:
        with open_weights(decoder_weights) as w:
            decoder = build_decoder(combined_target, w, trainable=False,
                data_format=data_format)
    else:
        decoder = build_decoder(combined_target, None, trainable=False,
            data_format=data_format)

    # Return other layers on top of original outputs
    return image, content, style, target, encoder, decoder, otherLayers
Ejemplo n.º 6
0
def train(content_dir='/floyd_images/',
          style_dir='/floyd_images/',
          checkpoint_dir='output',
          decoder_activation='relu',
          initial_size=512,
          random_crop_size=256,
          resume=False,
          optimizer='adam',
          learning_rate=1e-4,
          learning_rate_decay=5e-5,
          momentum=0.9,
          batch_size=8,
          num_epochs=44,
          content_layer='conv4_1',
          style_layers='conv1_1,conv2_1,conv3_1,conv4_1',
          tv_weight=0,
          style_weight=1e-2,
          content_weight=0.75,
          save_every=10000,
          print_every=10,
          gpu=0,
          vgg='/floyd_models/vgg19_weights_normalized.h5'):
    assert initial_size >= random_crop_size, 'Images are too small to be cropped'
    assert gpu >= 0, 'CPU mode is not supported'

    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)

    if not os.path.exists(checkpoint_dir):
        print('Creating checkpoint dir at', checkpoint_dir)
        os.mkdir(checkpoint_dir)

    style_layers = style_layers.split(',')

    # the content layer is also used as the encoder layer
    encoder_layer = content_layer
    encoder_layer_filters = vgg_layer_params(encoder_layer)[
        'filters']  # Just gives you the number of filters
    encoder_layer_shape = (None, encoder_layer_filters, None, None)

    # decoder->encoder setup
    if decoder_activation == 'relu':
        decoder_activation = tf.nn.relu
    elif decoder_activation == 'elu':
        decoder_activation = tf.nn.elu
    else:
        raise ValueError('Unknown activation: ' + decoder_activation)

    # This is a placeholder because we are going to feed it the output
    # from the encoder defined below.
    content_encoded = tf.placeholder(tf.float32, shape=encoder_layer_shape)
    style_encoded = tf.placeholder(tf.float32,
                                   shape=encoder_layer_shape)  # conv4_1
    output_encoded = adain(content_encoded, style_encoded)

    # TRIVIAL MASK
    trivial_mask_value = gen_trivial_mask()
    trivial_mask = tf.constant(trivial_mask_value,
                               dtype=tf.bool,
                               name="trivial_mask")

    window_mask_value = gen_window_mask()
    window_mask = tf.constant(window_mask_value,
                              dtype=tf.bool,
                              name="window_mask")

    # The same layers we pass in to the decoder need to be the same ones we use
    # to compute loss later.

    # Concatenate relevant inputs to be passed into decoder.
    output_combined = tf.concat([output_encoded, style_encoded], axis=1)
    images = build_decoder(output_combined,
                           weights=None,
                           trainable=True,
                           activation=decoder_activation)

    with open_weights(vgg) as w:
        vgg = build_vgg(images, w, last_layer=encoder_layer)
        encoder = vgg[encoder_layer]

    # loss setup
    # content_target, style_targets will hold activations of content and style
    # images respectively
    content_layer = vgg[
        content_layer]  # In this case it's the same as encoder_layer
    content_target = tf.placeholder(tf.float32, shape=encoder_layer_shape)
    style_layers = {layer: vgg[layer] for layer in style_layers}

    conv3_1_output_width_t, conv4_1_output_width_t = tf.shape(style_layers["conv3_1"], \
        out_type=tf.int32), tf.shape(style_layers["conv4_1"], out_type=tf.int32)

    style_targets = {
        layer: tf.placeholder(tf.float32, shape=style_layers[layer].shape)
        for layer in style_layers
    }

    conv3_1_output_width = tf.placeholder(tf.int32,
                                          shape=(),
                                          name="conv3_1_output_width")
    conv4_1_output_width = tf.placeholder(tf.int32,
                                          shape=(),
                                          name="conv4_1_output_width")

    content_loss = build_content_loss(content_layer, content_target, 0.75)

    style_texture_losses = build_style_texture_losses(style_layers,
                                                      style_targets,
                                                      style_weight * 0.1 * 2.0)
    style_content_loss = build_style_content_loss_guided(
        style_layers, style_targets, output_encoded, trivial_mask, window_mask,
        1.0)

    loss = tf.reduce_sum(list(
        style_texture_losses.values())) + style_content_loss

    if tv_weight:
        tv_loss = tf.reduce_sum(tf.image.total_variation(images)) * tv_weight
    else:
        tv_loss = tf.constant(0, dtype=tf.float32)
    loss += tv_loss

    # training setup
    batch = setup_input_pipeline(content_dir, style_dir, batch_size,
                                 num_epochs, initial_size, random_crop_size)

    global_step = tf.Variable(0, trainable=False, name='global_step')
    rate = tf.train.inverse_time_decay(learning_rate,
                                       global_step,
                                       decay_steps=1,
                                       decay_rate=learning_rate_decay)

    if optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(rate, beta1=momentum)
    elif optimizer == 'sgd':
        optimizer = tf.train.GradientDescentOptimizer(rate)
    else:
        raise ValueError('Unknown optimizer: ' + optimizer)

    train_op = optimizer.minimize(loss, global_step=global_step)

    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())

        if resume:
            latest = tf.train.latest_checkpoint(checkpoint_dir)
            saver.restore(sess, latest)
        else:
            sess.run(tf.global_variables_initializer())

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        with coord.stop_on_exception():
            while not coord.should_stop():
                content_batch, style_batch = sess.run(batch)

                # step 1
                # encode content and style images,
                # compute target style activations,
                # run content and style through AdaIN
                content_batch_encoded = sess.run(
                    encoder, feed_dict={images: content_batch})

                style_batch_encoded, style_target_vals = sess.run(
                    [encoder, style_layers], feed_dict={images: style_batch})

                # This is the AdaIN step
                output_batch_encoded = sess.run(output_encoded,
                                                feed_dict={
                                                    content_encoded:
                                                    content_batch_encoded,
                                                    style_encoded:
                                                    style_batch_encoded
                                                })

                # step 2
                # run the output batch through the decoder, compute loss
                feed_dict = {
                    output_encoded: output_batch_encoded,
                    style_encoded: style_batch_encoded,
                    # "We use the AdaIN output as the content target, instead of
                    # the commonly used feature responses of the content image"
                    content_target: output_batch_encoded
                    # filtered_x_target: filt_x_targ,
                    # filtered_y_target: filt_y_targ,
                    # conv3_1_output_width: conv3_1_shape[2],
                    # conv4_1_output_width: conv4_1_shape[2]
                }

                for layer in style_targets:
                    feed_dict[style_targets[layer]] = style_target_vals[layer]

                fetches = [
                    train_op, loss, content_loss, style_texture_losses,
                    style_content_loss, tv_loss, global_step
                ]
                result = sess.run(fetches, feed_dict=feed_dict)
                _, loss_val, content_loss_val, style_texture_loss_vals, style_content_loss_val, tv_loss_val, i = result

                # Print out the masks
                # fig = plt.figure()
                # for k in range(8):
                #     mask = fg_val[k, 0, :, :]
                #     pd.DataFrame(mask).to_csv("/output/fg_mask_" + str(k) + ".csv")
                #     fig.add_subplot(2, 4, k+1)
                #     plt.imshow(mask, cmap='gray')
                # plt.savefig("/output/fg_masks_" + str(i) + ".eps", format="eps", dpi=75)

                # fig = plt.figure()
                # for k in range(8):
                #     mask = bg_val[k, 0, :, :]
                #     pd.DataFrame(mask).to_csv("/output/bg_mask_" + str(k) + ".csv")
                #     fig.add_subplot(2, 4, k+1)
                #     plt.imshow(mask, cmap='gray')
                # plt.savefig("/output/bg_masks_" + str(i) + ".eps", format="eps", dpi=75)
                # for k in range(8):
                #     mask = tar_val[k, 0, :, :]
                #     fig.add_subplot(2, 4, k+1)
                #     mask_flattened = mask.flatten()
                #     print("Here is the shape")
                #     print(mask_flattened.shape)
                #     print(mask_flattened[:10])
                #     plt.hist(mask_flattened)
                #     plt.show()
                # plt.savefig("/output/first_layer_hist" + str(i) + ".eps", format="eps", dpi=75)
                # for k in range(8):
                #     mask = tar_val[k, 1, :, :]
                #     fig.add_subplot(2, 4, k+1)
                #     mask_flattened = mask.flatten()
                #     plt.hist(mask_flattened)
                #     plt.show()
                # plt.savefig("/output/second_layer_hist" + str(i) + ".eps", format="eps", dpi=75)
                # for k in range(8):
                #     first_activation = tar_val[k, 0, :, :]
                #     second_activation = tar_val[k, 1, :, :]
                #     pd.DataFrame(first_activation).to_csv("/output/first_activation_" + str(k) + ".csv")
                #     pd.DataFrame(second_activation).to_csv("/output/second_activation_" + str(k) + ".csv")

                if i % print_every == 0:
                    style_texture_loss_val = sum(
                        style_texture_loss_vals.values())
                    # style_loss_vals = '\t'.join(sorted(['%s = %0.4f' % (name, val) for name, val in style_loss_vals.items()]))
                    print(i,
                          'loss = %0.4f' % loss_val,
                          'content = %0.4f' % content_loss_val,
                          'style_texture = %0.4f' % style_texture_loss_val,
                          'style_content = %0.4f' % style_content_loss_val,
                          'tv = %0.4f' % tv_loss_val,
                          sep='\t')

                if i % save_every == 0:
                    print('Saving checkpoint')
                    saver.save(sess,
                               os.path.join(checkpoint_dir, 'adain'),
                               global_step=i)

        coord.join(threads)
        saver.save(sess, os.path.join(checkpoint_dir, 'adain-final'))
def initialize_model():
    global vgg
    global encoder
    global decoder
    global target
    global weighted_target
    global image
    global content
    global style
    global persistent_session
    global data_format
    alpha = 1.0

    graph = tf.Graph()
    # build the detection model graph from the saved model protobuf
    with graph.as_default():
        image = tf.placeholder(shape=(None, 3, None, None), dtype=tf.float32)
        content = tf.placeholder(shape=(1, 512, None, None), dtype=tf.float32)
        style = tf.placeholder(shape=(1, 512, None, None), dtype=tf.float32)

        target = adain(content, style, data_format=data_format)
        weighted_target = target * alpha + (1 - alpha) * content

        with open_weights('models/vgg19_weights_normalized.h5') as w:
            vgg = build_vgg(image, w, data_format=data_format)
            encoder = vgg['conv4_1']

        with open_weights('models/decoder_weights.h5') as w:
            decoder = build_decoder(weighted_target, w, trainable=False, data_format=data_format)

        # the default session behavior is to consume the entire GPU RAM during inference!
        config = tf.ConfigProto()
        config.gpu_options.per_process_gpu_memory_fraction = 0.12

        # the persistent session across function calls exposed to external code interfaces
        persistent_session = tf.Session(graph=graph, config=config)

        persistent_session.run(tf.global_variables_initializer())

    print('Initialized model')

    while True:
        with ai_integration.get_next_input(inputs_schema={
            "style": {
                "type": "image"
            },
            "content": {
                "type": "image"
            },
        }) as inputs_dict:

            # only update the negative fields if we reach the end of the function - then update successfully
            result_data = {"content-type": 'text/plain',
                           "data": None,
                           "success": False,
                           "error": None}

            print('Starting inference')
            start = time.time()

            content_size = 512
            style_size = 512
            crop = False
            preserve_color = False

            content_image = load_image(io.BytesIO(inputs_dict['content']), content_size, crop)
            style_image = load_image(io.BytesIO(inputs_dict['style']), style_size, crop)

            if preserve_color:
                style_image = coral(style_image, content_image)
            style_image = prepare_image(style_image)
            content_image = prepare_image(content_image)
            style_feature = persistent_session.run(encoder, feed_dict={
                image: style_image[np.newaxis, :]
            })
            content_feature = persistent_session.run(encoder, feed_dict={
                image: content_image[np.newaxis, :]
            })
            target_feature = persistent_session.run(target, feed_dict={
                content: content_feature,
                style: style_feature
            })

            output = persistent_session.run(decoder, feed_dict={
                content: content_feature,
                target: target_feature
            })

            output_img_bytes = save_image_in_memory(output[0], data_format=data_format)

            result_data["content-type"] = 'image/jpeg'
            result_data["data"] = output_img_bytes
            result_data["success"] = True
            result_data["error"] = None

            print('Finished inference and it took ' + str(time.time() - start))
            ai_integration.send_result(result_data)
Ejemplo n.º 8
0
def train(
        content_dir='/floyd_images/',
        style_dir='/floyd_images/',
        checkpoint_dir='output',
        decoder_activation='relu',
        initial_size=512,
        random_crop_size=256,
        resume=False,
        optimizer='adam',
        learning_rate=1e-4,
        learning_rate_decay=5e-5,
        momentum=0.9,
        batch_size=8,
        num_epochs=64,
        content_layer='conv4_1',
        style_layers='conv1_1,conv2_1,conv3_1,conv4_1',
        tv_weight=0,
        style_weight=1e-2,
        content_weight=0.75,
        save_every=10000,
        print_every=10,
        gpu=0,
        vgg='/floyd_models/vgg19_weights_normalized.h5'):
    assert initial_size >= random_crop_size, 'Images are too small to be cropped'
    assert gpu >= 0, 'CPU mode is not supported'

    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)

    if not os.path.exists(checkpoint_dir):
        print('Creating checkpoint dir at', checkpoint_dir)
        os.mkdir(checkpoint_dir)

    style_layers = style_layers.split(',')

    # the content layer is also used as the encoder layer
    encoder_layer = content_layer
    encoder_layer_filters = vgg_layer_params(encoder_layer)['filters'] # Just gives you the number of filters
    encoder_layer_shape = (None, encoder_layer_filters, None, None)

    # decoder->encoder setup
    if decoder_activation == 'relu':
        decoder_activation = tf.nn.relu
    elif decoder_activation == 'elu':
        decoder_activation = tf.nn.elu
    else:
        raise ValueError('Unknown activation: ' + decoder_activation)

    # This is a placeholder because we are going to feed it the output
    # from the encoder defined below.
    content_encoded = tf.placeholder(tf.float32, shape=encoder_layer_shape)
    style_encoded = tf.placeholder(tf.float32, shape=encoder_layer_shape)
    output_encoded = adain(content_encoded, style_encoded)
    images = build_decoder(output_encoded, weights=None, trainable=True,
        activation=decoder_activation)

    with open_weights(vgg) as w:
        vgg = build_vgg(images, w, last_layer=encoder_layer)
        encoder = vgg[encoder_layer]

    # loss setup
    # content_target, style_targets will hold activations of content and style
    # images respectively
    content_layer = vgg[content_layer] # In this case it's the same as encoder_layer
    content_target = tf.placeholder(tf.float32, shape=encoder_layer_shape)
    style_layers = {layer: vgg[layer] for layer in style_layers}
    style_targets = {
        layer: tf.placeholder(tf.float32, shape=style_layers[layer].shape)
        for layer in style_layers
    }

    content_loss = build_content_loss(content_layer, content_target, content_weight)

    style_texture_losses = build_style_texture_losses(style_layers, style_targets, style_weight)

    # Test with different style weights empirically
    style_content_loss = build_style_content_loss(style_layers, style_targets, 0.15)

    loss = content_loss + tf.reduce_sum(list(style_texture_losses.values())) + style_content_loss

    if tv_weight:
        tv_loss = tf.reduce_sum(tf.image.total_variation(images)) * tv_weight
    else:
        tv_loss = tf.constant(0, dtype=tf.float32)
    loss += tv_loss

    # training setup
    batch = setup_input_pipeline(content_dir, style_dir, batch_size,
        num_epochs, initial_size, random_crop_size)

    global_step = tf.Variable(0, trainable=False, name='global_step')
    rate = tf.train.inverse_time_decay(learning_rate, global_step,
        decay_steps=1, decay_rate=learning_rate_decay)

    if optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(rate, beta1=momentum)
    elif optimizer == 'sgd':
        optimizer = tf.train.GradientDescentOptimizer(rate)
    else:
        raise ValueError('Unknown optimizer: ' + optimizer)

    train_op = optimizer.minimize(loss, global_step=global_step)

    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())

        if resume:
            latest = tf.train.latest_checkpoint(checkpoint_dir)
            saver.restore(sess, latest)
        else:
            sess.run(tf.global_variables_initializer())
        
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        with coord.stop_on_exception():
            while not coord.should_stop():
                content_batch, style_batch = sess.run(batch)

                # step 1
                # encode content and style images,
                # compute target style activations,
                # run content and style through AdaIN
                content_batch_encoded = sess.run(encoder, feed_dict={
                    images: content_batch
                })

                style_batch_encoded, style_target_vals = sess.run([encoder, style_layers], feed_dict={
                    images: style_batch
                })

                # This is the AdaIN step
                output_batch_encoded = sess.run(output_encoded, feed_dict={
                    content_encoded: content_batch_encoded,
                    style_encoded: style_batch_encoded
                })

                # step 2
                # run the output batch through the decoder, compute loss
                feed_dict = {
                    output_encoded: output_batch_encoded,
                    # "We use the AdaIN output as the content target, instead of
                    # the commonly used feature responses of the content image"
                    content_target: output_batch_encoded
                }

                for layer in style_targets:
                    feed_dict[style_targets[layer]] = style_target_vals[layer]

                fetches = [train_op, loss, content_loss, style_texture_losses,
                    style_content_loss, tv_loss, global_step]
                result = sess.run(fetches, feed_dict=feed_dict)
                _, loss_val, content_loss_val, style_texture_loss_vals, style_content_loss_val, tv_loss_val, i = result

                if i % print_every == 0:
                    style_texture_loss_val = sum(style_texture_loss_vals.values())
                    # style_loss_vals = '\t'.join(sorted(['%s = %0.4f' % (name, val) for name, val in style_loss_vals.items()]))
                    print(i,
                        'loss = %0.4f' % loss_val,
                        'content = %0.4f' % content_loss_val,
                        'style_texture = %0.4f' % style_texture_loss_val,
                        'style_content = %0.4f' % style_content_loss_val,
                        'tv = %0.4f' % tv_loss_val, sep='\t')

                if i % save_every == 0:
                    print('Saving checkpoint')
                    saver.save(sess, os.path.join(checkpoint_dir, 'adain'), global_step=i)

        coord.join(threads)
        saver.save(sess, os.path.join(checkpoint_dir, 'adain-final'))
Ejemplo n.º 9
0
def train(content_dir='/floyd_images/',
          style_dir='/floyd_images/',
          checkpoint_dir='output',
          decoder_activation='relu',
          initial_size=512,
          random_crop_size=256,
          resume=False,
          optimizer='adam',
          learning_rate=1e-4,
          learning_rate_decay=5e-5,
          momentum=0.9,
          batch_size=8,
          num_epochs=64,
          content_layer='conv4_1',
          style_layers='conv1_1,conv2_1,conv3_1,conv4_1',
          tv_weight=0,
          style_weight=1e-2,
          content_weight=0.75,
          save_every=10000,
          print_every=10,
          gpu=0,
          vgg='/floyd_models/vgg19_weights_normalized.h5'):
    assert initial_size >= random_crop_size, 'Images are too small to be cropped'
    assert gpu >= 0, 'CPU mode is not supported'

    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)

    if not os.path.exists(checkpoint_dir):
        print('Creating checkpoint dir at', checkpoint_dir)
        os.mkdir(checkpoint_dir)

    style_layers = style_layers.split(',')

    # the content layer is also used as the encoder layer
    encoder_layer = content_layer
    encoder_layer_filters = vgg_layer_params(encoder_layer)[
        'filters']  # Just gives you the number of filters
    encoder_layer_shape = (None, encoder_layer_filters, None, None)

    # decoder->encoder setup
    if decoder_activation == 'relu':
        decoder_activation = tf.nn.relu
    elif decoder_activation == 'elu':
        decoder_activation = tf.nn.elu
    else:
        raise ValueError('Unknown activation: ' + decoder_activation)

    # This is a placeholder because we are going to feed it the output
    # from the encoder defined below.
    content_encoded = tf.placeholder(tf.float32, shape=encoder_layer_shape)
    style_encoded = tf.placeholder(tf.float32, shape=encoder_layer_shape)
    output_encoded = adain(content_encoded, style_encoded)
    # NOTE: "images" contains the output of the decoder
    images = build_decoder(output_encoded,
                           weights=None,
                           trainable=True,
                           activation=decoder_activation)

    # New placeholder just to hold content images
    # content_image = tf.placeholder(tf.float32, shape=(None, 3, random_crop_size, random_crop_size))
    images_reshaped = tf.transpose(images, perm=(0, 2, 3, 1))
    grayscaled_content = tf.image.rgb_to_grayscale(images_reshaped)
    # Run sobel operators on it
    filtered_x, filtered_y = edge_detection(grayscaled_content)

    with open_weights(vgg) as w:
        # We need the VGG for loss computation
        vgg = build_vgg(images, w, last_layer=encoder_layer)
        encoder = vgg[encoder_layer]

    # loss setup
    # content_target, style_targets will hold activations of content and style
    # images respectively
    content_layer = vgg[
        content_layer]  # In this case it's the same as encoder_layer
    content_target = tf.placeholder(tf.float32, shape=encoder_layer_shape)
    style_layers = {layer: vgg[layer] for layer in style_layers}

    conv3_1_output_width_t, conv4_1_output_width_t = tf.shape(style_layers["conv3_1"], \
        out_type=tf.int32), tf.shape(style_layers["conv4_1"], out_type=tf.int32)

    style_targets = {
        layer: tf.placeholder(tf.float32, shape=style_layers[layer].shape)
        for layer in style_layers
    }

    # Define placeholders for the targets
    filtered_x_target = tf.placeholder(tf.float32,
                                       shape=filtered_x.get_shape())
    filtered_y_target = tf.placeholder(tf.float32,
                                       shape=filtered_y.get_shape())

    conv3_1_output_width = tf.placeholder(tf.int32,
                                          shape=(),
                                          name="conv3_1_output_width")
    conv4_1_output_width = tf.placeholder(tf.int32,
                                          shape=(),
                                          name="conv4_1_output_width")

    content_general_loss = build_content_general_loss(content_layer,
                                                      content_target, 0.25)
    content_edge_loss = build_content_edge_loss(filtered_x, filtered_y,
                                                filtered_x_target,
                                                filtered_y_target, 3.0)
    style_texture_losses = build_style_texture_losses(style_layers,
                                                      style_targets,
                                                      style_weight)
    style_content_loss, rel_pixels_sum, pos_act_sum = build_style_content_loss(
        style_layers, style_targets, 2.5)

    loss = content_general_loss + content_edge_loss + tf.reduce_sum(
        list(style_texture_losses.values())) + style_content_loss

    if tv_weight:
        tv_loss = tf.reduce_sum(tf.image.total_variation(images)) * tv_weight
    else:
        tv_loss = tf.constant(0, dtype=tf.float32)
    loss += tv_loss

    # training setup
    batch = setup_input_pipeline(content_dir, style_dir, batch_size,
                                 num_epochs, initial_size, random_crop_size)

    global_step = tf.Variable(0, trainable=False, name='global_step')
    rate = tf.train.inverse_time_decay(learning_rate,
                                       global_step,
                                       decay_steps=1,
                                       decay_rate=learning_rate_decay)

    if optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(rate, beta1=momentum)
    elif optimizer == 'sgd':
        optimizer = tf.train.GradientDescentOptimizer(rate)
    else:
        raise ValueError('Unknown optimizer: ' + optimizer)

    train_op = optimizer.minimize(loss, global_step=global_step)

    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())

        if resume:
            latest = tf.train.latest_checkpoint(checkpoint_dir)
            saver.restore(sess, latest)
        else:
            sess.run(tf.global_variables_initializer())

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        with coord.stop_on_exception():
            while not coord.should_stop():
                content_batch, style_batch = sess.run(batch)

                # step 1
                # encode content and style images,
                # compute target style activations,
                # run content and style through AdaIN
                content_batch_encoded = sess.run(
                    encoder, feed_dict={images: content_batch})

                style_batch_encoded, style_target_vals = sess.run(
                    [encoder, style_layers], feed_dict={images: style_batch})

                # This is the AdaIN step
                output_batch_encoded = sess.run(output_encoded,
                                                feed_dict={
                                                    content_encoded:
                                                    content_batch_encoded,
                                                    style_encoded:
                                                    style_batch_encoded
                                                })

                # Actual target values for edge loss
                filt_x_targ, filt_y_targ = sess.run(
                    [filtered_x, filtered_y],
                    feed_dict={images: content_batch})

                # TODO: Need to compute output shapes before we can actually compute guided COS loss.
                conv3_1_shape, conv4_1_shape = sess.run(
                    [conv3_1_output_width_t, conv4_1_output_width_t],
                    feed_dict={images: content_batch})

                # step 2
                # run the output batch through the decoder, compute loss
                feed_dict = {
                    output_encoded: output_batch_encoded,
                    # "We use the AdaIN output as the content target, instead of
                    # the commonly used feature responses of the content image"
                    content_target: output_batch_encoded,
                    filtered_x_target: filt_x_targ,
                    filtered_y_target: filt_y_targ,
                    conv3_1_output_width: conv3_1_shape[2],
                    conv4_1_output_width: conv4_1_shape[2]
                }

                for layer in style_targets:
                    feed_dict[style_targets[layer]] = style_target_vals[layer]

                fetches = [
                    train_op, images, loss, content_general_loss,
                    content_edge_loss, style_texture_losses,
                    style_content_loss, rel_pixels_sum, pos_act_sum, tv_loss,
                    global_step
                ]
                result = sess.run(fetches, feed_dict=feed_dict)
                _, output_images, loss_val, content_general_loss_val, content_edge_loss_val, \
                    style_texture_loss_vals, style_content_loss_val, rel_pixels_sum_val, pos_act_sum_val, \
                    tv_loss_val, i = result

                # Try to plot these out?
                # (8, 256, 256, 1)
                # save_edge_images(filt_x_orig, batch_size, "x_filters")
                # save_edge_images(filt_y_orig, batch_size, "y_filters")
                # original_content_batch = np.transpose(content_batch, axes=(0, 2, 3, 1))
                # save_edge_images(original_content_batch, batch_size, "original_r")
                # exit()
                if i % print_every == 0:
                    style_texture_loss_val = sum(
                        style_texture_loss_vals.values())
                    # style_loss_vals = '\t'.join(sorted(['%s = %0.4f' % (name, val) for name, val in style_loss_vals.items()]))
                    print(i,
                          'loss = %0.4f' % loss_val,
                          'content_general = %0.4f' % content_general_loss_val,
                          'content_edge = %0.4f' % content_edge_loss_val,
                          'style_texture = %0.4f' % style_texture_loss_val,
                          'style_content = %0.4f' % style_content_loss_val,
                          'rel_pixels_sum_val = %0.4f' % rel_pixels_sum_val,
                          'pos_act_sum_val = %0.4f' % pos_act_sum_val,
                          'tv = %0.4f' % tv_loss_val,
                          sep='\t')

                if i % save_every == 0:
                    print('Saving checkpoint')
                    saver.save(sess,
                               os.path.join(checkpoint_dir, 'adain'),
                               global_step=i)

        coord.join(threads)
        saver.save(sess, os.path.join(checkpoint_dir, 'adain-final'))