def testCreateOnecloneWithPS(self): g = tf.Graph() with g.as_default(): tf.set_random_seed(0) tf_inputs = tf.constant(self._inputs, dtype=tf.float32) tf_labels = tf.constant(self._labels, dtype=tf.float32) model_fn = BatchNormClassifier model_args = (tf_inputs, tf_labels) deploy_config = model_deploy.DeploymentConfig(num_clones=1, num_ps_tasks=1) self.assertEqual(slim.get_variables(), []) clones = model_deploy.create_clones(deploy_config, model_fn, model_args) self.assertEqual(len(slim.get_variables()), 5) update_ops = tf.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) self.assertEqual(len(update_ops), 2) optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) total_loss, grads_and_vars = model_deploy.optimize_clones(clones, optimizer) self.assertEqual(len(grads_and_vars), len(tf.trainable_variables())) self.assertEqual(total_loss.op.name, 'total_loss') for g, v in grads_and_vars: self.assertDeviceEqual(g.device, '/job:worker/device:GPU:0') self.assertDeviceEqual(v.device, '/job:ps/task:0/CPU:0')
def testCreateMulticloneWithPS(self): g = tf.Graph() with g.as_default(): tf.set_random_seed(0) tf_inputs = tf.constant(self._inputs, dtype=tf.float32) tf_labels = tf.constant(self._labels, dtype=tf.float32) model_fn = BatchNormClassifier clone_args = (tf_inputs, tf_labels) deploy_config = model_deploy.DeploymentConfig(num_clones=2, num_ps_tasks=2) self.assertEqual(slim.get_variables(), []) clones = model_deploy.create_clones(deploy_config, model_fn, clone_args) self.assertEqual(len(slim.get_variables()), 5) for i, v in enumerate(slim.get_variables()): t = i % 2 self.assertDeviceEqual(v.device, '/job:ps/task:%d/device:CPU:0' % t) self.assertDeviceEqual(v.device, v.value().device) self.assertEqual(len(clones), 2) for i, clone in enumerate(clones): self.assertEqual( clone.outputs.op.name, 'clone_%d/BatchNormClassifier/fully_connected/Sigmoid' % i) self.assertEqual(clone.scope, 'clone_%d/' % i) self.assertDeviceEqual(clone.device, '/job:worker/device:GPU:%d' % i)
def testCreateMulticlone(self): g = tf.Graph() with g.as_default(): tf.set_random_seed(0) tf_inputs = tf.constant(self._inputs, dtype=tf.float32) tf_labels = tf.constant(self._labels, dtype=tf.float32) model_fn = BatchNormClassifier clone_args = (tf_inputs, tf_labels) num_clones = 4 deploy_config = model_deploy.DeploymentConfig(num_clones=num_clones) self.assertEqual(slim.get_variables(), []) clones = model_deploy.create_clones(deploy_config, model_fn, clone_args) self.assertEqual(len(slim.get_variables()), 5) for v in slim.get_variables(): self.assertDeviceEqual(v.device, 'CPU:0') self.assertDeviceEqual(v.value().device, 'CPU:0') self.assertEqual(len(clones), num_clones) for i, clone in enumerate(clones): self.assertEqual( clone.outputs.op.name, 'clone_%d/BatchNormClassifier/fully_connected/Sigmoid' % i) update_ops = tf.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, clone.scope) self.assertEqual(len(update_ops), 2) self.assertEqual(clone.scope, 'clone_%d/' % i) self.assertDeviceEqual(clone.device, 'GPU:%d' % i)
def testCreateLogisticClassifier(self): g = tf.Graph() with g.as_default(): tf.set_random_seed(0) tf_inputs = tf.constant(self._inputs, dtype=tf.float32) tf_labels = tf.constant(self._labels, dtype=tf.float32) model_fn = LogisticClassifier clone_args = (tf_inputs, tf_labels) deploy_config = model_deploy.DeploymentConfig(num_clones=1) self.assertEqual(slim.get_variables(), []) clones = model_deploy.create_clones(deploy_config, model_fn, clone_args) clone = clones[0] self.assertEqual(len(slim.get_variables()), 2) for v in slim.get_variables(): self.assertDeviceEqual(v.device, 'CPU:0') self.assertDeviceEqual(v.value().device, 'CPU:0') self.assertEqual(clone.outputs.op.name, 'LogisticClassifier/fully_connected/Sigmoid') self.assertEqual(clone.scope, '') self.assertDeviceEqual(clone.device, 'GPU:0') self.assertEqual(len(slim.losses.get_losses()), 1) update_ops = tf.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) self.assertEqual(update_ops, [])
def testCreateLogisticClassifier(self): g = tf.Graph() with g.as_default(): tf.set_random_seed(0) tf_inputs = tf.constant(self._inputs, dtype=tf.float32) tf_labels = tf.constant(self._labels, dtype=tf.float32) model_fn = LogisticClassifier clone_args = (tf_inputs, tf_labels) deploy_config = model_deploy.DeploymentConfig(num_clones=1) self.assertEqual(slim.get_variables(), []) clones = model_deploy.create_clones(deploy_config, model_fn, clone_args) self.assertEqual(len(slim.get_variables()), 2) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) self.assertEqual(update_ops, []) optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) total_loss, grads_and_vars = model_deploy.optimize_clones( clones, optimizer) self.assertEqual(len(grads_and_vars), len(tf.trainable_variables())) self.assertEqual(total_loss.op.name, 'total_loss') for g, v in grads_and_vars: self.assertDeviceEqual(g.device, 'GPU:0') self.assertDeviceEqual(v.device, 'CPU:0')
def testCreateOnecloneWithPS(self): g = tf.Graph() with g.as_default(): tf.set_random_seed(0) tf_inputs = tf.constant(self._inputs, dtype=tf.float32) tf_labels = tf.constant(self._labels, dtype=tf.float32) model_fn = BatchNormClassifier clone_args = (tf_inputs, tf_labels) deploy_config = model_deploy.DeploymentConfig(num_clones=1, num_ps_tasks=1) self.assertEqual(slim.get_variables(), []) clones = model_deploy.create_clones(deploy_config, model_fn, clone_args) self.assertEqual(len(clones), 1) clone = clones[0] self.assertEqual(clone.outputs.op.name, 'BatchNormClassifier/fully_connected/Sigmoid') self.assertDeviceEqual(clone.device, '/job:worker/device:GPU:0') self.assertEqual(clone.scope, '') self.assertEqual(len(slim.get_variables()), 5) for v in slim.get_variables(): self.assertDeviceEqual(v.device, '/job:ps/task:0/CPU:0') self.assertDeviceEqual(v.device, v.value().device)
def testDefaults(self): deploy_config = model_deploy.DeploymentConfig() self.assertEqual(slim.get_variables(), []) self.assertEqual(deploy_config.caching_device(), None) self.assertDeviceEqual(deploy_config.clone_device(0), 'GPU:0') self.assertEqual(deploy_config.clone_scope(0), '') self.assertDeviceEqual(deploy_config.optimizer_device(), 'CPU:0') self.assertDeviceEqual(deploy_config.inputs_device(), 'CPU:0') self.assertDeviceEqual(deploy_config.variables_device(), 'CPU:0')
def testLocalTrainOp(self): g = tf.Graph() with g.as_default(): tf.compat.v1.set_random_seed(0) tf_inputs = tf.constant(self._inputs, dtype=tf.float32) tf_labels = tf.constant(self._labels, dtype=tf.float32) model_fn = BatchNormClassifier model_args = (tf_inputs, tf_labels) deploy_config = model_deploy.DeploymentConfig(num_clones=2, clone_on_cpu=True) optimizer = tf.compat.v1.train.GradientDescentOptimizer( learning_rate=1.0) self.assertEqual(slim.get_variables(), []) model = model_deploy.deploy(deploy_config, model_fn, model_args, optimizer=optimizer) update_ops = tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.UPDATE_OPS) self.assertEqual(len(update_ops), 4) self.assertEqual(len(model.clones), 2) self.assertEqual(model.total_loss.op.name, 'total_loss') self.assertEqual(model.summary_op.op.name, 'summary_op/summary_op') self.assertEqual(model.train_op.op.name, 'train_op') with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) slim.get_variables_by_name() moving_mean = slim.get_variables_by_name('moving_mean')[0] moving_variance = slim.get_variables_by_name( 'moving_variance')[0] initial_loss = sess.run(model.total_loss) initial_mean, initial_variance = sess.run( [moving_mean, moving_variance]) self.assertAllClose(initial_mean, [0.0, 0.0, 0.0, 0.0]) self.assertAllClose(initial_variance, [1.0, 1.0, 1.0, 1.0]) for _ in range(10): sess.run(model.train_op) final_loss = sess.run(model.total_loss) self.assertLess(final_loss, initial_loss / 5.0) final_mean, final_variance = sess.run( [moving_mean, moving_variance]) expected_mean = np.array([0.125, 0.25, 0.375, 0.25]) expected_var = np.array([0.109375, 0.1875, 0.234375, 0.1875]) expected_var = self._addBesselsCorrection(16, expected_var) self.assertAllClose(final_mean, expected_mean) self.assertAllClose(final_variance, expected_var)
def precompute_gram_matrices(image, final_endpoint='fc8'): """Pre-computes the Gram matrices on a given image. Args: image: 4-D tensor. Input (batch of) image(s). final_endpoint: str, name of the final layer to compute Gram matrices for. Defaults to 'fc8'. Returns: dict mapping layer names to their corresponding Gram matrices. """ with tf.Session() as session: end_points = vgg.vgg_16(image, final_endpoint=final_endpoint) tf.train.Saver(slim.get_variables('vgg_16')).restore( session, vgg.checkpoint_file()) return dict((key, gram_matrix(value).eval()) for key, value in end_points.items())
def main(unused_argv=None): tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): # Forces all input processing onto CPU in order to reserve the GPU for the # forward inference and back-propagation. device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0' with tf.device( tf.train.replica_device_setter(FLAGS.ps_tasks, worker_device=device)): # Loads content images. content_inputs_, _ = image_utils.imagenet_inputs( FLAGS.batch_size, FLAGS.image_size) # Loads style images. [style_inputs_, _, _] = image_utils.arbitrary_style_image_inputs( FLAGS.style_dataset_file, batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, shuffle=True, center_crop=FLAGS.center_crop, augment_style_images=FLAGS.augment_style_images, random_style_image_size=FLAGS.random_style_image_size) with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): # Process style and content weight flags. content_weights = ast.literal_eval(FLAGS.content_weights) style_weights = ast.literal_eval(FLAGS.style_weights) # Define the model stylized_images, total_loss, loss_dict, \ _ = build_mobilenet_model.build_mobilenet_model( content_inputs_, style_inputs_, mobilenet_trainable=False, style_params_trainable=True, transformer_trainable=True, mobilenet_end_point='layer_19', transformer_alpha=FLAGS.alpha, style_prediction_bottleneck=100, adds_losses=True, content_weights=content_weights, style_weights=style_weights, total_variation_weight=FLAGS.total_variation_weight, ) # Adding scalar summaries to the tensorboard. for key in loss_dict: tf.summary.scalar(key, loss_dict[key]) # Adding Image summaries to the tensorboard. tf.summary.image('image/0_content_inputs', content_inputs_, 3) tf.summary.image('image/1_style_inputs_aug', style_inputs_, 3) tf.summary.image('image/2_stylized_images', stylized_images, 3) # Set up training optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) train_op = slim.learning.create_train_op( total_loss, optimizer, clip_gradient_norm=FLAGS.clip_gradient_norm, summarize_gradients=False) # Function to restore VGG16 parameters. init_fn_vgg = slim.assign_from_checkpoint_fn( vgg.checkpoint_file(), slim.get_variables('vgg_16')) # Function to restore Mobilenet V2 parameters. mobilenet_variables_dict = { var.op.name: var for var in slim.get_model_variables('MobilenetV2') } init_fn_mobilenet = slim.assign_from_checkpoint_fn( FLAGS.mobilenet_checkpoint, mobilenet_variables_dict) # Function to restore VGG16 and Mobilenet V2 parameters. def init_sub_networks(session): init_fn_vgg(session) init_fn_mobilenet(session) # Run training slim.learning.train(train_op=train_op, logdir=os.path.expanduser(FLAGS.train_dir), master=FLAGS.master, is_chief=FLAGS.task == 0, number_of_steps=FLAGS.train_steps, init_fn=init_sub_networks, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def main(unused_argv=None): with tf.Graph().as_default(): # Force all input processing onto CPU in order to reserve the GPU for the # forward inference and back-propagation. device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0' with tf.device( tf.train.replica_device_setter(FLAGS.ps_tasks, worker_device=device)): inputs, _ = image_utils.imagenet_inputs(FLAGS.batch_size, FLAGS.image_size) # Load style images and select one at random (for each graph execution, a # new random selection occurs) _, style_labels, style_gram_matrices = image_utils.style_image_inputs( os.path.expanduser(FLAGS.style_dataset_file), batch_size=FLAGS.batch_size, image_size=FLAGS.image_size, square_crop=True, shuffle=True) with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): # Process style and weight flags num_styles = FLAGS.num_styles if FLAGS.style_coefficients is None: style_coefficients = [1.0 for _ in range(num_styles)] else: style_coefficients = ast.literal_eval(FLAGS.style_coefficients) if len(style_coefficients) != num_styles: raise ValueError( 'number of style coefficients differs from number of styles' ) content_weights = ast.literal_eval(FLAGS.content_weights) style_weights = ast.literal_eval(FLAGS.style_weights) # Rescale style weights dynamically based on the current style image style_coefficient = tf.gather(tf.constant(style_coefficients), style_labels) style_weights = dict((key, style_coefficient * value) for key, value in style_weights.items()) # Define the model stylized_inputs = model.transform(inputs, alpha=FLAGS.alpha, normalizer_params={ 'labels': style_labels, 'num_categories': num_styles, 'center': True, 'scale': True }) # Compute losses. total_loss, loss_dict = learning.total_loss( inputs, stylized_inputs, style_gram_matrices, content_weights, style_weights) for key, value in loss_dict.items(): tf.summary.scalar(key, value) instance_norm_vars = [ var for var in slim.get_variables('transformer') if 'InstanceNorm' in var.name ] other_vars = [ var for var in slim.get_variables('transformer') if 'InstanceNorm' not in var.name ] # Function to restore VGG16 parameters. init_fn_vgg = slim.assign_from_checkpoint_fn( vgg.checkpoint_file(), slim.get_variables('vgg_16')) # Function to restore N-styles parameters. init_fn_n_styles = slim.assign_from_checkpoint_fn( os.path.expanduser(FLAGS.checkpoint), other_vars) def init_fn(session): init_fn_vgg(session) init_fn_n_styles(session) # Set up training. optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) train_op = slim.learning.create_train_op( total_loss, optimizer, clip_gradient_norm=FLAGS.clip_gradient_norm, variables_to_train=instance_norm_vars, summarize_gradients=False) # Run training. slim.learning.train(train_op=train_op, logdir=os.path.expanduser(FLAGS.train_dir), master=FLAGS.master, is_chief=FLAGS.task == 0, number_of_steps=FLAGS.train_steps, init_fn=init_fn, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)