Example #1
0
def style_transfer(
        content=None,
        content_dir=None,
        content_size=512,
        style=None,
        style_dir=None,
        style_size=512,
        crop=None,
        preserve_color=None,
        alpha=1.0,
        style_interp_weights=None,
        mask=None,
        output_dir='/output',
        save_ext='jpg',
        gpu=0,
        vgg_weights='/floyd_models/vgg19_weights_normalized.h5',
        decoder_weights='/floyd_models/decoder_weights.h5',
        tf_checkpoint_dir=None):
    assert bool(content) != bool(content_dir), 'Either content or content_dir should be given'
    assert bool(style) != bool(style_dir), 'Either style or style_dir should be given'

    if not os.path.exists(output_dir):
        print('Creating output dir at', output_dir)
        os.mkdir(output_dir)

    # Assume that it is either an h5 file or a name of a TensorFlow checkpoint
    # NOTE: For now, artificially switching off pretrained h5 weights
    decoder_in_h5 = False
    # decoder_weights.endswith('.h5')

    if content:
        content_batch = [content]
    else:
        assert mask is None, 'For spatial control use the --content option'
        content_batch = extract_image_names_recursive(content_dir)

    if style:
        style = style.split(',')
        if mask:
            assert len(style) == 2, 'For spatial control provide two style images'
            style_batch = [style]
        elif len(style) > 1: # Style blending
            if not style_interp_weights:
                # by default, all styles get equal weights
                style_interp_weights = np.array([1.0/len(style)] * len(style))
            else:
                # normalize weights so that their sum equals to one
                style_interp_weights = [float(w) for w in style_interp_weights.split(',')]
                style_interp_weights = np.array(style_interp_weights)
                style_interp_weights /= np.sum(style_interp_weights)
                assert len(style) == len(style_interp_weights), """--style and --style_interp_weights must have the same number of elements"""
            style_batch = [style]
        else:
            style_batch = style
    else:
        assert mask is None, 'For spatial control use the --style option'
        style_batch = extract_image_names_recursive(style_dir)

    print('Number of content images:', len(content_batch))
    print('Number of style images:', len(style_batch))

    if gpu >= 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
        data_format = 'channels_first'
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = ''
        data_format = 'channels_last'


    image, content, style, target, encoder, decoder, otherLayers = _build_graph(vgg_weights,
        decoder_weights if decoder_in_h5 else None, alpha, data_format=data_format)

    with tf.Session() as sess:
        if decoder_in_h5:
            sess.run(tf.global_variables_initializer())
        elif tf_checkpoint_dir is not None: # Some checkpoint was provided
            saver = tf.train.Saver()
            saver.restore(sess, os.path.join(tf_checkpoint_dir, 'adain-final'))

        for content_path, style_path in product(content_batch, style_batch):
            content_name = get_filename(content_path)
            content_image = load_image(content_path, content_size, crop=True)

            if isinstance(style_path, list): # Style blending/Spatial control
                style_paths = style_path
                style_name = '_'.join(map(get_filename, style_paths))

                # Gather all style images in one numpy array in order to get
                # their activations in one pass
                style_images = None
                for i, style_path in enumerate(style_paths):
                    style_image = load_image(style_path, style_size, crop)
                    if preserve_color:
                        style_image = coral(style_image, content_image)
                    style_image = prepare_image(style_image)
                    if style_images is None:
                        shape = tuple([len(style_paths)]) + style_image.shape
                        style_images = np.empty(shape)
                    assert style_images.shape[1:] == style_image.shape, """Style images must have the same shape"""
                    style_images[i] = style_image
                style_features = sess.run(encoder, feed_dict={
                    image: style_images
                })

                content_image = prepare_image(content_image)
                content_feature = sess.run(encoder, feed_dict={
                    image: content_image[np.newaxis,:]
                })

                if mask:
                    # For spatial control, extract foreground and background
                    # parts of the content using the corresponding masks,
                    # run them individually through AdaIN then combine
                    if data_format == 'channels_first':
                        _, c, h, w = content_feature.shape
                        content_view_shape = (c, -1)
                        mask_shape = lambda mask: (c, len(mask), 1)
                        mask_slice = lambda mask: (slice(None),mask)
                    else:
                        _, h, w, c = content_feature.shape
                        content_view_shape = (-1, c)
                        mask_shape = lambda mask: (1, len(mask), c)
                        mask_slice = lambda mask: (mask,slice(None))

                    mask = load_mask(mask, h, w).reshape(-1)
                    fg_mask = np.flatnonzero(mask == 1)
                    bg_mask = np.flatnonzero(mask == 0)

                    content_feat_view = content_feature.reshape(content_view_shape)
                    content_feat_fg = content_feat_view[mask_slice(fg_mask)].reshape(mask_shape(fg_mask))
                    content_feat_bg = content_feat_view[mask_slice(bg_mask)].reshape(mask_shape(bg_mask))

                    style_feature_fg = style_features[0]
                    style_feature_bg = style_features[1]

                    target_feature_fg = sess.run(target, feed_dict={
                        content: content_feat_fg[np.newaxis,:],
                        style: style_feature_fg[np.newaxis,:]
                    })
                    target_feature_fg = np.squeeze(target_feature_fg)

                    target_feature_bg = sess.run(target, feed_dict={
                        content: content_feat_bg[np.newaxis,:],
                        style: style_feature_bg[np.newaxis,:]
                    })
                    target_feature_bg = np.squeeze(target_feature_bg)

                    target_feature = np.zeros_like(content_feat_view)
                    target_feature[mask_slice(fg_mask)] = target_feature_fg
                    target_feature[mask_slice(bg_mask)] = target_feature_bg
                    target_feature = target_feature.reshape(content_feature.shape)
                else:
                    # For style blending, get activations for each style then
                    # take a weighted sum.
                    target_feature = np.zeros(content_feature.shape)
                    for style_feature, weight in zip(style_features, style_interp_weights):
                        target_feature += sess.run(target, feed_dict={
                            content: content_feature,
                            style: style_feature[np.newaxis,:]
                        }) * weight
            else:
                # NOTE: This is the part we care about, if only 1 style image is provided.
                style_name = get_filename(style_path)
                style_image = load_image(style_path, style_size, crop=True) # This only gives us square crop
                style_image = center_crop_np(style_image) # Actually crop the center out

                if preserve_color:
                    style_image = coral(style_image, content_image)
                style_image = prepare_image(style_image, True, data_format)
                content_image = prepare_image(content_image, True, data_format)
                
                # Extract other layers
                conv3_1_layer, conv4_1_layer = otherLayers
                style_feature, conv3_1_out_style, conv4_1_out_style = sess.run([encoder, conv3_1_layer, conv4_1_layer], feed_dict={
                    image: style_image[np.newaxis,:]
                })

                content_feature = sess.run(encoder, feed_dict={
                    image: content_image[np.newaxis,:]
                })

                target_feature = sess.run(target, feed_dict={
                    content: content_feature,
                    style: style_feature
                })

            output = sess.run(decoder, feed_dict={
                content: content_feature,
                target: target_feature,
                style: style_feature
            })

            # Grab the relevant layer outputs to see what's being minimized.
            conv3_1_out_output, conv4_1_out_output = sess.run([conv3_1_layer, conv4_1_layer], feed_dict={
                image: output
            })

            filename = '%s_stylized_%s.%s' % (content_name, style_name, save_ext)
            filename = os.path.join(output_dir, filename)
            save_image(filename, output[0], data_format=data_format)
            print('Output image saved at', filename)

            # TODO: Change these layers.
            layersToViz = [conv3_1_out_style, conv4_1_out_style, conv3_1_out_output, conv4_1_out_output]
def initialize_model():
    global vgg
    global encoder
    global decoder
    global target
    global weighted_target
    global image
    global content
    global style
    global persistent_session
    global data_format
    alpha = 1.0

    graph = tf.Graph()
    # build the detection model graph from the saved model protobuf
    with graph.as_default():
        image = tf.placeholder(shape=(None, 3, None, None), dtype=tf.float32)
        content = tf.placeholder(shape=(1, 512, None, None), dtype=tf.float32)
        style = tf.placeholder(shape=(1, 512, None, None), dtype=tf.float32)

        target = adain(content, style, data_format=data_format)
        weighted_target = target * alpha + (1 - alpha) * content

        with open_weights('models/vgg19_weights_normalized.h5') as w:
            vgg = build_vgg(image, w, data_format=data_format)
            encoder = vgg['conv4_1']

        with open_weights('models/decoder_weights.h5') as w:
            decoder = build_decoder(weighted_target, w, trainable=False, data_format=data_format)

        # the default session behavior is to consume the entire GPU RAM during inference!
        config = tf.ConfigProto()
        config.gpu_options.per_process_gpu_memory_fraction = 0.12

        # the persistent session across function calls exposed to external code interfaces
        persistent_session = tf.Session(graph=graph, config=config)

        persistent_session.run(tf.global_variables_initializer())

    print('Initialized model')

    while True:
        with ai_integration.get_next_input(inputs_schema={
            "style": {
                "type": "image"
            },
            "content": {
                "type": "image"
            },
        }) as inputs_dict:

            # only update the negative fields if we reach the end of the function - then update successfully
            result_data = {"content-type": 'text/plain',
                           "data": None,
                           "success": False,
                           "error": None}

            print('Starting inference')
            start = time.time()

            content_size = 512
            style_size = 512
            crop = False
            preserve_color = False

            content_image = load_image(io.BytesIO(inputs_dict['content']), content_size, crop)
            style_image = load_image(io.BytesIO(inputs_dict['style']), style_size, crop)

            if preserve_color:
                style_image = coral(style_image, content_image)
            style_image = prepare_image(style_image)
            content_image = prepare_image(content_image)
            style_feature = persistent_session.run(encoder, feed_dict={
                image: style_image[np.newaxis, :]
            })
            content_feature = persistent_session.run(encoder, feed_dict={
                image: content_image[np.newaxis, :]
            })
            target_feature = persistent_session.run(target, feed_dict={
                content: content_feature,
                style: style_feature
            })

            output = persistent_session.run(decoder, feed_dict={
                content: content_feature,
                target: target_feature
            })

            output_img_bytes = save_image_in_memory(output[0], data_format=data_format)

            result_data["content-type"] = 'image/jpeg'
            result_data["data"] = output_img_bytes
            result_data["success"] = True
            result_data["error"] = None

            print('Finished inference and it took ' + str(time.time() - start))
            ai_integration.send_result(result_data)