def style_transfer( content=None, content_dir=None, content_size=512, style=None, style_dir=None, style_size=512, crop=None, preserve_color=None, alpha=1.0, style_interp_weights=None, mask=None, output_dir='/output', save_ext='jpg', gpu=0, vgg_weights='/floyd_models/vgg19_weights_normalized.h5', decoder_weights='/floyd_models/decoder_weights.h5', tf_checkpoint_dir=None): assert bool(content) != bool(content_dir), 'Either content or content_dir should be given' assert bool(style) != bool(style_dir), 'Either style or style_dir should be given' if not os.path.exists(output_dir): print('Creating output dir at', output_dir) os.mkdir(output_dir) # Assume that it is either an h5 file or a name of a TensorFlow checkpoint # NOTE: For now, artificially switching off pretrained h5 weights decoder_in_h5 = False # decoder_weights.endswith('.h5') if content: content_batch = [content] else: assert mask is None, 'For spatial control use the --content option' content_batch = extract_image_names_recursive(content_dir) if style: style = style.split(',') if mask: assert len(style) == 2, 'For spatial control provide two style images' style_batch = [style] elif len(style) > 1: # Style blending if not style_interp_weights: # by default, all styles get equal weights style_interp_weights = np.array([1.0/len(style)] * len(style)) else: # normalize weights so that their sum equals to one style_interp_weights = [float(w) for w in style_interp_weights.split(',')] style_interp_weights = np.array(style_interp_weights) style_interp_weights /= np.sum(style_interp_weights) assert len(style) == len(style_interp_weights), """--style and --style_interp_weights must have the same number of elements""" style_batch = [style] else: style_batch = style else: assert mask is None, 'For spatial control use the --style option' style_batch = extract_image_names_recursive(style_dir) print('Number of content images:', len(content_batch)) print('Number of style images:', len(style_batch)) if gpu >= 0: os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) data_format = 'channels_first' else: os.environ['CUDA_VISIBLE_DEVICES'] = '' data_format = 'channels_last' image, content, style, target, encoder, decoder, otherLayers = _build_graph(vgg_weights, decoder_weights if decoder_in_h5 else None, alpha, data_format=data_format) with tf.Session() as sess: if decoder_in_h5: sess.run(tf.global_variables_initializer()) elif tf_checkpoint_dir is not None: # Some checkpoint was provided saver = tf.train.Saver() saver.restore(sess, os.path.join(tf_checkpoint_dir, 'adain-final')) for content_path, style_path in product(content_batch, style_batch): content_name = get_filename(content_path) content_image = load_image(content_path, content_size, crop=True) if isinstance(style_path, list): # Style blending/Spatial control style_paths = style_path style_name = '_'.join(map(get_filename, style_paths)) # Gather all style images in one numpy array in order to get # their activations in one pass style_images = None for i, style_path in enumerate(style_paths): style_image = load_image(style_path, style_size, crop) if preserve_color: style_image = coral(style_image, content_image) style_image = prepare_image(style_image) if style_images is None: shape = tuple([len(style_paths)]) + style_image.shape style_images = np.empty(shape) assert style_images.shape[1:] == style_image.shape, """Style images must have the same shape""" style_images[i] = style_image style_features = sess.run(encoder, feed_dict={ image: style_images }) content_image = prepare_image(content_image) content_feature = sess.run(encoder, feed_dict={ image: content_image[np.newaxis,:] }) if mask: # For spatial control, extract foreground and background # parts of the content using the corresponding masks, # run them individually through AdaIN then combine if data_format == 'channels_first': _, c, h, w = content_feature.shape content_view_shape = (c, -1) mask_shape = lambda mask: (c, len(mask), 1) mask_slice = lambda mask: (slice(None),mask) else: _, h, w, c = content_feature.shape content_view_shape = (-1, c) mask_shape = lambda mask: (1, len(mask), c) mask_slice = lambda mask: (mask,slice(None)) mask = load_mask(mask, h, w).reshape(-1) fg_mask = np.flatnonzero(mask == 1) bg_mask = np.flatnonzero(mask == 0) content_feat_view = content_feature.reshape(content_view_shape) content_feat_fg = content_feat_view[mask_slice(fg_mask)].reshape(mask_shape(fg_mask)) content_feat_bg = content_feat_view[mask_slice(bg_mask)].reshape(mask_shape(bg_mask)) style_feature_fg = style_features[0] style_feature_bg = style_features[1] target_feature_fg = sess.run(target, feed_dict={ content: content_feat_fg[np.newaxis,:], style: style_feature_fg[np.newaxis,:] }) target_feature_fg = np.squeeze(target_feature_fg) target_feature_bg = sess.run(target, feed_dict={ content: content_feat_bg[np.newaxis,:], style: style_feature_bg[np.newaxis,:] }) target_feature_bg = np.squeeze(target_feature_bg) target_feature = np.zeros_like(content_feat_view) target_feature[mask_slice(fg_mask)] = target_feature_fg target_feature[mask_slice(bg_mask)] = target_feature_bg target_feature = target_feature.reshape(content_feature.shape) else: # For style blending, get activations for each style then # take a weighted sum. target_feature = np.zeros(content_feature.shape) for style_feature, weight in zip(style_features, style_interp_weights): target_feature += sess.run(target, feed_dict={ content: content_feature, style: style_feature[np.newaxis,:] }) * weight else: # NOTE: This is the part we care about, if only 1 style image is provided. style_name = get_filename(style_path) style_image = load_image(style_path, style_size, crop=True) # This only gives us square crop style_image = center_crop_np(style_image) # Actually crop the center out if preserve_color: style_image = coral(style_image, content_image) style_image = prepare_image(style_image, True, data_format) content_image = prepare_image(content_image, True, data_format) # Extract other layers conv3_1_layer, conv4_1_layer = otherLayers style_feature, conv3_1_out_style, conv4_1_out_style = sess.run([encoder, conv3_1_layer, conv4_1_layer], feed_dict={ image: style_image[np.newaxis,:] }) content_feature = sess.run(encoder, feed_dict={ image: content_image[np.newaxis,:] }) target_feature = sess.run(target, feed_dict={ content: content_feature, style: style_feature }) output = sess.run(decoder, feed_dict={ content: content_feature, target: target_feature, style: style_feature }) # Grab the relevant layer outputs to see what's being minimized. conv3_1_out_output, conv4_1_out_output = sess.run([conv3_1_layer, conv4_1_layer], feed_dict={ image: output }) filename = '%s_stylized_%s.%s' % (content_name, style_name, save_ext) filename = os.path.join(output_dir, filename) save_image(filename, output[0], data_format=data_format) print('Output image saved at', filename) # TODO: Change these layers. layersToViz = [conv3_1_out_style, conv4_1_out_style, conv3_1_out_output, conv4_1_out_output]
def style_transfer(content_img=None, content_size=512, style_img=None, style_size=512, crop=None, alpha=1.0, content_dir='content', style_dir='style', output_dir='output', vgg_weights='models/vgg19_weights_normalized.h5', decoder_weights='models/decoder_weights.h5'): decoder_in_h5 = decoder_weights.endswith('.h5') os.environ['CUDA_VISIBLE_DEVICES'] = '' data_format = 'channels_last' if not os.path.exists(content_dir): os.mkdir(content_dir) if not os.path.exists(style_dir): os.mkdir(style_dir) if not os.path.exists(output_dir): os.mkdir(output_dir) image, content, style, target, encoder, decoder = init_graph( vgg_weights, decoder_weights if decoder_in_h5 else None, alpha, data_format=data_format) with tf.Session() as sess: if decoder_in_h5: sess.run(tf.global_variables_initializer()) else: saver = tf.train.Saver() saver.restore(sess, decoder_weights) content_name = content_img.filename style_name = style_img.filename content_path = os.path.join(content_dir, content_name) style_path = os.path.join(style_dir, style_name) with open(content_path, "wb") as f: f.write(content_img.read()) with open(style_path, "wb") as f: f.write(style_img.read()) content_image = load_image(content_path, content_size, crop) style_image = load_image(style_path, style_size, crop) style_image = prepare_image(style_image) content_image = prepare_image(content_image) style_feature = sess.run(encoder, feed_dict={image: style_image[np.newaxis, :]}) content_feature = sess.run( encoder, feed_dict={image: content_image[np.newaxis, :]}) target_feature = sess.run(target, feed_dict={ content: content_feature, style: style_feature }) output = sess.run(decoder, feed_dict={ content: content_feature, target: target_feature }) name = f"{content_name.split('.')[0]}_stylized_{style_name.split('.')[0]}.jpg" filename = os.path.join(output_dir, name) save_image(filename, output[0], data_format=data_format) return name.split('.')[0], filename