def export_to_saved_model(checkpoint, alpha): """Export arbitrary style transfer trained checkpoints to SavedModel format. Args: checkpoint: str, path to the checkpoint file. alpha: Width Multiplier of the transform model. Returns: (str, str) Path to the exported style predict and style transform SavedModel. """ saved_model_dir = tempfile.mkdtemp() predict_saved_model_folder = os.path.join(saved_model_dir, 'predict') transform_saved_model_folder = os.path.join(saved_model_dir, 'transform') with tf.Graph().as_default(), tf.Session() as sess: # Defines place holder for the style image. style_image_tensor = tf.placeholder( tf.float32, shape=[None, None, None, 3], name='style_image') # Defines place holder for the content image. content_image_tensor = tf.placeholder( tf.float32, shape=[None, None, None, 3], name='content_image') # Defines the model. stylized_images, _, _, bottleneck_feat = \ build_mobilenet_model.build_mobilenet_model( content_image_tensor, style_image_tensor, mobilenet_trainable=False, style_params_trainable=False, transformer_trainable=False, style_prediction_bottleneck=100, transformer_alpha=alpha, mobilenet_end_point='layer_19', adds_losses=False) # Load model weights from checkpoint file load_checkpoint(sess, checkpoint) # Write SavedModel for serving or conversion to TF Lite tf.saved_model.simple_save( sess, predict_saved_model_folder, inputs={style_image_tensor.name: style_image_tensor}, outputs={'style_bottleneck': bottleneck_feat}) tf.logging.debug('Export predict SavedModel to', predict_saved_model_folder) tf.saved_model.simple_save( sess, transform_saved_model_folder, inputs={ content_image_tensor.name: content_image_tensor, 'style_bottleneck': bottleneck_feat }, outputs={'stylized_image': stylized_images}) tf.logging.debug('Export transform SavedModel to', transform_saved_model_folder) return predict_saved_model_folder, transform_saved_model_folder
def main(unused_argv=None): tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): # Forces all input processing onto CPU in order to reserve the GPU for the # forward inference and back-propagation. device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0' with tf.device( tf.train.replica_device_setter(FLAGS.ps_tasks, worker_device=device)): # Load content images content_inputs_, _ = image_utils.imagenet_inputs( FLAGS.batch_size, FLAGS.image_size) # Loads style images. [style_inputs_, _, style_inputs_orig_] = image_utils.arbitrary_style_image_inputs( FLAGS.style_dataset_file, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, shuffle=True, center_crop=FLAGS.center_crop, augment_style_images=FLAGS.augment_style_images, random_style_image_size=FLAGS.random_style_image_size) with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): # Process style and content weight flags. content_weights = ast.literal_eval(FLAGS.content_weights) style_weights = ast.literal_eval(FLAGS.style_weights) # Define the model stylized_images, \ true_loss, \ _, \ bottleneck_feat = build_mobilenet_model.build_mobilenet_model( content_inputs_, style_inputs_, mobilenet_trainable=True, style_params_trainable=False, style_prediction_bottleneck=100, adds_losses=True, content_weights=content_weights, style_weights=style_weights, total_variation_weight=FLAGS.total_variation_weight, ) _, inception_bottleneck_feat = build_model.style_prediction( style_inputs_, [], [], is_training=False, trainable=False, inception_end_point='Mixed_6e', style_prediction_bottleneck=100, reuse=None, ) print('PRINTING TRAINABLE VARIABLES') for x in tf.trainable_variables(): print(x) mse_loss = tf.losses.mean_squared_error(inception_bottleneck_feat, bottleneck_feat) total_loss = mse_loss if FLAGS.use_true_loss: true_loss = FLAGS.true_loss_weight * true_loss total_loss += true_loss if FLAGS.use_true_loss: tf.summary.scalar('mse', mse_loss) tf.summary.scalar('true_loss', true_loss) tf.summary.scalar('total_loss', total_loss) tf.summary.image('image/0_content_inputs', content_inputs_, 3) tf.summary.image('image/1_style_inputs_orig', style_inputs_orig_, 3) tf.summary.image('image/2_style_inputs_aug', style_inputs_, 3) tf.summary.image('image/3_stylized_images', stylized_images, 3) mobilenet_variables_to_restore = slim.get_variables_to_restore( include=['MobilenetV2'], exclude=['global_step']) optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) train_op = slim.learning.create_train_op( total_loss, optimizer, clip_gradient_norm=FLAGS.clip_gradient_norm, summarize_gradients=False) init_fn = slim.assign_from_checkpoint_fn( FLAGS.initial_checkpoint, slim.get_variables_to_restore( exclude=['MobilenetV2', 'mobilenet_conv', 'global_step'])) init_pretrained_mobilenet = slim.assign_from_checkpoint_fn( FLAGS.mobilenet_checkpoint, mobilenet_variables_to_restore) def init_sub_networks(session): init_fn(session) init_pretrained_mobilenet(session) slim.learning.train(train_op=train_op, logdir=os.path.expanduser(FLAGS.train_dir), master=FLAGS.master, is_chief=FLAGS.task == 0, number_of_steps=FLAGS.train_steps, init_fn=init_sub_networks, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def main(unused_argv=None): tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): # Forces all input processing onto CPU in order to reserve the GPU for the # forward inference and back-propagation. device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0' with tf.device( tf.train.replica_device_setter(FLAGS.ps_tasks, worker_device=device)): # Loads content images. content_inputs_, _ = image_utils.imagenet_inputs( FLAGS.batch_size, FLAGS.image_size) # Loads style images. [style_inputs_, _, _] = image_utils.arbitrary_style_image_inputs( FLAGS.style_dataset_file, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, shuffle=True, center_crop=FLAGS.center_crop, augment_style_images=FLAGS.augment_style_images, random_style_image_size=FLAGS.random_style_image_size) with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): # Process style and content weight flags. content_weights = ast.literal_eval(FLAGS.content_weights) style_weights = ast.literal_eval(FLAGS.style_weights) # Define the model stylized_images, total_loss, loss_dict, \ _ = build_mobilenet_model.build_mobilenet_model( content_inputs_, style_inputs_, mobilenet_trainable=False, style_params_trainable=True, transformer_trainable=True, mobilenet_end_point='layer_19', transformer_alpha=FLAGS.alpha, style_prediction_bottleneck=100, adds_losses=True, content_weights=content_weights, style_weights=style_weights, total_variation_weight=FLAGS.total_variation_weight, ) # Adding scalar summaries to the tensorboard. for key in loss_dict: tf.summary.scalar(key, loss_dict[key]) # Adding Image summaries to the tensorboard. tf.summary.image('image/0_content_inputs', content_inputs_, 3) tf.summary.image('image/1_style_inputs_aug', style_inputs_, 3) tf.summary.image('image/2_stylized_images', stylized_images, 3) # Set up training optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) train_op = slim.learning.create_train_op( total_loss, optimizer, clip_gradient_norm=FLAGS.clip_gradient_norm, summarize_gradients=False) # Function to restore VGG16 parameters. init_fn_vgg = slim.assign_from_checkpoint_fn( vgg.checkpoint_file(), slim.get_variables('vgg_16')) # Function to restore Mobilenet V2 parameters. mobilenet_variables_dict = { var.op.name: var for var in slim.get_model_variables('MobilenetV2') } init_fn_mobilenet = slim.assign_from_checkpoint_fn( FLAGS.mobilenet_checkpoint, mobilenet_variables_dict) # Function to restore VGG16 and Mobilenet V2 parameters. def init_sub_networks(session): init_fn_vgg(session) init_fn_mobilenet(session) # Run training slim.learning.train(train_op=train_op, logdir=os.path.expanduser(FLAGS.train_dir), master=FLAGS.master, is_chief=FLAGS.task == 0, number_of_steps=FLAGS.train_steps, init_fn=init_sub_networks, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def main(unused_argv=None): tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): # Forces all input processing onto CPU in order to reserve the GPU for the # forward inference and back-propagation. device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0' with tf.device( tf.train.replica_device_setter(FLAGS.ps_tasks, worker_device=device)): # Load content images content_inputs_, _ = image_utils.imagenet_inputs(FLAGS.batch_size, FLAGS.image_size) # Loads style images. [style_inputs_, _, style_inputs_orig_] = image_utils.arbitrary_style_image_inputs( FLAGS.style_dataset_file, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, shuffle=True, center_crop=FLAGS.center_crop, augment_style_images=FLAGS.augment_style_images, random_style_image_size=FLAGS.random_style_image_size) with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): # Process style and content weight flags. content_weights = ast.literal_eval(FLAGS.content_weights) style_weights = ast.literal_eval(FLAGS.style_weights) # Define the model stylized_images, \ true_loss, \ _, \ bottleneck_feat = build_mobilenet_model.build_mobilenet_model( content_inputs_, style_inputs_, mobilenet_trainable=True, style_params_trainable=False, style_prediction_bottleneck=100, adds_losses=True, content_weights=content_weights, style_weights=style_weights, total_variation_weight=FLAGS.total_variation_weight, ) _, inception_bottleneck_feat = build_model.style_prediction( style_inputs_, [], [], is_training=False, trainable=False, inception_end_point='Mixed_6e', style_prediction_bottleneck=100, reuse=None, ) print('PRINTING TRAINABLE VARIABLES') for x in tf.trainable_variables(): print(x) mse_loss = tf.losses.mean_squared_error( inception_bottleneck_feat, bottleneck_feat) total_loss = mse_loss if FLAGS.use_true_loss: true_loss = FLAGS.true_loss_weight*true_loss total_loss += true_loss if FLAGS.use_true_loss: tf.summary.scalar('mse', mse_loss) tf.summary.scalar('true_loss', true_loss) tf.summary.scalar('total_loss', total_loss) tf.summary.image('image/0_content_inputs', content_inputs_, 3) tf.summary.image('image/1_style_inputs_orig', style_inputs_orig_, 3) tf.summary.image('image/2_style_inputs_aug', style_inputs_, 3) tf.summary.image('image/3_stylized_images', stylized_images, 3) mobilenet_variables_to_restore = slim.get_variables_to_restore( include=['MobilenetV2'], exclude=['global_step']) optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) train_op = slim.learning.create_train_op( total_loss, optimizer, clip_gradient_norm=FLAGS.clip_gradient_norm, summarize_gradients=False ) init_fn = slim.assign_from_checkpoint_fn( FLAGS.initial_checkpoint, slim.get_variables_to_restore( exclude=['MobilenetV2', 'mobilenet_conv', 'global_step'])) init_pretrained_mobilenet = slim.assign_from_checkpoint_fn( FLAGS.mobilenet_checkpoint, mobilenet_variables_to_restore) def init_sub_networks(session): init_fn(session) init_pretrained_mobilenet(session) slim.learning.train( train_op=train_op, logdir=os.path.expanduser(FLAGS.train_dir), master=FLAGS.master, is_chief=FLAGS.task == 0, number_of_steps=FLAGS.train_steps, init_fn=init_sub_networks, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)