コード例 #1
0
def main(unused_argv=None):
    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        # Forces all input processing onto CPU in order to reserve the GPU for the
        # forward inference and back-propagation.
        device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0'
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               worker_device=device)):
            # Loads content images.
            content_inputs_, _ = image_utils.imagenet_inputs(
                FLAGS.batch_size, FLAGS.image_size)

            # Loads style images.
            [style_inputs_, _, _] = image_utils.arbitrary_style_image_inputs(
                FLAGS.style_dataset_file,
                batch_size=FLAGS.batch_size,
                image_size=FLAGS.image_size,
                shuffle=True,
                center_crop=FLAGS.center_crop,
                augment_style_images=FLAGS.augment_style_images,
                random_style_image_size=FLAGS.random_style_image_size)

        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
            # Process style and content weight flags.
            content_weights = ast.literal_eval(FLAGS.content_weights)
            style_weights = ast.literal_eval(FLAGS.style_weights)

            # Define the model
            stylized_images, total_loss, loss_dict, \
                  _ = build_mobilenet_model.build_mobilenet_model(
                      content_inputs_,
                      style_inputs_,
                      mobilenet_trainable=False,
                      style_params_trainable=True,
                      transformer_trainable=True,
                      mobilenet_end_point='layer_19',
                      transformer_alpha=FLAGS.alpha,
                      style_prediction_bottleneck=100,
                      adds_losses=True,
                      content_weights=content_weights,
                      style_weights=style_weights,
                      total_variation_weight=FLAGS.total_variation_weight,
                  )

            # Adding scalar summaries to the tensorboard.
            for key in loss_dict:
                tf.summary.scalar(key, loss_dict[key])

            # Adding Image summaries to the tensorboard.
            tf.summary.image('image/0_content_inputs', content_inputs_, 3)
            tf.summary.image('image/1_style_inputs_aug', style_inputs_, 3)
            tf.summary.image('image/2_stylized_images', stylized_images, 3)

            # Set up training
            optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
            train_op = slim.learning.create_train_op(
                total_loss,
                optimizer,
                clip_gradient_norm=FLAGS.clip_gradient_norm,
                summarize_gradients=False)

            # Function to restore VGG16 parameters.
            init_fn_vgg = slim.assign_from_checkpoint_fn(
                vgg.checkpoint_file(), slim.get_variables('vgg_16'))

            # Function to restore Mobilenet V2 parameters.
            mobilenet_variables_dict = {
                var.op.name: var
                for var in slim.get_model_variables('MobilenetV2')
            }
            init_fn_mobilenet = slim.assign_from_checkpoint_fn(
                FLAGS.mobilenet_checkpoint, mobilenet_variables_dict)

            # Function to restore VGG16 and Mobilenet V2 parameters.
            def init_sub_networks(session):
                init_fn_vgg(session)
                init_fn_mobilenet(session)

            # Run training
            slim.learning.train(train_op=train_op,
                                logdir=os.path.expanduser(FLAGS.train_dir),
                                master=FLAGS.master,
                                is_chief=FLAGS.task == 0,
                                number_of_steps=FLAGS.train_steps,
                                init_fn=init_sub_networks,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
コード例 #2
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    with tf.Graph().as_default():
        # Loads content images.
        eval_content_inputs_, _ = image_utils.imagenet_inputs(
            FLAGS.batch_size, FLAGS.image_size)

        # Process style and content weight flags.
        content_weights = ast.literal_eval(FLAGS.content_weights)
        style_weights = ast.literal_eval(FLAGS.style_weights)

        # Loads evaluation style images.
        eval_style_inputs_, _, _ = image_utils.arbitrary_style_image_inputs(
            FLAGS.eval_style_dataset_file,
            batch_size=FLAGS.batch_size,
            image_size=FLAGS.image_size,
            center_crop=True,
            shuffle=True,
            augment_style_images=False,
            random_style_image_size=False)

        # Computes stylized noise.
        stylized_noise, _, _, _ = build_model.build_model(
            tf.random_uniform([
                min(4, FLAGS.batch_size), FLAGS.image_size, FLAGS.image_size, 3
            ]),
            tf.slice(eval_style_inputs_, [0, 0, 0, 0],
                     [min(4, FLAGS.batch_size), -1, -1, -1]),
            trainable=False,
            is_training=False,
            reuse=None,
            inception_end_point='Mixed_6e',
            style_prediction_bottleneck=100,
            adds_losses=False)

        # Computes stylized images.
        stylized_images, _, loss_dict, _ = build_model.build_model(
            eval_content_inputs_,
            eval_style_inputs_,
            trainable=False,
            is_training=False,
            reuse=True,
            inception_end_point='Mixed_6e',
            style_prediction_bottleneck=100,
            adds_losses=True,
            content_weights=content_weights,
            style_weights=style_weights,
            total_variation_weight=FLAGS.total_variation_weight)

        # Adds Image summaries to the tensorboard.
        tf.summary.image(
            'image/{}/0_eval_content_inputs'.format(FLAGS.eval_name),
            eval_content_inputs_, 3)
        tf.summary.image(
            'image/{}/1_eval_style_inputs'.format(FLAGS.eval_name),
            eval_style_inputs_, 3)
        tf.summary.image(
            'image/{}/2_eval_stylized_images'.format(FLAGS.eval_name),
            stylized_images, 3)
        tf.summary.image('image/{}/3_stylized_noise'.format(FLAGS.eval_name),
                         stylized_noise, 3)

        metrics = {}
        for key, value in loss_dict.items():
            metrics[key] = tf.metrics.mean(value)

        names_values, names_updates = slim.metrics.aggregate_metric_map(
            metrics)
        for name, value in names_values.items():
            slim.summaries.add_scalar_summary(value, name, print_summary=True)
        eval_op = list(names_updates.values())
        num_evals = FLAGS.num_evaluation_styles / FLAGS.batch_size

        slim.evaluation.evaluation_loop(
            master=FLAGS.master,
            checkpoint_dir=FLAGS.checkpoint_dir,
            logdir=FLAGS.eval_dir,
            eval_op=eval_op,
            num_evals=num_evals,
            eval_interval_secs=FLAGS.eval_interval_secs)
コード例 #3
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  with tf.Graph().as_default():
    # Loads content images.
    eval_content_inputs_, _ = image_utils.imagenet_inputs(
        FLAGS.batch_size, FLAGS.image_size)

    # Process style and content weight flags.
    content_weights = ast.literal_eval(FLAGS.content_weights)
    style_weights = ast.literal_eval(FLAGS.style_weights)

    # Loads evaluation style images.
    eval_style_inputs_, _, _ = image_utils.arbitrary_style_image_inputs(
        FLAGS.eval_style_dataset_file,
        batch_size=FLAGS.batch_size,
        image_size=FLAGS.image_size,
        center_crop=True,
        shuffle=True,
        augment_style_images=False,
        random_style_image_size=False)

    # Computes stylized noise.
    stylized_noise, _, _, _ = build_model.build_model(
        tf.random_uniform(
            [min(4, FLAGS.batch_size), FLAGS.image_size, FLAGS.image_size, 3]),
        tf.slice(eval_style_inputs_, [0, 0, 0, 0],
                 [min(4, FLAGS.batch_size), -1, -1, -1]),
        trainable=False,
        is_training=False,
        reuse=None,
        inception_end_point='Mixed_6e',
        style_prediction_bottleneck=100,
        adds_losses=False)

    # Computes stylized images.
    stylized_images, _, loss_dict, _ = build_model.build_model(
        eval_content_inputs_,
        eval_style_inputs_,
        trainable=False,
        is_training=False,
        reuse=True,
        inception_end_point='Mixed_6e',
        style_prediction_bottleneck=100,
        adds_losses=True,
        content_weights=content_weights,
        style_weights=style_weights,
        total_variation_weight=FLAGS.total_variation_weight)

    # Adds Image summaries to the tensorboard.
    tf.summary.image('image/{}/0_eval_content_inputs'.format(FLAGS.eval_name),
                     eval_content_inputs_, 3)
    tf.summary.image('image/{}/1_eval_style_inputs'.format(FLAGS.eval_name),
                     eval_style_inputs_, 3)
    tf.summary.image('image/{}/2_eval_stylized_images'.format(FLAGS.eval_name),
                     stylized_images, 3)
    tf.summary.image('image/{}/3_stylized_noise'.format(FLAGS.eval_name),
                     stylized_noise, 3)

    metrics = {}
    for key, value in loss_dict.iteritems():
      metrics[key] = tf.metrics.mean(value)

    names_values, names_updates = slim.metrics.aggregate_metric_map(metrics)
    for name, value in names_values.iteritems():
      slim.summaries.add_scalar_summary(value, name, print_summary=True)
    eval_op = names_updates.values()
    num_evals = FLAGS.num_evaluation_styles / FLAGS.batch_size

    slim.evaluation.evaluation_loop(
        master=FLAGS.master,
        checkpoint_dir=FLAGS.checkpoint_dir,
        logdir=FLAGS.eval_dir,
        eval_op=eval_op,
        num_evals=num_evals,
        eval_interval_secs=FLAGS.eval_interval_secs)
def main(unused_argv=None):
    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        # Forces all input processing onto CPU in order to reserve the GPU for the
        # forward inference and back-propagation.
        device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0'
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               worker_device=device)):
            # Load content images
            content_inputs_, _ = image_utils.imagenet_inputs(
                FLAGS.batch_size, FLAGS.image_size)

            # Loads style images.
            [style_inputs_, _,
             style_inputs_orig_] = image_utils.arbitrary_style_image_inputs(
                 FLAGS.style_dataset_file,
                 batch_size=FLAGS.batch_size,
                 image_size=FLAGS.image_size,
                 shuffle=True,
                 center_crop=FLAGS.center_crop,
                 augment_style_images=FLAGS.augment_style_images,
                 random_style_image_size=FLAGS.random_style_image_size)

        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
            # Process style and content weight flags.
            content_weights = ast.literal_eval(FLAGS.content_weights)
            style_weights = ast.literal_eval(FLAGS.style_weights)

            # Define the model
            stylized_images, \
            true_loss, \
            _, \
            bottleneck_feat = build_mobilenet_model.build_mobilenet_model(
                content_inputs_,
                style_inputs_,
                mobilenet_trainable=True,
                style_params_trainable=False,
                style_prediction_bottleneck=100,
                adds_losses=True,
                content_weights=content_weights,
                style_weights=style_weights,
                total_variation_weight=FLAGS.total_variation_weight,
            )

            _, inception_bottleneck_feat = build_model.style_prediction(
                style_inputs_,
                [],
                [],
                is_training=False,
                trainable=False,
                inception_end_point='Mixed_6e',
                style_prediction_bottleneck=100,
                reuse=None,
            )

            print('PRINTING TRAINABLE VARIABLES')
            for x in tf.trainable_variables():
                print(x)

            mse_loss = tf.losses.mean_squared_error(inception_bottleneck_feat,
                                                    bottleneck_feat)
            total_loss = mse_loss
            if FLAGS.use_true_loss:
                true_loss = FLAGS.true_loss_weight * true_loss
                total_loss += true_loss

            if FLAGS.use_true_loss:
                tf.summary.scalar('mse', mse_loss)
                tf.summary.scalar('true_loss', true_loss)
            tf.summary.scalar('total_loss', total_loss)
            tf.summary.image('image/0_content_inputs', content_inputs_, 3)
            tf.summary.image('image/1_style_inputs_orig', style_inputs_orig_,
                             3)
            tf.summary.image('image/2_style_inputs_aug', style_inputs_, 3)
            tf.summary.image('image/3_stylized_images', stylized_images, 3)

            mobilenet_variables_to_restore = slim.get_variables_to_restore(
                include=['MobilenetV2'], exclude=['global_step'])

            optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
            train_op = slim.learning.create_train_op(
                total_loss,
                optimizer,
                clip_gradient_norm=FLAGS.clip_gradient_norm,
                summarize_gradients=False)

            init_fn = slim.assign_from_checkpoint_fn(
                FLAGS.initial_checkpoint,
                slim.get_variables_to_restore(
                    exclude=['MobilenetV2', 'mobilenet_conv', 'global_step']))
            init_pretrained_mobilenet = slim.assign_from_checkpoint_fn(
                FLAGS.mobilenet_checkpoint, mobilenet_variables_to_restore)

            def init_sub_networks(session):
                init_fn(session)
                init_pretrained_mobilenet(session)

            slim.learning.train(train_op=train_op,
                                logdir=os.path.expanduser(FLAGS.train_dir),
                                master=FLAGS.master,
                                is_chief=FLAGS.task == 0,
                                number_of_steps=FLAGS.train_steps,
                                init_fn=init_sub_networks,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
コード例 #5
0
def main(unused_argv=None):
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    # Forces all input processing onto CPU in order to reserve the GPU for the
    # forward inference and back-propagation.
    device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0'
    with tf.device(
        tf.train.replica_device_setter(FLAGS.ps_tasks, worker_device=device)):
      # Loads content images.
      content_inputs_, _ = image_utils.imagenet_inputs(FLAGS.batch_size,
                                                       FLAGS.image_size)

      # Loads style images.
      [style_inputs_, _,
       style_inputs_orig_] = image_utils.arbitrary_style_image_inputs(
           FLAGS.style_dataset_file,
           batch_size=FLAGS.batch_size,
           image_size=FLAGS.image_size,
           shuffle=True,
           center_crop=FLAGS.center_crop,
           augment_style_images=FLAGS.augment_style_images,
           random_style_image_size=FLAGS.random_style_image_size)

    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      # Process style and content weight flags.
      content_weights = ast.literal_eval(FLAGS.content_weights)
      style_weights = ast.literal_eval(FLAGS.style_weights)

      # Define the model
      stylized_images, total_loss, loss_dict, _ = build_model.build_model(
          content_inputs_,
          style_inputs_,
          trainable=True,
          is_training=True,
          inception_end_point='Mixed_6e',
          style_prediction_bottleneck=100,
          adds_losses=True,
          content_weights=content_weights,
          style_weights=style_weights,
          total_variation_weight=FLAGS.total_variation_weight)

      # Adding scalar summaries to the tensorboard.
      for key, value in loss_dict.iteritems():
        tf.summary.scalar(key, value)

      # Adding Image summaries to the tensorboard.
      tf.summary.image('image/0_content_inputs', content_inputs_, 3)
      tf.summary.image('image/1_style_inputs_orig', style_inputs_orig_, 3)
      tf.summary.image('image/2_style_inputs_aug', style_inputs_, 3)
      tf.summary.image('image/3_stylized_images', stylized_images, 3)

      # Set up training
      optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
      train_op = slim.learning.create_train_op(
          total_loss,
          optimizer,
          clip_gradient_norm=FLAGS.clip_gradient_norm,
          summarize_gradients=False)

      # Function to restore VGG16 parameters.
      init_fn_vgg = slim.assign_from_checkpoint_fn(vgg.checkpoint_file(),
                                                   slim.get_variables('vgg_16'))

      # Function to restore Inception_v3 parameters.
      inception_variables_dict = {
          var.op.name: var
          for var in slim.get_model_variables('InceptionV3')
      }
      init_fn_inception = slim.assign_from_checkpoint_fn(
          FLAGS.inception_v3_checkpoint, inception_variables_dict)

      # Function to restore VGG16 and Inception_v3 parameters.
      def init_sub_networks(session):
        init_fn_vgg(session)
        init_fn_inception(session)

      # Run training
      slim.learning.train(
          train_op=train_op,
          logdir=os.path.expanduser(FLAGS.train_dir),
          master=FLAGS.master,
          is_chief=FLAGS.task == 0,
          number_of_steps=FLAGS.train_steps,
          init_fn=init_sub_networks,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)
コード例 #6
0
def main(unused_argv=None):
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    # Forces all input processing onto CPU in order to reserve the GPU for the
    # forward inference and back-propagation.
    device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0'
    with tf.device(
        tf.train.replica_device_setter(FLAGS.ps_tasks, worker_device=device)):
      # Load content images
      content_inputs_, _ = image_utils.imagenet_inputs(FLAGS.batch_size,
                                                       FLAGS.image_size)

      # Loads style images.
      [style_inputs_, _,
       style_inputs_orig_] = image_utils.arbitrary_style_image_inputs(
           FLAGS.style_dataset_file,
           batch_size=FLAGS.batch_size,
           image_size=FLAGS.image_size,
           shuffle=True,
           center_crop=FLAGS.center_crop,
           augment_style_images=FLAGS.augment_style_images,
           random_style_image_size=FLAGS.random_style_image_size)

    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      # Process style and content weight flags.
      content_weights = ast.literal_eval(FLAGS.content_weights)
      style_weights = ast.literal_eval(FLAGS.style_weights)

      # Define the model
      stylized_images, \
      true_loss, \
      _, \
      bottleneck_feat = build_mobilenet_model.build_mobilenet_model(
          content_inputs_,
          style_inputs_,
          mobilenet_trainable=True,
          style_params_trainable=False,
          style_prediction_bottleneck=100,
          adds_losses=True,
          content_weights=content_weights,
          style_weights=style_weights,
          total_variation_weight=FLAGS.total_variation_weight,
      )

      _, inception_bottleneck_feat = build_model.style_prediction(
          style_inputs_,
          [],
          [],
          is_training=False,
          trainable=False,
          inception_end_point='Mixed_6e',
          style_prediction_bottleneck=100,
          reuse=None,
      )

      print('PRINTING TRAINABLE VARIABLES')
      for x in tf.trainable_variables():
        print(x)

      mse_loss = tf.losses.mean_squared_error(
          inception_bottleneck_feat, bottleneck_feat)
      total_loss = mse_loss
      if FLAGS.use_true_loss:
        true_loss = FLAGS.true_loss_weight*true_loss
        total_loss += true_loss

      if FLAGS.use_true_loss:
        tf.summary.scalar('mse', mse_loss)
        tf.summary.scalar('true_loss', true_loss)
      tf.summary.scalar('total_loss', total_loss)
      tf.summary.image('image/0_content_inputs', content_inputs_, 3)
      tf.summary.image('image/1_style_inputs_orig', style_inputs_orig_, 3)
      tf.summary.image('image/2_style_inputs_aug', style_inputs_, 3)
      tf.summary.image('image/3_stylized_images', stylized_images, 3)

      mobilenet_variables_to_restore = slim.get_variables_to_restore(
          include=['MobilenetV2'],
          exclude=['global_step'])

      optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
      train_op = slim.learning.create_train_op(
          total_loss,
          optimizer,
          clip_gradient_norm=FLAGS.clip_gradient_norm,
          summarize_gradients=False
      )

      init_fn = slim.assign_from_checkpoint_fn(
          FLAGS.initial_checkpoint,
          slim.get_variables_to_restore(
              exclude=['MobilenetV2', 'mobilenet_conv', 'global_step']))
      init_pretrained_mobilenet = slim.assign_from_checkpoint_fn(
          FLAGS.mobilenet_checkpoint, mobilenet_variables_to_restore)

      def init_sub_networks(session):
        init_fn(session)
        init_pretrained_mobilenet(session)

      slim.learning.train(
          train_op=train_op,
          logdir=os.path.expanduser(FLAGS.train_dir),
          master=FLAGS.master,
          is_chief=FLAGS.task == 0,
          number_of_steps=FLAGS.train_steps,
          init_fn=init_sub_networks,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)