Beispiel #1
0
    def generator(self, x_init, reuse=False, scope="generator"):
        if self.light:
            with tf.variable_scope(scope, reuse=reuse):
                G = generator_lite.G_net(x_init)
                return G.fake

        else:
            with tf.variable_scope(scope, reuse=reuse):
                G = generator.G_net(x_init)
                return G.fake
Beispiel #2
0
def test(checkpoint_dir,
         style_name,
         test_dir,
         if_adjust_brightness,
         img_size=[256, 256]):
    # tf.reset_default_graph()
    result_dir = 'results/' + style_name
    check_folder(result_dir)
    test_files = glob('{}/*.*'.format(test_dir))

    test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')

    with tf.variable_scope("generator", reuse=False):
        if 'lite' in checkpoint_dir:
            test_generated = generator_lite.G_net(test_real).fake
        else:
            test_generated = generator.G_net(test_real).fake
    saver = tf.train.Saver()

    gpu_options = tf.GPUOptions(allow_growth=True)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          gpu_options=gpu_options)) as sess:
        # tf.global_variables_initializer().run()
        # load model
        ckpt = tf.train.get_checkpoint_state(
            checkpoint_dir)  # checkpoint file information
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(
                ckpt.model_checkpoint_path)  # first line
            saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
            print(" [*] Success to read {}".format(
                os.path.join(checkpoint_dir, ckpt_name)))
        else:
            print(" [*] Failed to find a checkpoint")
            return
        # stats_graph(tf.get_default_graph())

        begin = time.time()
        for sample_file in tqdm(test_files):
            # print('Processing image: ' + sample_file)
            sample_image = np.asarray(load_test_data(sample_file, img_size))
            image_path = os.path.join(
                result_dir, '{0}'.format(os.path.basename(sample_file)))
            fake_img = sess.run(test_generated,
                                feed_dict={test_real: sample_image})
            if if_adjust_brightness:
                save_images(fake_img, image_path, sample_file)
            else:
                save_images(fake_img, image_path, None)
        end = time.time()
        print(f'test-time: {end-begin} s')
        print(f'one image test time : {(end-begin)/len(test_files)} s')
Beispiel #3
0
def main(checkpoint_dir, style_name):
    if "lite" in checkpoint_dir:
        ckpt_dir = "../checkpoint/" + "generator_" + style_name + "_weight_lite"
    else:
        ckpt_dir = "../checkpoint/" + "generator_" + style_name + "_weight"
    check_folder(ckpt_dir)

    placeholder = tf.placeholder(tf.float32, [1, None, None, 3],
                                 name="generator_input")
    with tf.variable_scope("generator", reuse=False):
        if "lite" in checkpoint_dir:
            _ = generator_lite.G_net(placeholder).fake
        else:
            _ = generator.G_net(placeholder).fake

    generator_var = [
        var for var in tf.trainable_variables()
        if var.name.startswith("generator")
    ]
    saver = tf.train.Saver(generator_var)

    gpu_options = tf.GPUOptions(allow_growth=True)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          gpu_options=gpu_options)) as sess:
        sess.run(tf.global_variables_initializer())
        # load model
        ckpt = tf.train.get_checkpoint_state(
            checkpoint_dir)  # checkpoint file information
        if ckpt and ckpt.model_checkpoint_path:
            print(ckpt.model_checkpoint_path)
            ckpt_name = os.path.basename(
                ckpt.model_checkpoint_path)  # first line
            saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
            counter = ckpt_name.split("-")[-1]
            print(" [*] Success to read {}".format(ckpt_name))
        else:
            print(" [*] Failed to find a checkpoint")
            return

        info = save(saver, sess, ckpt_dir, style_name + "-" + counter)

        print(f"save over : {info} ")
Beispiel #4
0
def cvt2anime_video(video,
                    output,
                    checkpoint_dir,
                    output_format="MP4V",
                    img_size=(256, 256)):
    """
    output_format: 4-letter code that specify codec to use for specific video type. e.g. for mp4 support use "H264", "MP4V", or "X264"
    """
    gpu_stat = bool(len(tf.config.experimental.list_physical_devices("GPU")))
    if gpu_stat:
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    gpu_options = tf.compat.v1.GPUOptions(allow_growth=gpu_stat)

    test_real = tf.compat.v1.placeholder(tf.float32, [1, None, None, 3],
                                         name="test")

    with tf.compat.v1.variable_scope("generator", reuse=False):
        if "lite" in checkpoint_dir:
            test_generated = generator_lite.G_net(test_real).fake
        else:
            test_generated = generator.G_net(test_real).fake

    # load video
    vid = cv2.VideoCapture(video)
    vid_name = os.path.basename(video)
    total = int(vid.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = int(vid.get(cv2.CAP_PROP_FPS))
    # codec = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
    codec = cv2.VideoWriter_fourcc(*output_format)

    tfconfig = tf.compat.v1.ConfigProto(allow_soft_placement=True,
                                        gpu_options=gpu_options)
    with tf.compat.v1.Session(config=tfconfig) as sess:
        # tf.global_variables_initializer().run()
        # load model
        ckpt = tf.train.get_checkpoint_state(
            checkpoint_dir)  # checkpoint file information
        saver = tf.compat.v1.train.Saver()
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(
                ckpt.model_checkpoint_path)  # first line
            saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
            print(" [*] Success to read {}".format(
                os.path.join(checkpoint_dir, ckpt_name)))
        else:
            print(" [*] Failed to find a checkpoint")
            return

        # determine output width and height
        ret, img = vid.read()
        if img is None:
            print("Error! Failed to determine frame size: frame empty.")
            return
        img = preprocessing(img, img_size)
        height, width = img.shape[:2]
        # out = cv2.VideoWriter(os.path.join(output, vid_name.replace('mp4','mkv')), codec, fps, (width, height))
        out = cv2.VideoWriter(os.path.join(output, vid_name), codec, fps,
                              (width, height))

        pbar = tqdm(total=total)
        vid.set(cv2.CAP_PROP_POS_FRAMES, 0)
        while ret:
            ret, frame = vid.read()
            if frame is None:
                print("Warning: got empty frame.")
                continue

            img = convert_image(frame, img_size)
            fake_img = sess.run(test_generated, feed_dict={test_real: img})
            fake_img = inverse_image(fake_img)
            fake_img = adjust_brightness_from_src_to_dst(
                fake_img, cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            out.write(cv2.cvtColor(fake_img, cv2.COLOR_BGR2RGB))
            pbar.update(1)

        pbar.close()
        vid.release()
        # cv2.destroyAllWindows()
        return os.path.join(output, vid_name)