def style_transfer( content=None, content_dir=None, content_size=512, style=None, style_dir=None, style_size=512, crop=None, preserve_color=None, alpha=1.0, style_interp_weights=None, mask=None, output_dir='/output', save_ext='jpg', gpu=0, vgg_weights='/floyd_models/vgg19_weights_normalized.h5', decoder_weights='/floyd_models/decoder_weights.h5', tf_checkpoint_dir=None): assert bool(content) != bool(content_dir), 'Either content or content_dir should be given' assert bool(style) != bool(style_dir), 'Either style or style_dir should be given' if not os.path.exists(output_dir): print('Creating output dir at', output_dir) os.mkdir(output_dir) # Assume that it is either an h5 file or a name of a TensorFlow checkpoint # NOTE: For now, artificially switching off pretrained h5 weights decoder_in_h5 = False # decoder_weights.endswith('.h5') if content: content_batch = [content] else: assert mask is None, 'For spatial control use the --content option' content_batch = extract_image_names_recursive(content_dir) if style: style = style.split(',') if mask: assert len(style) == 2, 'For spatial control provide two style images' style_batch = [style] elif len(style) > 1: # Style blending if not style_interp_weights: # by default, all styles get equal weights style_interp_weights = np.array([1.0/len(style)] * len(style)) else: # normalize weights so that their sum equals to one style_interp_weights = [float(w) for w in style_interp_weights.split(',')] style_interp_weights = np.array(style_interp_weights) style_interp_weights /= np.sum(style_interp_weights) assert len(style) == len(style_interp_weights), """--style and --style_interp_weights must have the same number of elements""" style_batch = [style] else: style_batch = style else: assert mask is None, 'For spatial control use the --style option' style_batch = extract_image_names_recursive(style_dir) print('Number of content images:', len(content_batch)) print('Number of style images:', len(style_batch)) if gpu >= 0: os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) data_format = 'channels_first' else: os.environ['CUDA_VISIBLE_DEVICES'] = '' data_format = 'channels_last' image, content, style, target, encoder, decoder, otherLayers = _build_graph(vgg_weights, decoder_weights if decoder_in_h5 else None, alpha, data_format=data_format) with tf.Session() as sess: if decoder_in_h5: sess.run(tf.global_variables_initializer()) elif tf_checkpoint_dir is not None: # Some checkpoint was provided saver = tf.train.Saver() saver.restore(sess, os.path.join(tf_checkpoint_dir, 'adain-final')) for content_path, style_path in product(content_batch, style_batch): content_name = get_filename(content_path) content_image = load_image(content_path, content_size, crop=True) if isinstance(style_path, list): # Style blending/Spatial control style_paths = style_path style_name = '_'.join(map(get_filename, style_paths)) # Gather all style images in one numpy array in order to get # their activations in one pass style_images = None for i, style_path in enumerate(style_paths): style_image = load_image(style_path, style_size, crop) if preserve_color: style_image = coral(style_image, content_image) style_image = prepare_image(style_image) if style_images is None: shape = tuple([len(style_paths)]) + style_image.shape style_images = np.empty(shape) assert style_images.shape[1:] == style_image.shape, """Style images must have the same shape""" style_images[i] = style_image style_features = sess.run(encoder, feed_dict={ image: style_images }) content_image = prepare_image(content_image) content_feature = sess.run(encoder, feed_dict={ image: content_image[np.newaxis,:] }) if mask: # For spatial control, extract foreground and background # parts of the content using the corresponding masks, # run them individually through AdaIN then combine if data_format == 'channels_first': _, c, h, w = content_feature.shape content_view_shape = (c, -1) mask_shape = lambda mask: (c, len(mask), 1) mask_slice = lambda mask: (slice(None),mask) else: _, h, w, c = content_feature.shape content_view_shape = (-1, c) mask_shape = lambda mask: (1, len(mask), c) mask_slice = lambda mask: (mask,slice(None)) mask = load_mask(mask, h, w).reshape(-1) fg_mask = np.flatnonzero(mask == 1) bg_mask = np.flatnonzero(mask == 0) content_feat_view = content_feature.reshape(content_view_shape) content_feat_fg = content_feat_view[mask_slice(fg_mask)].reshape(mask_shape(fg_mask)) content_feat_bg = content_feat_view[mask_slice(bg_mask)].reshape(mask_shape(bg_mask)) style_feature_fg = style_features[0] style_feature_bg = style_features[1] target_feature_fg = sess.run(target, feed_dict={ content: content_feat_fg[np.newaxis,:], style: style_feature_fg[np.newaxis,:] }) target_feature_fg = np.squeeze(target_feature_fg) target_feature_bg = sess.run(target, feed_dict={ content: content_feat_bg[np.newaxis,:], style: style_feature_bg[np.newaxis,:] }) target_feature_bg = np.squeeze(target_feature_bg) target_feature = np.zeros_like(content_feat_view) target_feature[mask_slice(fg_mask)] = target_feature_fg target_feature[mask_slice(bg_mask)] = target_feature_bg target_feature = target_feature.reshape(content_feature.shape) else: # For style blending, get activations for each style then # take a weighted sum. target_feature = np.zeros(content_feature.shape) for style_feature, weight in zip(style_features, style_interp_weights): target_feature += sess.run(target, feed_dict={ content: content_feature, style: style_feature[np.newaxis,:] }) * weight else: # NOTE: This is the part we care about, if only 1 style image is provided. style_name = get_filename(style_path) style_image = load_image(style_path, style_size, crop=True) # This only gives us square crop style_image = center_crop_np(style_image) # Actually crop the center out if preserve_color: style_image = coral(style_image, content_image) style_image = prepare_image(style_image, True, data_format) content_image = prepare_image(content_image, True, data_format) # Extract other layers conv3_1_layer, conv4_1_layer = otherLayers style_feature, conv3_1_out_style, conv4_1_out_style = sess.run([encoder, conv3_1_layer, conv4_1_layer], feed_dict={ image: style_image[np.newaxis,:] }) content_feature = sess.run(encoder, feed_dict={ image: content_image[np.newaxis,:] }) target_feature = sess.run(target, feed_dict={ content: content_feature, style: style_feature }) output = sess.run(decoder, feed_dict={ content: content_feature, target: target_feature, style: style_feature }) # Grab the relevant layer outputs to see what's being minimized. conv3_1_out_output, conv4_1_out_output = sess.run([conv3_1_layer, conv4_1_layer], feed_dict={ image: output }) filename = '%s_stylized_%s.%s' % (content_name, style_name, save_ext) filename = os.path.join(output_dir, filename) save_image(filename, output[0], data_format=data_format) print('Output image saved at', filename) # TODO: Change these layers. layersToViz = [conv3_1_out_style, conv4_1_out_style, conv3_1_out_output, conv4_1_out_output]
def initialize_model(): global vgg global encoder global decoder global target global weighted_target global image global content global style global persistent_session global data_format alpha = 1.0 graph = tf.Graph() # build the detection model graph from the saved model protobuf with graph.as_default(): image = tf.placeholder(shape=(None, 3, None, None), dtype=tf.float32) content = tf.placeholder(shape=(1, 512, None, None), dtype=tf.float32) style = tf.placeholder(shape=(1, 512, None, None), dtype=tf.float32) target = adain(content, style, data_format=data_format) weighted_target = target * alpha + (1 - alpha) * content with open_weights('models/vgg19_weights_normalized.h5') as w: vgg = build_vgg(image, w, data_format=data_format) encoder = vgg['conv4_1'] with open_weights('models/decoder_weights.h5') as w: decoder = build_decoder(weighted_target, w, trainable=False, data_format=data_format) # the default session behavior is to consume the entire GPU RAM during inference! config = tf.ConfigProto() config.gpu_options.per_process_gpu_memory_fraction = 0.12 # the persistent session across function calls exposed to external code interfaces persistent_session = tf.Session(graph=graph, config=config) persistent_session.run(tf.global_variables_initializer()) print('Initialized model') while True: with ai_integration.get_next_input(inputs_schema={ "style": { "type": "image" }, "content": { "type": "image" }, }) as inputs_dict: # only update the negative fields if we reach the end of the function - then update successfully result_data = {"content-type": 'text/plain', "data": None, "success": False, "error": None} print('Starting inference') start = time.time() content_size = 512 style_size = 512 crop = False preserve_color = False content_image = load_image(io.BytesIO(inputs_dict['content']), content_size, crop) style_image = load_image(io.BytesIO(inputs_dict['style']), style_size, crop) if preserve_color: style_image = coral(style_image, content_image) style_image = prepare_image(style_image) content_image = prepare_image(content_image) style_feature = persistent_session.run(encoder, feed_dict={ image: style_image[np.newaxis, :] }) content_feature = persistent_session.run(encoder, feed_dict={ image: content_image[np.newaxis, :] }) target_feature = persistent_session.run(target, feed_dict={ content: content_feature, style: style_feature }) output = persistent_session.run(decoder, feed_dict={ content: content_feature, target: target_feature }) output_img_bytes = save_image_in_memory(output[0], data_format=data_format) result_data["content-type"] = 'image/jpeg' result_data["data"] = output_img_bytes result_data["success"] = True result_data["error"] = None print('Finished inference and it took ' + str(time.time() - start)) ai_integration.send_result(result_data)