Example #1
0
def build_network(content_img, style_img):
  """Builds the neural network for image stylization."""
  stylize_op, _, _, _ = arbitrary_image_stylization_build_model.build_model(
      content_img,
      style_img,
      trainable=False,
      is_training=False,
      adds_losses=False)
  return stylize_op
Example #2
0
def main(unused_argv=None):
  tf.logging.set_verbosity(tf.logging.INFO)
  if not tf.gfile.Exists(FLAGS.output_dir):
    tf.gfile.MkDir(FLAGS.output_dir)

  with tf.Graph().as_default(), tf.Session() as sess:
    # Defines place holder for the style image.
    style_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
    if FLAGS.style_square_crop:
      style_img_preprocessed = image_utils.center_crop_resize_image(
          style_img_ph, FLAGS.style_image_size)
    else:
      style_img_preprocessed = image_utils.resize_image(style_img_ph,
                                                        FLAGS.style_image_size)

    # Defines place holder for the content image.
    content_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
    if FLAGS.content_square_crop:
      content_img_preprocessed = image_utils.center_crop_resize_image(
          content_img_ph, FLAGS.image_size)
    else:
      content_img_preprocessed = image_utils.resize_image(
          content_img_ph, FLAGS.image_size)

    # Defines the model.
    stylized_images, _, _, bottleneck_feat = build_model.build_model(
        content_img_preprocessed,
        style_img_preprocessed,
        trainable=False,
        is_training=False,
        inception_end_point='Mixed_6e',
        style_prediction_bottleneck=100,
        adds_losses=False)

    if tf.gfile.IsDirectory(FLAGS.checkpoint):
      checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint)
    else:
      checkpoint = FLAGS.checkpoint
      tf.logging.info('loading latest checkpoint file: {}'.format(checkpoint))

    init_fn = slim.assign_from_checkpoint_fn(checkpoint,
                                             slim.get_variables_to_restore())
    sess.run([tf.local_variables_initializer()])
    init_fn(sess)

    # Gets the list of the input style images.
    style_img_list = tf.gfile.Glob(FLAGS.style_images_paths)
    if len(style_img_list) > FLAGS.maximum_styles_to_evaluate:
      np.random.seed(1234)
      style_img_list = np.random.permutation(style_img_list)
      style_img_list = style_img_list[:FLAGS.maximum_styles_to_evaluate]

    # Gets list of input content images.
    content_img_list = tf.gfile.Glob(FLAGS.content_images_paths)

    for content_i, content_img_path in enumerate(content_img_list):
      content_img_np = image_utils.load_np_image_uint8(content_img_path)[:, :, :
                                                                         3]
      content_img_name = os.path.basename(content_img_path)[:-4]

      # Saves preprocessed content image.
      inp_img_croped_resized_np = sess.run(
          content_img_preprocessed, feed_dict={
              content_img_ph: content_img_np
          })
      image_utils.save_np_image(inp_img_croped_resized_np,
                                os.path.join(FLAGS.output_dir,
                                             '%s.jpg' % (content_img_name)))

      # Computes bottleneck features of the style prediction network for the
      # identity transform.
      identity_params = sess.run(
          bottleneck_feat, feed_dict={style_img_ph: content_img_np})

      for style_i, style_img_path in enumerate(style_img_list):
        if style_i > FLAGS.maximum_styles_to_evaluate:
          break
        style_img_name = os.path.basename(style_img_path)[:-4]
        style_image_np = image_utils.load_np_image_uint8(style_img_path)[:, :, :
                                                                         3]

        if style_i % 10 == 0:
          tf.logging.info('Stylizing (%d) %s with (%d) %s' %
                          (content_i, content_img_name, style_i,
                           style_img_name))

        # Saves preprocessed style image.
        style_img_croped_resized_np = sess.run(
            style_img_preprocessed, feed_dict={
                style_img_ph: style_image_np
            })
        image_utils.save_np_image(style_img_croped_resized_np,
                                  os.path.join(FLAGS.output_dir,
                                               '%s.jpg' % (style_img_name)))

        # Computes bottleneck features of the style prediction network for the
        # given style image.
        style_params = sess.run(
            bottleneck_feat, feed_dict={style_img_ph: style_image_np})

        interpolation_weights = ast.literal_eval(FLAGS.interpolation_weights)
        # Interpolates between the parameters of the identity transform and
        # style parameters of the given style image.
        for interp_i, wi in enumerate(interpolation_weights):
          stylized_image_res = sess.run(
              stylized_images,
              feed_dict={
                  bottleneck_feat:
                      identity_params * (1 - wi) + style_params * wi,
                  content_img_ph:
                      content_img_np
              })

          # Saves stylized image.
          image_utils.save_np_image(
              stylized_image_res,
              os.path.join(FLAGS.output_dir, '%s_stylized_%s_%d.jpg' %
                           (content_img_name, style_img_name, interp_i)))
Example #3
0
def main(unused_argv=None):
    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        # Forces all input processing onto CPU in order to reserve the GPU for the
        # forward inference and back-propagation.
        device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0'
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               worker_device=device)):
            # Loads content images.
            content_inputs_, _ = image_utils.imagenet_inputs(
                FLAGS.batch_size, FLAGS.image_size)

            # Loads style images.
            [style_inputs_, _,
             style_inputs_orig_] = image_utils.arbitrary_style_image_inputs(
                 FLAGS.style_dataset_file,
                 batch_size=FLAGS.batch_size,
                 image_size=FLAGS.image_size,
                 shuffle=True,
                 center_crop=FLAGS.center_crop,
                 augment_style_images=FLAGS.augment_style_images,
                 random_style_image_size=FLAGS.random_style_image_size)

        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
            # Process style and content weight flags.
            content_weights = ast.literal_eval(FLAGS.content_weights)
            style_weights = ast.literal_eval(FLAGS.style_weights)

            # Define the model
            stylized_images, total_loss, loss_dict, _ = build_model.build_model(
                content_inputs_,
                style_inputs_,
                trainable=True,
                is_training=True,
                inception_end_point='Mixed_6e',
                style_prediction_bottleneck=100,
                adds_losses=True,
                content_weights=content_weights,
                style_weights=style_weights,
                total_variation_weight=FLAGS.total_variation_weight)

            # Adding scalar summaries to the tensorboard.
            for key, value in loss_dict.iteritems():
                tf.summary.scalar(key, value)

            # Adding Image summaries to the tensorboard.
            tf.summary.image('image/0_content_inputs', content_inputs_, 3)
            tf.summary.image('image/1_style_inputs_orig', style_inputs_orig_,
                             3)
            tf.summary.image('image/2_style_inputs_aug', style_inputs_, 3)
            tf.summary.image('image/3_stylized_images', stylized_images, 3)

            # Set up training
            optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
            train_op = slim.learning.create_train_op(
                total_loss,
                optimizer,
                clip_gradient_norm=FLAGS.clip_gradient_norm,
                summarize_gradients=False)

            # Function to restore VGG16 parameters.
            init_fn_vgg = slim.assign_from_checkpoint_fn(
                vgg.checkpoint_file(), slim.get_variables('vgg_16'))

            # Function to restore Inception_v3 parameters.
            inception_variables_dict = {
                var.op.name: var
                for var in slim.get_model_variables('InceptionV3')
            }
            init_fn_inception = slim.assign_from_checkpoint_fn(
                FLAGS.inception_v3_checkpoint, inception_variables_dict)

            # Function to restore VGG16 and Inception_v3 parameters.
            def init_sub_networks(session):
                init_fn_vgg(session)
                init_fn_inception(session)

            # Run training
            slim.learning.train(train_op=train_op,
                                logdir=os.path.expanduser(FLAGS.train_dir),
                                master=FLAGS.master,
                                is_chief=FLAGS.task == 0,
                                number_of_steps=FLAGS.train_steps,
                                init_fn=init_sub_networks,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
Example #4
0
    def _run_tf_graph(
        self,
        sess,
        content_path,
        style_path,
        content_size,
        style_size,
        content_square_crop,
        style_square_crop,
        interp_weights,
        max_styles_to_evaluate,
    ):
        """Create a Tensorflow static graph that defines all computation needed

        Parameters
        ----------
        sess : tf.Session
            Session to run this graph on
        content_path : str
            Path to content images
        style_path : str
            Path to style images
        content_size : int
            Size to resize the content image into
        style_size : int
            Size to resize the style image into
        content_square_crop : bool
            Trigger to crop the content into a square
        style_square_crop : bool
            Trigger to crop the style into a square
        interp_weights : list
            Interpolation weights
        max_styles_to_evaluate : int
            Maximum number of styles to run style transfer into
        """
        # Define placeholder for style image
        style_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
        style_img_preprocessed = self._create_placeholder(
            style_square_crop, style_img_ph, style_size
        )
        # Define placeholder for content image
        content_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
        content_img_preprocessed = self._create_placeholder(
            content_square_crop, content_img_ph, content_size
        )
        # Define the model
        stylized_images, _, _, bottleneck_feat = build_model.build_model(
            content_img_preprocessed,
            style_img_preprocessed,
            trainable=False,
            is_training=False,
            inception_end_point='Mixed_6e',
            style_prediction_bottleneck=100,
            adds_losses=False,
        )

        # Load checkpoint
        if tf.gfile.IsDirectory(self.checkpoint):
            checkpoint = tf.train.latest_checkpoint(self.checkpoint)
        else:
            checkpoint = self.checkpoint
            self.logger.info(
                'Loading latest checkpoint file: {}'.format(checkpoint)
            )

        init_fn = slim.assign_from_checkpoint_fn(
            model_path=checkpoint, var_list=slim.get_variables_to_restore()
        )
        sess.run([tf.local_variables_initializer()])
        init_fn(sess)

        # Get list of content and style images
        content_img_list, style_img_list = self._get_img_lists(
            content_path, style_path, max_styles_to_evaluate
        )

        for content_i, content_img_path in enumerate(content_img_list):
            content_img_np, content_img_name = self._get_data_and_name(
                img_path=content_img_path
            )

            # Compute bottleneck features of the style prediction network
            # for the identity transform
            identity_params = sess.run(
                bottleneck_feat, feed_dict={style_img_ph: content_img_np}
            )

            for style_i, style_img_path in enumerate(style_img_list):
                if style_i > max_styles_to_evaluate:
                    break
                style_img_np, style_img_name = self._get_data_and_name(
                    img_path=style_img_path
                )

                if style_i % 10 == 0:
                    self.logger.info(
                        'Stylizing ({}) {} with {} {}'.format(
                            content_i,
                            content_img_name,
                            style_i,
                            style_img_name,
                        )
                    )

                # Compute bottleneck features of the style prediction
                style_params = sess.run(
                    bottleneck_feat, feed_dict={style_img_ph: style_img_np}
                )
                for interp_i, wi in enumerate(interp_weights):
                    stylized_image_res = sess.run(
                        stylized_images,
                        feed_dict={
                            bottleneck_feat: identity_params * (1 - wi)
                            + style_params * wi,
                            content_img_ph: content_img_np,
                        },
                    )

                    # Save stylized image
                    fname = os.path.join(
                        self.output,
                        '{}_stylized_{}_{}.png'.format(
                            content_img_name, style_img_name, interp_i
                        ),
                    )
                    self._save_image(stylized_image_res[0], fname)
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    with tf.Graph().as_default():
        # Loads content images.
        eval_content_inputs_, _ = image_utils.imagenet_inputs(
            FLAGS.batch_size, FLAGS.image_size)

        # Process style and content weight flags.
        content_weights = ast.literal_eval(FLAGS.content_weights)
        style_weights = ast.literal_eval(FLAGS.style_weights)

        # Loads evaluation style images.
        eval_style_inputs_, _, _ = image_utils.arbitrary_style_image_inputs(
            FLAGS.eval_style_dataset_file,
            batch_size=FLAGS.batch_size,
            image_size=FLAGS.image_size,
            center_crop=True,
            shuffle=True,
            augment_style_images=False,
            random_style_image_size=False)

        # Computes stylized noise.
        stylized_noise, _, _, _ = build_model.build_model(
            tf.random_uniform([
                min(4, FLAGS.batch_size), FLAGS.image_size, FLAGS.image_size, 3
            ]),
            tf.slice(eval_style_inputs_, [0, 0, 0, 0],
                     [min(4, FLAGS.batch_size), -1, -1, -1]),
            trainable=False,
            is_training=False,
            reuse=None,
            inception_end_point='Mixed_6e',
            style_prediction_bottleneck=100,
            adds_losses=False)

        # Computes stylized images.
        stylized_images, _, loss_dict, _ = build_model.build_model(
            eval_content_inputs_,
            eval_style_inputs_,
            trainable=False,
            is_training=False,
            reuse=True,
            inception_end_point='Mixed_6e',
            style_prediction_bottleneck=100,
            adds_losses=True,
            content_weights=content_weights,
            style_weights=style_weights,
            total_variation_weight=FLAGS.total_variation_weight)

        # Adds Image summaries to the tensorboard.
        tf.summary.image(
            'image/{}/0_eval_content_inputs'.format(FLAGS.eval_name),
            eval_content_inputs_, 3)
        tf.summary.image(
            'image/{}/1_eval_style_inputs'.format(FLAGS.eval_name),
            eval_style_inputs_, 3)
        tf.summary.image(
            'image/{}/2_eval_stylized_images'.format(FLAGS.eval_name),
            stylized_images, 3)
        tf.summary.image('image/{}/3_stylized_noise'.format(FLAGS.eval_name),
                         stylized_noise, 3)

        metrics = {}
        for key, value in loss_dict.items():
            metrics[key] = tf.metrics.mean(value)

        names_values, names_updates = slim.metrics.aggregate_metric_map(
            metrics)
        for name, value in names_values.items():
            slim.summaries.add_scalar_summary(value, name, print_summary=True)
        eval_op = list(names_updates.values())
        num_evals = FLAGS.num_evaluation_styles / FLAGS.batch_size

        slim.evaluation.evaluation_loop(
            master=FLAGS.master,
            checkpoint_dir=FLAGS.checkpoint_dir,
            logdir=FLAGS.eval_dir,
            eval_op=eval_op,
            num_evals=num_evals,
            eval_interval_secs=FLAGS.eval_interval_secs)
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  with tf.Graph().as_default():
    # Loads content images.
    eval_content_inputs_, _ = image_utils.imagenet_inputs(
        FLAGS.batch_size, FLAGS.image_size)

    # Process style and content weight flags.
    content_weights = ast.literal_eval(FLAGS.content_weights)
    style_weights = ast.literal_eval(FLAGS.style_weights)

    # Loads evaluation style images.
    eval_style_inputs_, _, _ = image_utils.arbitrary_style_image_inputs(
        FLAGS.eval_style_dataset_file,
        batch_size=FLAGS.batch_size,
        image_size=FLAGS.image_size,
        center_crop=True,
        shuffle=True,
        augment_style_images=False,
        random_style_image_size=False)

    # Computes stylized noise.
    stylized_noise, _, _, _ = build_model.build_model(
        tf.random_uniform(
            [min(4, FLAGS.batch_size), FLAGS.image_size, FLAGS.image_size, 3]),
        tf.slice(eval_style_inputs_, [0, 0, 0, 0],
                 [min(4, FLAGS.batch_size), -1, -1, -1]),
        trainable=False,
        is_training=False,
        reuse=None,
        inception_end_point='Mixed_6e',
        style_prediction_bottleneck=100,
        adds_losses=False)

    # Computes stylized images.
    stylized_images, _, loss_dict, _ = build_model.build_model(
        eval_content_inputs_,
        eval_style_inputs_,
        trainable=False,
        is_training=False,
        reuse=True,
        inception_end_point='Mixed_6e',
        style_prediction_bottleneck=100,
        adds_losses=True,
        content_weights=content_weights,
        style_weights=style_weights,
        total_variation_weight=FLAGS.total_variation_weight)

    # Adds Image summaries to the tensorboard.
    tf.summary.image('image/{}/0_eval_content_inputs'.format(FLAGS.eval_name),
                     eval_content_inputs_, 3)
    tf.summary.image('image/{}/1_eval_style_inputs'.format(FLAGS.eval_name),
                     eval_style_inputs_, 3)
    tf.summary.image('image/{}/2_eval_stylized_images'.format(FLAGS.eval_name),
                     stylized_images, 3)
    tf.summary.image('image/{}/3_stylized_noise'.format(FLAGS.eval_name),
                     stylized_noise, 3)

    metrics = {}
    for key, value in loss_dict.iteritems():
      metrics[key] = tf.metrics.mean(value)

    names_values, names_updates = slim.metrics.aggregate_metric_map(metrics)
    for name, value in names_values.iteritems():
      slim.summaries.add_scalar_summary(value, name, print_summary=True)
    eval_op = names_updates.values()
    num_evals = FLAGS.num_evaluation_styles / FLAGS.batch_size

    slim.evaluation.evaluation_loop(
        master=FLAGS.master,
        checkpoint_dir=FLAGS.checkpoint_dir,
        logdir=FLAGS.eval_dir,
        eval_op=eval_op,
        num_evals=num_evals,
        eval_interval_secs=FLAGS.eval_interval_secs)
Example #7
0
    style_img_preprocessed = image_utils.resize_image(
        style_img_ph, style_image_size)

content_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])

if content_square_crop:
    content_img_preprocessed = image_utils.center_crop_resize_image(
        content_img_ph, image_size)
else:
    content_img_preprocessed = image_utils.resize_image(
        content_img_ph, image_size)

stylized_images, _, _, bottleneck_feat = build_model.build_model(
    content_img_preprocessed,
    style_img_preprocessed,
    trainable=False,
    is_training=False,
    inception_end_point='Mixed_6e',
    style_prediction_bottleneck=100,
    adds_losses=False)


@runway.setup
def setup():
    init_fn = slim.assign_from_checkpoint_fn(
        './arbitrary_style_transfer/model.ckpt',
        slim.get_variables_to_restore()
    )
    sess.run([tf.local_variables_initializer()])
    init_fn(sess)
    return sess
def main(unused_argv=None):
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    # Forces all input processing onto CPU in order to reserve the GPU for the
    # forward inference and back-propagation.
    device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0'
    with tf.device(
        tf.train.replica_device_setter(FLAGS.ps_tasks, worker_device=device)):
      # Loads content images.
      content_inputs_, _ = image_utils.imagenet_inputs(FLAGS.batch_size,
                                                       FLAGS.image_size)

      # Loads style images.
      [style_inputs_, _,
       style_inputs_orig_] = image_utils.arbitrary_style_image_inputs(
           FLAGS.style_dataset_file,
           batch_size=FLAGS.batch_size,
           image_size=FLAGS.image_size,
           shuffle=True,
           center_crop=FLAGS.center_crop,
           augment_style_images=FLAGS.augment_style_images,
           random_style_image_size=FLAGS.random_style_image_size)

    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      # Process style and content weight flags.
      content_weights = ast.literal_eval(FLAGS.content_weights)
      style_weights = ast.literal_eval(FLAGS.style_weights)

      # Define the model
      stylized_images, total_loss, loss_dict, _ = build_model.build_model(
          content_inputs_,
          style_inputs_,
          trainable=True,
          is_training=True,
          inception_end_point='Mixed_6e',
          style_prediction_bottleneck=100,
          adds_losses=True,
          content_weights=content_weights,
          style_weights=style_weights,
          total_variation_weight=FLAGS.total_variation_weight)

      # Adding scalar summaries to the tensorboard.
      for key, value in loss_dict.iteritems():
        tf.summary.scalar(key, value)

      # Adding Image summaries to the tensorboard.
      tf.summary.image('image/0_content_inputs', content_inputs_, 3)
      tf.summary.image('image/1_style_inputs_orig', style_inputs_orig_, 3)
      tf.summary.image('image/2_style_inputs_aug', style_inputs_, 3)
      tf.summary.image('image/3_stylized_images', stylized_images, 3)

      # Set up training
      optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
      train_op = slim.learning.create_train_op(
          total_loss,
          optimizer,
          clip_gradient_norm=FLAGS.clip_gradient_norm,
          summarize_gradients=False)

      # Function to restore VGG16 parameters.
      init_fn_vgg = slim.assign_from_checkpoint_fn(vgg.checkpoint_file(),
                                                   slim.get_variables('vgg_16'))

      # Function to restore Inception_v3 parameters.
      inception_variables_dict = {
          var.op.name: var
          for var in slim.get_model_variables('InceptionV3')
      }
      init_fn_inception = slim.assign_from_checkpoint_fn(
          FLAGS.inception_v3_checkpoint, inception_variables_dict)

      # Function to restore VGG16 and Inception_v3 parameters.
      def init_sub_networks(session):
        init_fn_vgg(session)
        init_fn_inception(session)

      # Run training
      slim.learning.train(
          train_op=train_op,
          logdir=os.path.expanduser(FLAGS.train_dir),
          master=FLAGS.master,
          is_chief=FLAGS.task == 0,
          number_of_steps=FLAGS.train_steps,
          init_fn=init_sub_networks,
          save_summaries_secs=FLAGS.save_summaries_secs,
          save_interval_secs=FLAGS.save_interval_secs)
def main(unused_argv=None):
    tf.logging.set_verbosity(tf.logging.INFO)
    #if not tf.gfile.Exists(FLAGS.output_dir):
    #    tf.gfile.MkDir(FLAGS.output_dir)

    if FLAGS.tensorrt:
        gpu_options = None
        print(trt.trt_convert.get_linked_tensorrt_version())
        gpu_options = cpb2.GPUOptions(
            per_process_gpu_memory_fraction=_GPU_MEM_FRACTION)
        sessconfig = cpb2.ConfigProto(gpu_options=gpu_options)
    else:
        sessconfig = None

    # Instantiate video capture object.
    cap = cv2.VideoCapture(1)

    # Set resolution
    # if resolution is not None:
    x_length, y_length = (1024, 1280)
    cap.set(3, x_length)  # 3 and 4 are OpenCV property IDs.
    cap.set(4, y_length)
    cap.read()
    x_new = int(cap.get(3))
    y_new = int(cap.get(4))
    print('Resolution is: {0} by {1}'.format(x_new, y_new))

    with tf.Graph().as_default(), tf.Session(config=sessconfig) as sess:

        #TODO - calculate these dimensions dynamically (they can't use None since TensorRT
        # needs precalculated dimensions

        # Defines place holder for the style image.
        style_img_ph = tf.placeholder(tf.float32,
                                      shape=[200, 1200, 3],
                                      name="style_img_ph")
        if FLAGS.style_square_crop:
            style_img_preprocessed = image_utils.center_crop_resize_image(
                style_img_ph, FLAGS.style_image_size)
        else:
            style_img_preprocessed = image_utils.resize_image(
                style_img_ph, FLAGS.style_image_size)

        # Defines place holder for the content image.
        content_img_ph = tf.placeholder(tf.float32,
                                        shape=[200, 1200, 3],
                                        name="content_img_ph")
        if FLAGS.content_square_crop:
            content_img_preprocessed = image_utils.center_crop_resize_image(
                content_img_ph, FLAGS.image_size)
        else:
            content_img_preprocessed = image_utils.resize_image(
                content_img_ph, FLAGS.image_size)

        # Defines the model.
        stylized_images, _, _, bottleneck_feat = build_model.build_model(
            content_img_preprocessed,
            style_img_preprocessed,
            trainable=False,
            is_training=False,
            inception_end_point='Mixed_6e',
            style_prediction_bottleneck=100,
            adds_losses=False)

        print(stylized_images)
        print(bottleneck_feat)

        if tf.gfile.IsDirectory(FLAGS.checkpoint):
            checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint)
        else:
            checkpoint = FLAGS.checkpoint
            tf.logging.info(
                'loading latest checkpoint file: {}'.format(checkpoint))

        init_fn = slim.assign_from_checkpoint_fn(
            checkpoint, slim.get_variables_to_restore())
        sess.run([tf.local_variables_initializer()])
        init_fn(sess)

        tf.train.write_graph(sess.graph_def, '.', 'model.pbtxt')

        if FLAGS.tensorrt:
            # We use a built-in TF helper to export variables to constants
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess,  # The session is used to retrieve the weights
                tf.get_default_graph().as_graph_def(
                ),  # The graph_def is used to retrieve the nodes 
                [
                    'transformer/expand/conv3/conv/Sigmoid'
                ]  # The output node names are used to select the usefull nodes
            )

            trt_graph = trt.create_inference_graph(
                input_graph_def=output_graph_def,
                outputs=["transformer/expand/conv3/conv/Sigmoid"],
                max_workspace_size_bytes=5 << 30,
                max_batch_size=1,
                precision_mode=
                "FP16",  # TRT Engine precision "FP32","FP16" or "INT8"
                minimum_segment_size=10)

            bottleneck_feat_O, content_img_ph_O, stylized_images_O = importer.import_graph_def(
                graph_def=trt_graph,
                return_elements=[
                    "Conv/BiasAdd", "content_img_ph",
                    "transformer/expand/conv3/conv/Sigmoid"
                ])
            bottleneck_feat_O = bottleneck_feat_O.outputs[0]
            content_img_ph_O = content_img_ph_O.outputs[0]
            stylized_images_O = stylized_images_O.outputs[0]

            print("bottleneck opt:" + str(bottleneck_feat_O))
            print(content_img_ph_O)
            print(stylized_images_O)

        # Gets the list of the input style images.
        #style_img_list = tf.gfile.Glob(FLAGS.style_images_paths)
        # if len(style_img_list) > FLAGS.maximum_styles_to_evaluate:
        #    np.random.seed(1234)
        #    style_img_list = np.random.permutation(style_img_list)
        #    style_img_list = style_img_list[:FLAGS.maximum_styles_to_evaluate]

        # Gets list of input co ntent images.
        # content_img_list = tf.gfile.Glob(FLAGS.content_images_paths)

        # if style_i % 10 == 0:
        # tf.logging.info('Stylizing  %s with (%d) %s' %
        #                        ( content_img_name, style_i,
        #                         style_img_name))

        # for style_i, style_img_path in enumerate(style_img_list):
        # if style_i > FLAGS.maximum_styles_to_evaluate:
        #    break
        interpolation_weight = FLAGS.interpolation_weight
        activate_style = None

        while True:
            start = timer()
            #calculating style isn't the major FPS bottleneck
            current_style = Style.objects.filter(is_active=True).first()
            if (activate_style != current_style):
                activate_style = current_style
                style_img_path = activate_style.source_file.path
                print("current image is " + style_img_path)
                style_img_name = "bricks"
                style_image_np = image_utils.load_np_image_uint8(
                    style_img_path)[:, :, :3]
                style_image_np = cv2.resize(style_image_np, (1200, 200))

                # Saves preprocessed style image.
                style_img_croped_resized_np = sess.run(
                    style_img_preprocessed,
                    feed_dict={style_img_ph: style_image_np})
                #image_utils.save_np_image(style_img_croped_resized_np,
                #                          os.path.join(FLAGS.output_dir,
                #                                       '%s.jpg' % (style_img_name)))

                # Computes bottleneck features of the style prediction network for the
                # given style image.
                style_params = sess.run(
                    bottleneck_feat, feed_dict={style_img_ph: style_image_np})

            # for content_i, content_img_path in enumerate(content_img_list):
            ret, frame = cap.read()
            print("webcam image: " + str(frame.shape))
            #crop to get the weird 1200x200 format
            content_img_np = frame[500:700, 80:1280]
            #content_img_np = frame
            print("cropped image:" + str(content_img_np.shape))
            # content_img_np = image_utils.load_np_image_uint8(content_img_path)[:, :, :
            #                                                                        3]

            # content_img_name = os.path.basename(content_img_path)[:-4]
            content_img_name = "webcam"

            # Saves preprocessed content image.
            print("Input image:" + str(content_img_np.shape))
            inp_img_croped_resized_np = sess.run(
                content_img_preprocessed,
                feed_dict={content_img_ph: content_img_np})
            # image_utils.save_np_image(inp_img_croped_resized_np,
            #                          os.path.join(FLAGS.output_dir,
            #                                       '%s.jpg' % (content_img_name)))

            # Computes bottleneck features of the style prediction network for the
            # identity transform.
            identity_params = sess.run(
                bottleneck_feat, feed_dict={style_img_ph: content_img_np})

            # Interpolates between the parameters of the identity transform and
            # style parameters of the given style image.
            wi = interpolation_weight
            style_np = identity_params * (1 - wi) + style_params * wi
            if FLAGS.tensorrt:
                style_np = np.reshape(style_np, (1, 100, 1, 1))

                stylized_image_res = sess.run(stylized_images_O,
                                              feed_dict={
                                                  bottleneck_feat_O: style_np,
                                                  content_img_ph_O:
                                                  content_img_np
                                              })
            else:
                stylized_image_res = sess.run(stylized_images,
                                              feed_dict={
                                                  bottleneck_feat: style_np,
                                                  content_img_ph:
                                                  content_img_np
                                              })

            end = timer()
            print(end - start)
            print(stylized_image_res.shape)
            # Saves stylized image.
            # image_utils.save_np_image(
            #  stylized_image_res,
            #  os.path.join(FLAGS.output_dir, '%s_stylized_%s_%d.jpg' %
            #               (content_img_name, style_img_name, interp_i)))
            display_np_image(stylized_image_res, FLAGS.showFullScreen)
            print(stylized_image_res.shape)
            # if cv2.waitKey(1) & 0xFF == ord('q'):
            #  break
            #img_out = np.squeeze(stylized_image_res).astype(np.uint8)
            #img_out = cv2.cvtColor(img_out, cv2.COLOR_BGR2RGB)
            #cv2.imshow('frame', img_out)

            key = cv2.waitKey(10)
            print("Key " + str(key))
            if key == 27:
                break
            elif key == 192:
                FLAGS.showFullScreen = False
                cv2.setWindowProperty("window", cv2.WND_PROP_FULLSCREEN,
                                      cv2.WINDOW_NORMAL)
            elif (key == 233 or key == 193):
                FLAGS.showFullScreen = True
                cv2.setWindowProperty("window", cv2.WND_PROP_FULLSCREEN,
                                      cv2.WINDOW_FULLSCREEN)
            elif key == 60:  # less
                interpolation_weight -= 0.25
            elif key == 62:  # > more
                interpolation_weight += 0.25

            #if cv2.waitKey(1) & 0xFF == ord('q'):
            #    break

    cap.release()
    cv2.destroyAllWindows()
def main(unused_argv=None):
    tf.logging.set_verbosity(tf.logging.INFO)
    if not tf.gfile.Exists(FLAGS.output_dir):
        tf.gfile.MkDir(FLAGS.output_dir)

    with tf.Graph().as_default(), tf.Session() as sess:
        # Defines place holder for the style image.
        style_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
        if FLAGS.style_square_crop:
            style_img_preprocessed = image_utils.center_crop_resize_image(
                style_img_ph, FLAGS.style_image_size)
        else:
            style_img_preprocessed = image_utils.resize_image(
                style_img_ph, FLAGS.style_image_size)

        #video input
        capture = cv2.VideoCapture(FLAGS.video_path)
        fps = capture.get(cv2.CAP_PROP_FPS)
        video_name = os.path.basename(FLAGS.video_path)[:-4]

        in_width = capture.get(cv2.CAP_PROP_FRAME_WIDTH)
        in_height = capture.get(cv2.CAP_PROP_FRAME_HEIGHT)

        FLAGS.image_size = int(min(in_width, in_height, FLAGS.image_size))

        # Defines place holder for the content image.
        content_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
        if FLAGS.content_square_crop:
            content_img_preprocessed = image_utils.center_crop_resize_image(
                content_img_ph, FLAGS.image_size)
        else:
            content_img_preprocessed = image_utils.resize_image(
                content_img_ph, FLAGS.image_size)

        # Defines the model.
        stylized_images, _, _, bottleneck_feat = build_model.build_model(
            content_img_preprocessed,
            style_img_preprocessed,
            trainable=False,
            is_training=False,
            inception_end_point='Mixed_6e',
            style_prediction_bottleneck=100,
            adds_losses=False)

        if tf.gfile.IsDirectory(FLAGS.checkpoint):
            checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint)
        else:
            checkpoint = FLAGS.checkpoint
            tf.logging.info(
                'loading latest checkpoint file: {}'.format(checkpoint))

        init_fn = slim.assign_from_checkpoint_fn(
            checkpoint, slim.get_variables_to_restore())
        sess.run([tf.local_variables_initializer()])
        init_fn(sess)

        # Gets the list of the input style images.
        style_img_path = tf.gfile.Glob(FLAGS.style_images_path)[0]
        print("\nstyling using " + os.path.basename(style_img_path) + " at " +
              str(FLAGS.image_size) + "p")
        style = os.path.basename(style_img_path)[:-4]
        style_image_np = image_utils.load_np_image_uint8(
            style_img_path)[:, :, :3]

        # Computes bottleneck features of the style prediction network for the
        # given style image.
        style_params = sess.run(bottleneck_feat,
                                feed_dict={style_img_ph: style_image_np})

        #video output
        width = int(FLAGS.image_size * in_width / in_height)
        codec = cv2.VideoWriter_fourcc(*"MJPG")
        out_fps = fps / (1 + FLAGS.frame_skips)
        out_file = os.path.join(
            FLAGS.output_dir,
            video_name + "_" + style + "_" + str(FLAGS.image_size) + '.avi')
        out = cv2.VideoWriter(out_file, codec, out_fps,
                              (width, FLAGS.image_size), True)

        #audio input
        cmd = "ffmpeg -y -loglevel quiet -i {} -ab 160k -ac 2 -ar 44100 -vn {}.wav".format(
            FLAGS.video_path, video_name)
        subprocess.call(cmd, shell=True)
        y, sr = librosa.load(video_name + ".wav")
        # tempo, beats = librosa.beat.beat_track(y=y, sr=sr, units="time", tightness=10)
        feature_split = 1
        # rms = librosa.feature.rmse(y=y, frame_length=int(sr/out_fps/feature_split))[0]
        bins_per_octave = 4
        hop_length = int(sr / out_fps)
        n_bins = bins_per_octave * 5
        # cqt = np.abs(librosa.core.cqt(y, sr=sr, fmin=30, n_bins=n_bins, bins_per_octave=bins_per_octave, hop_length=hop_length))
        cqt = np.abs(librosa.core.stft(y, hop_length=hop_length))
        cqt_sr = sr / hop_length

        output_files = []
        hasFrame = capture.isOpened()
        i = 0
        start = time.time()
        maxWeight = 1
        lastWeight = 0
        while (True):
            frame_start = time.time()
            # skip frames
            for skip in range(FLAGS.frame_skips):
                capture.grab()

            hasFrame, frame = capture.read()

            if not hasFrame:
                break

            inp = cv2.resize(frame, (FLAGS.image_size, FLAGS.image_size))
            content_img_np = inp[:, :, [2, 1, 0]]
            content_img_name = video_name + "_" + str(i)

            # for content_i, content_img_path in content_enum:
            # if video:
            #   content_img_np = video[content_i]
            #   content_img_name = str(content_i)
            # else:
            #   content_img_np = image_utils.load_np_image_uint8(content_img_path)[:, :, :3]
            #   content_img_name = os.path.basename(content_img_path)[:-4]

            # Saves preprocessed content image.
            # inp_img_croped_resized_np = sess.run(
            #     content_img_preprocessed, feed_dict={
            #         content_img_ph: content_img_np
            #     })
            # image_utils.save_np_image(inp_img_croped_resized_np,
            #                           os.path.join(FLAGS.output_dir,
            #                                       '%s.jpg' % (content_img_name)))

            # Computes bottleneck features of the style prediction network for the
            # identity transform.
            identity_params = sess.run(
                bottleneck_feat, feed_dict={style_img_ph: content_img_np})

            duration = time.time() - start

            # while beats[0] < duration:
            #     beats = beats[1:]
            # weight = max(0, min(1, abs(beats[0]-duration)))

            weight = 0

            # print(cqt.shape[0])
            bin_start = int(cqt.shape[0] * 0.001)
            bins = int(cqt.shape[0] * 0.25)
            for bin_i in range(bin_start, bin_start + bins):
                cur = cqt[bin_i, int(cqt_sr * i / out_fps)]
                weight += cur

            weight = weight / bins
            # weight = min(1, min(1, weight/bins)*FLAGS.interpolation_weight)

            # weight = min(1, cqt[])
            # weight = 1 - FLAGS.interpolation_weight * (1+sin(i/7*pi))/2

            # print(weight)
            maxWeight = max(maxWeight, weight)
            weight /= maxWeight
            weight *= weight
            weight = max(weight, lastWeight - 0.1)

            lastWeight = weight

            stylized_image_res = sess.run(stylized_images,
                                          feed_dict={
                                              bottleneck_feat:
                                              identity_params * (1 - weight) +
                                              style_params * weight,
                                              content_img_ph:
                                              content_img_np
                                          })

            # output_filename = '%s_stylized_%s.jpg' % (content_img_name, style)
            # output_file = os.path.join(FLAGS.output_dir, output_filename)

            # Saves stylized image.
            # image_utils.save_np_image(stylized_image_res, output_file)
            # output_files += output_file

            #writes image to video output
            sqr_output_frame = np.uint8(stylized_image_res * 255)[0][:, :,
                                                                     [2, 1, 0]]
            out.write(cv2.resize(sqr_output_frame, (width, FLAGS.image_size)))

            tf.logging.info('Stylized %s with weight %.2f at %.1f fps' %
                            (content_img_name, weight, 1 /
                             (time.time() - frame_start)))

            # print("Outputted " + '%s_stylized_%s.jpg' %
            #     (content_img_name, style))

            i += 1

        out.release()
        capture.release()
        cmd = "ffmpeg -i {} -i {}.wav -c:v copy -shortest -map 0:v:0 -map 1:a:0 temp.avi".format(
            out_file, video_name)
        print(cmd)
        subprocess.call(cmd, shell=True)
        subprocess.call("mv -f temp.avi {}".format(out_file), shell=True)
        subprocess.call("rm {}.wav".format(video_name), shell=True)

        print("Average fps: " + str(i / (time.time() - start)))
        return 0