Exemplo n.º 1
0
    def _create_placeholder(self, square_crop, img_ph, size):
        """Create input image placeholder

        Returns a subgraph with resize steps for the image

        Parameters
        ----------
        square_crop : bool
            Perfom square cropping or not
        img_ph : tf.Tensor
            Input image placeholder

        Returns
        -------
        tf.Tensor
           Image placeholder with preprocessing step
        """
        if square_crop:
            self.logger.debug('Cropping image')
            img_preprocessed = image_utils.center_crop_resize_image(
                img_ph, size
            )
        else:
            img_preprocessed = image_utils.resize_image(img_ph, size)
        return img_preprocessed
Exemplo n.º 2
0
def main(unused_argv=None):
  tf.logging.set_verbosity(tf.logging.INFO)
  if not tf.gfile.Exists(FLAGS.output_dir):
    tf.gfile.MkDir(FLAGS.output_dir)

  with tf.Graph().as_default(), tf.Session() as sess:
    # Defines place holder for the style image.
    style_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
    if FLAGS.style_square_crop:
      style_img_preprocessed = image_utils.center_crop_resize_image(
          style_img_ph, FLAGS.style_image_size)
    else:
      style_img_preprocessed = image_utils.resize_image(style_img_ph,
                                                        FLAGS.style_image_size)

    # Defines place holder for the content image.
    content_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
    if FLAGS.content_square_crop:
      content_img_preprocessed = image_utils.center_crop_resize_image(
          content_img_ph, FLAGS.image_size)
    else:
      content_img_preprocessed = image_utils.resize_image(
          content_img_ph, FLAGS.image_size)

    # Defines the model.
    stylized_images, _, _, bottleneck_feat = build_model.build_model(
        content_img_preprocessed,
        style_img_preprocessed,
        trainable=False,
        is_training=False,
        inception_end_point='Mixed_6e',
        style_prediction_bottleneck=100,
        adds_losses=False)

    if tf.gfile.IsDirectory(FLAGS.checkpoint):
      checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint)
    else:
      checkpoint = FLAGS.checkpoint
      tf.logging.info('loading latest checkpoint file: {}'.format(checkpoint))

    init_fn = slim.assign_from_checkpoint_fn(checkpoint,
                                             slim.get_variables_to_restore())
    sess.run([tf.local_variables_initializer()])
    init_fn(sess)

    # Gets the list of the input style images.
    style_img_list = tf.gfile.Glob(FLAGS.style_images_paths)
    if len(style_img_list) > FLAGS.maximum_styles_to_evaluate:
      np.random.seed(1234)
      style_img_list = np.random.permutation(style_img_list)
      style_img_list = style_img_list[:FLAGS.maximum_styles_to_evaluate]

    # Gets list of input content images.
    content_img_list = tf.gfile.Glob(FLAGS.content_images_paths)

    for content_i, content_img_path in enumerate(content_img_list):
      content_img_np = image_utils.load_np_image_uint8(content_img_path)[:, :, :
                                                                         3]
      content_img_name = os.path.basename(content_img_path)[:-4]

      # Saves preprocessed content image.
      inp_img_croped_resized_np = sess.run(
          content_img_preprocessed, feed_dict={
              content_img_ph: content_img_np
          })
      image_utils.save_np_image(inp_img_croped_resized_np,
                                os.path.join(FLAGS.output_dir,
                                             '%s.jpg' % (content_img_name)))

      # Computes bottleneck features of the style prediction network for the
      # identity transform.
      identity_params = sess.run(
          bottleneck_feat, feed_dict={style_img_ph: content_img_np})

      for style_i, style_img_path in enumerate(style_img_list):
        if style_i > FLAGS.maximum_styles_to_evaluate:
          break
        style_img_name = os.path.basename(style_img_path)[:-4]
        style_image_np = image_utils.load_np_image_uint8(style_img_path)[:, :, :
                                                                         3]

        if style_i % 10 == 0:
          tf.logging.info('Stylizing (%d) %s with (%d) %s' %
                          (content_i, content_img_name, style_i,
                           style_img_name))

        # Saves preprocessed style image.
        style_img_croped_resized_np = sess.run(
            style_img_preprocessed, feed_dict={
                style_img_ph: style_image_np
            })
        image_utils.save_np_image(style_img_croped_resized_np,
                                  os.path.join(FLAGS.output_dir,
                                               '%s.jpg' % (style_img_name)))

        # Computes bottleneck features of the style prediction network for the
        # given style image.
        style_params = sess.run(
            bottleneck_feat, feed_dict={style_img_ph: style_image_np})

        interpolation_weights = ast.literal_eval(FLAGS.interpolation_weights)
        # Interpolates between the parameters of the identity transform and
        # style parameters of the given style image.
        for interp_i, wi in enumerate(interpolation_weights):
          stylized_image_res = sess.run(
              stylized_images,
              feed_dict={
                  bottleneck_feat:
                      identity_params * (1 - wi) + style_params * wi,
                  content_img_ph:
                      content_img_np
              })

          # Saves stylized image.
          image_utils.save_np_image(
              stylized_image_res,
              os.path.join(FLAGS.output_dir, '%s_stylized_%s_%d.jpg' %
                           (content_img_name, style_img_name, interp_i)))
Exemplo n.º 3
0
def main(unused_argv=None):
    print('timer start')
    start = time.time()

    words = [
        '布', '植物', 'ガラス', '革', '金属', '紙', 'プラスチック', '石', '水', '木', '樹脂',
        'アクリル', 'アルミニウム', '牛皮', 'レンガ', '絹'
    ]

    config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
    sess = tf.Session(config=config)

    tf.logging.set_verbosity(tf.logging.INFO)
    if not tf.gfile.Exists(FLAGS.output_dir):
        tf.gfile.MkDir(FLAGS.output_dir)

    with tf.Graph().as_default(), sess:
        # Defines place holder for the style image.
        style_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
        if FLAGS.style_square_crop:
            style_img_preprocessed = image_utils.center_crop_resize_image(
                style_img_ph, FLAGS.style_image_size)
        else:
            style_img_preprocessed = image_utils.resize_image(
                style_img_ph, FLAGS.style_image_size)

        # Defines place holder for the content image.
        content_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
        if FLAGS.content_square_crop:
            content_img_preprocessed = image_utils.center_crop_resize_image(
                content_img_ph, FLAGS.image_size)
        else:
            content_img_preprocessed = image_utils.resize_image(
                content_img_ph, FLAGS.image_size)

        # Defines the model.
        stylized_images, _, _, bottleneck_feat = build_model.build_model(
            content_img_preprocessed,
            style_img_preprocessed,
            trainable=False,
            is_training=False,
            inception_end_point='Mixed_6e',
            style_prediction_bottleneck=100,
            adds_losses=False)

        if tf.gfile.IsDirectory(FLAGS.checkpoint):
            checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint)
        else:
            checkpoint = FLAGS.checkpoint
            tf.logging.info(
                'loading latest checkpoint file: {}'.format(checkpoint))

        init_fn = slim.assign_from_checkpoint_fn(
            checkpoint, slim.get_variables_to_restore())
        sess.run([tf.local_variables_initializer()])
        init_fn(sess)

        # Gets the list of the input style images.
        style_img_list = tf.gfile.Glob(FLAGS.style_images_paths)
        if len(style_img_list) > FLAGS.maximum_styles_to_evaluate:
            np.random.seed(1234)
            style_img_list = np.random.permutation(style_img_list)
            style_img_list = style_img_list[:FLAGS.maximum_styles_to_evaluate]

        # Gets list of input content images.
        content_img_list = tf.gfile.Glob(FLAGS.content_images_paths)

        j = -1
        for content_i, content_img_path in enumerate(content_img_list):
            j += 1
            content_img_np = image_utils.load_np_image_uint8(
                content_img_path)[:, :, :3]
            content_img_name = os.path.basename(content_img_path)[:-4]

            # Saves preprocessed content image.
            inp_img_croped_resized_np = sess.run(
                content_img_preprocessed,
                feed_dict={content_img_ph: content_img_np})
            image_utils.save_np_image(
                inp_img_croped_resized_np,
                os.path.join(FLAGS.output_dir, '%s.jpg' % (content_img_name)))

            if FLAGS.color_preserve is True:
                print('color preserve mode!')
                # convert content iamge to ycc from bgr
                height, width, channels = inp_img_croped_resized_np[
                    0].shape[:3]
                # print(inp_img_croped_resized_np)
                wgap = 4 - (
                    width % 4
                )  # fuking translater made some gaps because of decoder
                hgap = 4 - (height % 4)
                inp_img_croped_resized_np = inp_img_croped_resized_np * 255
                content_img_np_ycc = cv2.cvtColor(inp_img_croped_resized_np[0],
                                                  cv2.COLOR_RGB2YCR_CB)
                # print(content_img_np_ycc)
                zeros = np.zeros((height, width), content_img_np_ycc.dtype)
                zeros = zeros + 128  # YCC's zero is center of 255
                tmp = cv2.cvtColor(content_img_np_ycc, cv2.COLOR_YCR_CB2BGR)
                cv2.imwrite("gray.jpg", tmp)
                # print(zeros)
                Ycontent, Crcontent, Cbcontent = cv2.split(content_img_np_ycc)
                # print(Ycontent.shape, Crcontent.shape, Cbcontent.shape, zeros.shape)
                # print(Crcontent)
                # print(content_img_np_ycc)
                # content_img_np_ycc_y = cv2.merge((Y, zeros, zeros))
                # content_img_np_gry = cv2.cvtColor(content_img_np_ycc_y, cv2.COLOR_YCR_CB2RGB)
                # print(content_img_np_gry)
                # cv2.imwrite("gray.jpg", content_img_np_gry)
                # print(np.shape(content_img_np))
                # content_img_np = content_img_np_gry

            # Computes bottleneck features of the style prediction network for the
            # identity transform.

            identity_params = sess.run(
                bottleneck_feat, feed_dict={style_img_ph: content_img_np})

            i = 0
            for word in words:
                # word = words[i]
                print(word)
                i += 1
                # if style_i > FLAGS.maximum_styles_to_evaluate:
                # break
                # style_img_name = os.path.basename(style_img_path)[:-4]
                # style_image_np = image_utils.load_np_image_uint8(style_img_path)[:, :, :
                # 3]

                # if style_i % 10 == 0:
                # tf.logging.info('Stylizing (%d) %s with (%d) %s' %
                # (content_i, content_img_name, style_i,
                # style_img_name))

                # Saves preprocessed style image.
                # style_img_croped_resized_np = sess.run(
                # style_img_preprocessed, feed_dict={
                # style_img_ph: style_image_np
                # })
                # image_utils.save_np_image(style_img_croped_resized_np,
                # os.path.join(FLAGS.output_dir,
                # '%s.jpg' % (style_img_name)))

                # Computes bottleneck features of the style prediction network for the
                # given style image.
                # style_params_ori = sess.run(
                # bottleneck_feat, feed_dict={style_img_ph: style_image_np})

                # print(np.shape(style_params))
                picklename = 'params/{}_{}.pickle'.format(word, j)
                f = open(picklename, 'r')
                style_params = pickle.load(f)
                # print(style_params)

                # print('diff of original para and made para:')
                # print(style_params_ori - style_params)

                interpolation_weights = ast.literal_eval(
                    FLAGS.interpolation_weights)
                # Interpolates between the parameters of the identity transform and
                # style parameters of the given style image.
                for interp_i, wi in enumerate(interpolation_weights):
                    stylized_image_res = sess.run(
                        stylized_images,
                        feed_dict={
                            bottleneck_feat:
                            identity_params * (1 - wi) + style_params * wi,
                            content_img_ph:
                            content_img_np
                        })

                    if FLAGS.color_preserve is True:
                        # print(stylized_image_res[0].shape)
                        stylized_image_res_ycc = cv2.cvtColor(
                            stylized_image_res[0], cv2.COLOR_RGB2YCR_CB)
                        Ystylized, Crstylized, Cbstylized = cv2.split(
                            stylized_image_res_ycc)
                        if wgap == 4:  # if original image is just fit
                            Ystylized_crop = Ystylized * 255
                        else:
                            Ystylized_crop = Ystylized[:, :-1 * wgap] * 255
                        if hgap == 4:
                            Ystylized_crop = Ystylized_crop
                        else:
                            Ystylized_crop = Ystylized_crop[:-1 * hgap, :]
                        print(Ystylized_crop.shape, Cbcontent.shape)
                        # print(wgap)
                        swapped_ycc = cv2.merge(
                            (Ystylized_crop, Crcontent, Cbcontent))
                        # print(swapped_ycc)
                        stylized_image_res = cv2.cvtColor(
                            swapped_ycc, cv2.COLOR_YCR_CB2BGR)
                        # print(stylized_image_res)
                        cv2.imwrite(
                            os.path.join(
                                FLAGS.output_dir, '%s_stylized_%s_%d.jpg' %
                                (content_img_name, word, interp_i)),
                            stylized_image_res)

                    # Saves stylized image.
                    else:
                        image_utils.save_np_image(
                            stylized_image_res,
                            os.path.join(
                                FLAGS.output_dir, '%s_stylized_%s_%d.jpg' %
                                (content_img_name, word, interp_i)))

    elapsed_time = time.time() - start
    print("timer stop")
    print(elapsed_time)
Exemplo n.º 4
0
def styleParam(style_img_list):
    config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
    sess = tf.Session(config=config)
    slim = tf.contrib.slim
    checkpoint_in = 'arbitrary_style_transfer/model.ckpt'
    output_dir = 'outputs_style'
    image_size = 256
    content_square_crop = False
    style_image_size = 256
    style_square_crop = False
    maximum_styles_to_evaluate = 1024
    interpolation_weights_in = '[1.0]'

    tf.logging.set_verbosity(tf.logging.INFO)
    style_param_matrix = []
    if not tf.gfile.Exists(output_dir):
        tf.gfile.MkDir(output_dir)

    with tf.Graph().as_default(), sess:
        # Defines place holder for the style image.
        style_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
        if style_square_crop:
            style_img_preprocessed = image_utils.center_crop_resize_image(
                style_img_ph, style_image_size)
        else:
            style_img_preprocessed = image_utils.resize_image(
                style_img_ph, style_image_size)

        # Defines place holder for the content image.
        content_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
        if content_square_crop:
            content_img_preprocessed = image_utils.center_crop_resize_image(
                content_img_ph, image_size)
        else:
            content_img_preprocessed = image_utils.resize_image(
                content_img_ph, image_size)

        # Defines the model.
        stylized_images, _, _, bottleneck_feat = build_model.build_model(
            content_img_preprocessed,
            style_img_preprocessed,
            trainable=False,
            is_training=False,
            inception_end_point='Mixed_6e',
            style_prediction_bottleneck=100,
            adds_losses=False)

        if tf.gfile.IsDirectory(checkpoint_in):
            checkpoint = tf.train.latest_checkpoint(checkpoint_in)
        else:
            checkpoint = checkpoint_in
            tf.logging.info(
                'loading latest checkpoint file: {}'.format(checkpoint))

        init_fn = slim.assign_from_checkpoint_fn(
            checkpoint, slim.get_variables_to_restore())
        sess.run([tf.local_variables_initializer()])
        init_fn(sess)

        # Gets the list of the input style images.
        # style_img_list = tf.gfile.Glob(style_images_paths)
        # if len(style_img_list) > maximum_styles_to_evaluate:
        # np.random.seed(1234)
        # style_img_list = np.random.permutation(style_img_list)
        # style_img_list = style_img_list[:maximum_styles_to_evaluate]

        count = 0
        for style_i, style_img_path in enumerate(style_img_list):
            # if style_i > maximum_styles_to_evaluate:
            # break
            style_image_np = image_utils.load_np_image_uint8(
                style_img_path)[:, :, :3]
            # Computes bottleneck features of the style prediction network for the
            # given style image.
            style_params = sess.run(bottleneck_feat,
                                    feed_dict={style_img_ph: style_image_np})
            style_param_matrix.append(style_params)
            count += 1
            # if count % 100 == 0:
            #     print(count, '/', len(style_img_list))

    # style_param_matrix is (num_of_images,100) vector
    return (style_param_matrix)
Exemplo n.º 5
0
import tensorflow as tf

import runway

slim = tf.contrib.slim

style_image_size = 1024
image_size = 1024
style_square_crop = False
content_square_crop = False

sess = tf.InteractiveSession()
style_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])

if style_square_crop:
    style_img_preprocessed = image_utils.center_crop_resize_image(
        style_img_ph, style_image_size)
else:
    style_img_preprocessed = image_utils.resize_image(
        style_img_ph, style_image_size)

content_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])

if content_square_crop:
    content_img_preprocessed = image_utils.center_crop_resize_image(
        content_img_ph, image_size)
else:
    content_img_preprocessed = image_utils.resize_image(
        content_img_ph, image_size)

stylized_images, _, _, bottleneck_feat = build_model.build_model(
    content_img_preprocessed,
def main(unused_argv=None):
    tf.logging.set_verbosity(tf.logging.INFO)
    #if not tf.gfile.Exists(FLAGS.output_dir):
    #    tf.gfile.MkDir(FLAGS.output_dir)

    if FLAGS.tensorrt:
        gpu_options = None
        print(trt.trt_convert.get_linked_tensorrt_version())
        gpu_options = cpb2.GPUOptions(
            per_process_gpu_memory_fraction=_GPU_MEM_FRACTION)
        sessconfig = cpb2.ConfigProto(gpu_options=gpu_options)
    else:
        sessconfig = None

    # Instantiate video capture object.
    cap = cv2.VideoCapture(1)

    # Set resolution
    # if resolution is not None:
    x_length, y_length = (1024, 1280)
    cap.set(3, x_length)  # 3 and 4 are OpenCV property IDs.
    cap.set(4, y_length)
    cap.read()
    x_new = int(cap.get(3))
    y_new = int(cap.get(4))
    print('Resolution is: {0} by {1}'.format(x_new, y_new))

    with tf.Graph().as_default(), tf.Session(config=sessconfig) as sess:

        #TODO - calculate these dimensions dynamically (they can't use None since TensorRT
        # needs precalculated dimensions

        # Defines place holder for the style image.
        style_img_ph = tf.placeholder(tf.float32,
                                      shape=[200, 1200, 3],
                                      name="style_img_ph")
        if FLAGS.style_square_crop:
            style_img_preprocessed = image_utils.center_crop_resize_image(
                style_img_ph, FLAGS.style_image_size)
        else:
            style_img_preprocessed = image_utils.resize_image(
                style_img_ph, FLAGS.style_image_size)

        # Defines place holder for the content image.
        content_img_ph = tf.placeholder(tf.float32,
                                        shape=[200, 1200, 3],
                                        name="content_img_ph")
        if FLAGS.content_square_crop:
            content_img_preprocessed = image_utils.center_crop_resize_image(
                content_img_ph, FLAGS.image_size)
        else:
            content_img_preprocessed = image_utils.resize_image(
                content_img_ph, FLAGS.image_size)

        # Defines the model.
        stylized_images, _, _, bottleneck_feat = build_model.build_model(
            content_img_preprocessed,
            style_img_preprocessed,
            trainable=False,
            is_training=False,
            inception_end_point='Mixed_6e',
            style_prediction_bottleneck=100,
            adds_losses=False)

        print(stylized_images)
        print(bottleneck_feat)

        if tf.gfile.IsDirectory(FLAGS.checkpoint):
            checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint)
        else:
            checkpoint = FLAGS.checkpoint
            tf.logging.info(
                'loading latest checkpoint file: {}'.format(checkpoint))

        init_fn = slim.assign_from_checkpoint_fn(
            checkpoint, slim.get_variables_to_restore())
        sess.run([tf.local_variables_initializer()])
        init_fn(sess)

        tf.train.write_graph(sess.graph_def, '.', 'model.pbtxt')

        if FLAGS.tensorrt:
            # We use a built-in TF helper to export variables to constants
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess,  # The session is used to retrieve the weights
                tf.get_default_graph().as_graph_def(
                ),  # The graph_def is used to retrieve the nodes 
                [
                    'transformer/expand/conv3/conv/Sigmoid'
                ]  # The output node names are used to select the usefull nodes
            )

            trt_graph = trt.create_inference_graph(
                input_graph_def=output_graph_def,
                outputs=["transformer/expand/conv3/conv/Sigmoid"],
                max_workspace_size_bytes=5 << 30,
                max_batch_size=1,
                precision_mode=
                "FP16",  # TRT Engine precision "FP32","FP16" or "INT8"
                minimum_segment_size=10)

            bottleneck_feat_O, content_img_ph_O, stylized_images_O = importer.import_graph_def(
                graph_def=trt_graph,
                return_elements=[
                    "Conv/BiasAdd", "content_img_ph",
                    "transformer/expand/conv3/conv/Sigmoid"
                ])
            bottleneck_feat_O = bottleneck_feat_O.outputs[0]
            content_img_ph_O = content_img_ph_O.outputs[0]
            stylized_images_O = stylized_images_O.outputs[0]

            print("bottleneck opt:" + str(bottleneck_feat_O))
            print(content_img_ph_O)
            print(stylized_images_O)

        # Gets the list of the input style images.
        #style_img_list = tf.gfile.Glob(FLAGS.style_images_paths)
        # if len(style_img_list) > FLAGS.maximum_styles_to_evaluate:
        #    np.random.seed(1234)
        #    style_img_list = np.random.permutation(style_img_list)
        #    style_img_list = style_img_list[:FLAGS.maximum_styles_to_evaluate]

        # Gets list of input co ntent images.
        # content_img_list = tf.gfile.Glob(FLAGS.content_images_paths)

        # if style_i % 10 == 0:
        # tf.logging.info('Stylizing  %s with (%d) %s' %
        #                        ( content_img_name, style_i,
        #                         style_img_name))

        # for style_i, style_img_path in enumerate(style_img_list):
        # if style_i > FLAGS.maximum_styles_to_evaluate:
        #    break
        interpolation_weight = FLAGS.interpolation_weight
        activate_style = None

        while True:
            start = timer()
            #calculating style isn't the major FPS bottleneck
            current_style = Style.objects.filter(is_active=True).first()
            if (activate_style != current_style):
                activate_style = current_style
                style_img_path = activate_style.source_file.path
                print("current image is " + style_img_path)
                style_img_name = "bricks"
                style_image_np = image_utils.load_np_image_uint8(
                    style_img_path)[:, :, :3]
                style_image_np = cv2.resize(style_image_np, (1200, 200))

                # Saves preprocessed style image.
                style_img_croped_resized_np = sess.run(
                    style_img_preprocessed,
                    feed_dict={style_img_ph: style_image_np})
                #image_utils.save_np_image(style_img_croped_resized_np,
                #                          os.path.join(FLAGS.output_dir,
                #                                       '%s.jpg' % (style_img_name)))

                # Computes bottleneck features of the style prediction network for the
                # given style image.
                style_params = sess.run(
                    bottleneck_feat, feed_dict={style_img_ph: style_image_np})

            # for content_i, content_img_path in enumerate(content_img_list):
            ret, frame = cap.read()
            print("webcam image: " + str(frame.shape))
            #crop to get the weird 1200x200 format
            content_img_np = frame[500:700, 80:1280]
            #content_img_np = frame
            print("cropped image:" + str(content_img_np.shape))
            # content_img_np = image_utils.load_np_image_uint8(content_img_path)[:, :, :
            #                                                                        3]

            # content_img_name = os.path.basename(content_img_path)[:-4]
            content_img_name = "webcam"

            # Saves preprocessed content image.
            print("Input image:" + str(content_img_np.shape))
            inp_img_croped_resized_np = sess.run(
                content_img_preprocessed,
                feed_dict={content_img_ph: content_img_np})
            # image_utils.save_np_image(inp_img_croped_resized_np,
            #                          os.path.join(FLAGS.output_dir,
            #                                       '%s.jpg' % (content_img_name)))

            # Computes bottleneck features of the style prediction network for the
            # identity transform.
            identity_params = sess.run(
                bottleneck_feat, feed_dict={style_img_ph: content_img_np})

            # Interpolates between the parameters of the identity transform and
            # style parameters of the given style image.
            wi = interpolation_weight
            style_np = identity_params * (1 - wi) + style_params * wi
            if FLAGS.tensorrt:
                style_np = np.reshape(style_np, (1, 100, 1, 1))

                stylized_image_res = sess.run(stylized_images_O,
                                              feed_dict={
                                                  bottleneck_feat_O: style_np,
                                                  content_img_ph_O:
                                                  content_img_np
                                              })
            else:
                stylized_image_res = sess.run(stylized_images,
                                              feed_dict={
                                                  bottleneck_feat: style_np,
                                                  content_img_ph:
                                                  content_img_np
                                              })

            end = timer()
            print(end - start)
            print(stylized_image_res.shape)
            # Saves stylized image.
            # image_utils.save_np_image(
            #  stylized_image_res,
            #  os.path.join(FLAGS.output_dir, '%s_stylized_%s_%d.jpg' %
            #               (content_img_name, style_img_name, interp_i)))
            display_np_image(stylized_image_res, FLAGS.showFullScreen)
            print(stylized_image_res.shape)
            # if cv2.waitKey(1) & 0xFF == ord('q'):
            #  break
            #img_out = np.squeeze(stylized_image_res).astype(np.uint8)
            #img_out = cv2.cvtColor(img_out, cv2.COLOR_BGR2RGB)
            #cv2.imshow('frame', img_out)

            key = cv2.waitKey(10)
            print("Key " + str(key))
            if key == 27:
                break
            elif key == 192:
                FLAGS.showFullScreen = False
                cv2.setWindowProperty("window", cv2.WND_PROP_FULLSCREEN,
                                      cv2.WINDOW_NORMAL)
            elif (key == 233 or key == 193):
                FLAGS.showFullScreen = True
                cv2.setWindowProperty("window", cv2.WND_PROP_FULLSCREEN,
                                      cv2.WINDOW_FULLSCREEN)
            elif key == 60:  # less
                interpolation_weight -= 0.25
            elif key == 62:  # > more
                interpolation_weight += 0.25

            #if cv2.waitKey(1) & 0xFF == ord('q'):
            #    break

    cap.release()
    cv2.destroyAllWindows()
def main(unused_argv=None):
    tf.logging.set_verbosity(tf.logging.INFO)
    if not tf.gfile.Exists(FLAGS.output_dir):
        tf.gfile.MkDir(FLAGS.output_dir)

    with tf.Graph().as_default(), tf.Session() as sess:
        # Defines place holder for the style image.
        style_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
        if FLAGS.style_square_crop:
            style_img_preprocessed = image_utils.center_crop_resize_image(
                style_img_ph, FLAGS.style_image_size)
        else:
            style_img_preprocessed = image_utils.resize_image(
                style_img_ph, FLAGS.style_image_size)

        #video input
        capture = cv2.VideoCapture(FLAGS.video_path)
        fps = capture.get(cv2.CAP_PROP_FPS)
        video_name = os.path.basename(FLAGS.video_path)[:-4]

        in_width = capture.get(cv2.CAP_PROP_FRAME_WIDTH)
        in_height = capture.get(cv2.CAP_PROP_FRAME_HEIGHT)

        FLAGS.image_size = int(min(in_width, in_height, FLAGS.image_size))

        # Defines place holder for the content image.
        content_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
        if FLAGS.content_square_crop:
            content_img_preprocessed = image_utils.center_crop_resize_image(
                content_img_ph, FLAGS.image_size)
        else:
            content_img_preprocessed = image_utils.resize_image(
                content_img_ph, FLAGS.image_size)

        # Defines the model.
        stylized_images, _, _, bottleneck_feat = build_model.build_model(
            content_img_preprocessed,
            style_img_preprocessed,
            trainable=False,
            is_training=False,
            inception_end_point='Mixed_6e',
            style_prediction_bottleneck=100,
            adds_losses=False)

        if tf.gfile.IsDirectory(FLAGS.checkpoint):
            checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint)
        else:
            checkpoint = FLAGS.checkpoint
            tf.logging.info(
                'loading latest checkpoint file: {}'.format(checkpoint))

        init_fn = slim.assign_from_checkpoint_fn(
            checkpoint, slim.get_variables_to_restore())
        sess.run([tf.local_variables_initializer()])
        init_fn(sess)

        # Gets the list of the input style images.
        style_img_path = tf.gfile.Glob(FLAGS.style_images_path)[0]
        print("\nstyling using " + os.path.basename(style_img_path) + " at " +
              str(FLAGS.image_size) + "p")
        style = os.path.basename(style_img_path)[:-4]
        style_image_np = image_utils.load_np_image_uint8(
            style_img_path)[:, :, :3]

        # Computes bottleneck features of the style prediction network for the
        # given style image.
        style_params = sess.run(bottleneck_feat,
                                feed_dict={style_img_ph: style_image_np})

        #video output
        width = int(FLAGS.image_size * in_width / in_height)
        codec = cv2.VideoWriter_fourcc(*"MJPG")
        out_fps = fps / (1 + FLAGS.frame_skips)
        out_file = os.path.join(
            FLAGS.output_dir,
            video_name + "_" + style + "_" + str(FLAGS.image_size) + '.avi')
        out = cv2.VideoWriter(out_file, codec, out_fps,
                              (width, FLAGS.image_size), True)

        #audio input
        cmd = "ffmpeg -y -loglevel quiet -i {} -ab 160k -ac 2 -ar 44100 -vn {}.wav".format(
            FLAGS.video_path, video_name)
        subprocess.call(cmd, shell=True)
        y, sr = librosa.load(video_name + ".wav")
        # tempo, beats = librosa.beat.beat_track(y=y, sr=sr, units="time", tightness=10)
        feature_split = 1
        # rms = librosa.feature.rmse(y=y, frame_length=int(sr/out_fps/feature_split))[0]
        bins_per_octave = 4
        hop_length = int(sr / out_fps)
        n_bins = bins_per_octave * 5
        # cqt = np.abs(librosa.core.cqt(y, sr=sr, fmin=30, n_bins=n_bins, bins_per_octave=bins_per_octave, hop_length=hop_length))
        cqt = np.abs(librosa.core.stft(y, hop_length=hop_length))
        cqt_sr = sr / hop_length

        output_files = []
        hasFrame = capture.isOpened()
        i = 0
        start = time.time()
        maxWeight = 1
        lastWeight = 0
        while (True):
            frame_start = time.time()
            # skip frames
            for skip in range(FLAGS.frame_skips):
                capture.grab()

            hasFrame, frame = capture.read()

            if not hasFrame:
                break

            inp = cv2.resize(frame, (FLAGS.image_size, FLAGS.image_size))
            content_img_np = inp[:, :, [2, 1, 0]]
            content_img_name = video_name + "_" + str(i)

            # for content_i, content_img_path in content_enum:
            # if video:
            #   content_img_np = video[content_i]
            #   content_img_name = str(content_i)
            # else:
            #   content_img_np = image_utils.load_np_image_uint8(content_img_path)[:, :, :3]
            #   content_img_name = os.path.basename(content_img_path)[:-4]

            # Saves preprocessed content image.
            # inp_img_croped_resized_np = sess.run(
            #     content_img_preprocessed, feed_dict={
            #         content_img_ph: content_img_np
            #     })
            # image_utils.save_np_image(inp_img_croped_resized_np,
            #                           os.path.join(FLAGS.output_dir,
            #                                       '%s.jpg' % (content_img_name)))

            # Computes bottleneck features of the style prediction network for the
            # identity transform.
            identity_params = sess.run(
                bottleneck_feat, feed_dict={style_img_ph: content_img_np})

            duration = time.time() - start

            # while beats[0] < duration:
            #     beats = beats[1:]
            # weight = max(0, min(1, abs(beats[0]-duration)))

            weight = 0

            # print(cqt.shape[0])
            bin_start = int(cqt.shape[0] * 0.001)
            bins = int(cqt.shape[0] * 0.25)
            for bin_i in range(bin_start, bin_start + bins):
                cur = cqt[bin_i, int(cqt_sr * i / out_fps)]
                weight += cur

            weight = weight / bins
            # weight = min(1, min(1, weight/bins)*FLAGS.interpolation_weight)

            # weight = min(1, cqt[])
            # weight = 1 - FLAGS.interpolation_weight * (1+sin(i/7*pi))/2

            # print(weight)
            maxWeight = max(maxWeight, weight)
            weight /= maxWeight
            weight *= weight
            weight = max(weight, lastWeight - 0.1)

            lastWeight = weight

            stylized_image_res = sess.run(stylized_images,
                                          feed_dict={
                                              bottleneck_feat:
                                              identity_params * (1 - weight) +
                                              style_params * weight,
                                              content_img_ph:
                                              content_img_np
                                          })

            # output_filename = '%s_stylized_%s.jpg' % (content_img_name, style)
            # output_file = os.path.join(FLAGS.output_dir, output_filename)

            # Saves stylized image.
            # image_utils.save_np_image(stylized_image_res, output_file)
            # output_files += output_file

            #writes image to video output
            sqr_output_frame = np.uint8(stylized_image_res * 255)[0][:, :,
                                                                     [2, 1, 0]]
            out.write(cv2.resize(sqr_output_frame, (width, FLAGS.image_size)))

            tf.logging.info('Stylized %s with weight %.2f at %.1f fps' %
                            (content_img_name, weight, 1 /
                             (time.time() - frame_start)))

            # print("Outputted " + '%s_stylized_%s.jpg' %
            #     (content_img_name, style))

            i += 1

        out.release()
        capture.release()
        cmd = "ffmpeg -i {} -i {}.wav -c:v copy -shortest -map 0:v:0 -map 1:a:0 temp.avi".format(
            out_file, video_name)
        print(cmd)
        subprocess.call(cmd, shell=True)
        subprocess.call("mv -f temp.avi {}".format(out_file), shell=True)
        subprocess.call("rm {}.wav".format(video_name), shell=True)

        print("Average fps: " + str(i / (time.time() - start)))
        return 0