Пример #1
0
def main():
    """main"""

    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    # open session
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        gan = UGATIT(sess, args)

        # build graph
        gan.build_model()

        # show network architecture
        show_all_variables()

        if args.phase == 'train':
            gan.train()
            print(" [*] Training finished!")

        if args.phase == 'test':
            gan.test()
            print(" [*] Test finished!")

        if args.phase == "export":
            gan.export_saved_model()
            print(" [*] Export model finished!")
Пример #2
0
def selfie2anime():
    img_id = os.environ['id']
    result_id = os.environ['result']
    

    parser = get_parser()
    args = parser.parse_args("--phase test".split())

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        #sess.reuse_variables()
        gan = UGATIT(sess, args)

        # build graph
        gan.build_model()

        # download target img
        download_path = os.path.join(img_path, img_id)

        download_image(images_bucket, img_id, dest=download_path)
        dataset_tool.create_from_images(record_path, img_path, True)
        # os.remove(del_record)
        
        img = gan.infer(download_path)

        image_url = upload_image(img, result_id)

    return download_path, img
Пример #3
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    # open session
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        gan = UGATIT(sess, args)

        # build graph
        gan.build_model()

        # show network architecture
        show_all_variables()

        if args.phase == 'train':
            gan.train()
            print(" [*] Training finished!")

        if args.phase == 'test':
            gan.test()
            print(" [*] Test finished!")

        if args.phase == 'export':
            gan.export()
            print(" [*] Export finished!")
Пример #4
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    use_gpu = args.device == 'cuda'
    place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
    args.place = place

    with fluid.dygraph.guard(place):
        # open session
        gan = UGATIT(args)

        # build graph
        gan.build_model()

        if args.phase == 'train':
            gan.train()
            print(" [*] Training finished!")

        if args.phase == 'test':
            gan.test()
            print(" [*] Test finished!")

        if args.phase == 'deploy':
            gan.deploy()
            print(" [*] Deploy finished!")
Пример #5
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    # 自定义参数
    args.epoch = 1000
    args.batch_size = 1
    # args.dataset = "s2a4zhengsheng"
    # args.gan_type = "dragan"

    # open session
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        gan = UGATIT(sess, args)

        # build graph
        gan.build_model()

        # show network architecture
        show_all_variables()

        if args.phase == 'train':
            gan.train()
            print(" [*] Training finished!")

        if args.phase == 'test':
            gan.test()
            print(" [*] Test finished!")
Пример #6
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    gpu_options = tf.GPUOptions(allow_growth=True)

    # open session
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)) as sess:
        gan = UGATIT(sess, args)

        print('[Info] 构建模型开始!')
        # build graph
        gan.build_model()
        print('[Info] 构建模型完成!')

        # show network architecture
        show_all_variables()

        if args.phase == 'train':
            gan.train()
            print(" [*] Training finished!")

        if args.phase == 'test':
            gan.test()
            print(" [*] Test finished!")
Пример #7
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    # open session
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        gan = UGATIT(sess, args)

        # build graph
        gan.build_model()

        # show network architecture
        show_all_variables()

        if args.phase == 'train':
            gan.train()
            print(" [*] Training finished!")

        if args.phase == 'test':
            # gan.test()
            gan.test_single_img(
                sample_file="UGATIT/dataset/selfie2anime/testA/conv.JPG"
            )  #pass the img you want to convert
            print(" [*] Test finished!")
Пример #8
0
def main():
    # parse arguments
    # import os
    # os.environ["CUDA_VISIBLE_DEVICES"] = "2"
    args = parse_args()
    import os
    gpu_id = args.gpu_ids
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

    if args is None:
        exit()

    # open session
    gan = UGATIT(args)

    # build graph
    gan.build_model()

    if args.phase == 'train':
        gan.train()
        print(" [*] Training finished!")

    if args.phase == 'test':
        gan.test()
        print(" [*] Test finished!")
Пример #9
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    from paddle import fluid
    if args.device == "cuda":
        place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
    else:
        place = fluid.CPUPlace()
    #
    with fluid.dygraph.guard(place=place):
        if args.parallel:
            args.strategy = fluid.dygraph.parallel.prepare_context()
        else:
            args.strategy = None
        gan = UGATIT(args)

        # build graph
        gan.build_model()

        if args.phase == 'train':
            gan.train()
            print(" [*] Training finished!")

        if args.phase == 'test':
            gan.test()
            print(" [*] Test finished!")
Пример #10
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    if args.weather == "Day2rain":
        args.dataset = "day2rain"
    elif args.weather == "Day2night":
        args.dataset = "day2night"
    # open session
    gan = UGATIT(args)

    # build graph
    gan.build_model()

    if args.phase == 'train':

        gan.train()
        print(" [*] Training finished!")

    if args.phase == 'test':

        gan.generate(args.dataset_path, args.output_path)
        print(" [*] Test finished!")
Пример #11
0
def process():

    input_path = generate_random_filename(upload_directory, "jpg")
    output_path = generate_random_filename(result_directory, "jpg")

    try:
        url = request.json["url"]

        download(url, input_path)

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            gan = UGATIT(sess, args)

            gan.build_model()

            gan.test_endpoint_init()

            gan.test_endpoint(input_path, output_path)

        callback = send_file(output_path, mimetype='image/jpeg')

        return callback, 200

    except:
        traceback.print_exc()
        return {'message': 'input error'}, 400

    finally:
        clean_all([input_path, output_path])
Пример #12
0
def index():

    # set session for image results
    if "file_urls" not in session:
        session['file_urls'] = []
    # list to hold our uploaded image urls
    file_urls = session['file_urls']

    if request.method == 'POST':
        file_obj = request.files
        for f in file_obj:
            file = request.files.get(f)

            # convert string of image data to uint8
            nparr = np.fromfile(file, np.uint8)
            # decode image
            img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)

            # parse arguments
            args = parse_args()
            if args is None:
                exit()

            # open session
            with tf.Session(config=tf.ConfigProto(
                    allow_soft_placement=True)) as sess:
                gan = UGATIT(sess, args)

                # build graph
                gan.build_model()

                # show network architecture
                show_all_variables()

                # do some fancy processing here....
                fake_img = gan.test_endpoint(img)

                # save the file with to our photos folder
                filename = str(uuid.uuid1()) + '.png'
                cv2.imwrite('uploads/' + filename, fake_img)
                # append image urls
                file_urls.append(photos.url(filename))

        session['file_urls'] = file_urls
        return "uploading..."
    return render_template('index.html')
Пример #13
0
def video(args):
    # open session
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        gan = UGATIT(sess, args)

        # build graph
        gan.build_model()

        # show network architecture
        show_all_variables()

        # load model for use in video stream
        gan.video_inference_init()

        # Run frame thread
        threading.Thread(target=read_frame_thread).start()

        while state.running:
            if state.frame is None:
                time.sleep(0.01)
                continue

            # Get recent frame
            frame = state.frame

            # generate image
            gen_image = gan.video_inference(frame)

            # display frame
            cv2.imshow('Output', gen_image)

            # handle key press events
            process_events()
Пример #14
0
def getUGATITTransform(input_image, dataset):
    image = imageio.imread(io.BytesIO(input_image))
    [h,w,c] = np.shape(image)
    scale = 1
    if((h > 512) or (w > 512)):
        if (h > w):
            scale = h // 512 + 1
        else:
            scale = w // 512 + 1
    image = cv2.resize(image, dsize=(int(w/scale), int(h/scale)))
    t_image = torch.from_numpy(np.expand_dims(np.transpose(image, (2, 0, 1)), 0)).cuda().float()
    args = parse_args()
    args.light = True
    args.dataset=dataset
    gan = UGATIT(args)
    gan.build_model()
    model_list = glob(os.path.join(gan.result_dir, gan.dataset, 'model', '*.pt'))
    if not len(model_list) == 0:
        model_list.sort()
        iter = int(model_list[-1].split('_')[-1].split('.')[0])
        gan.load(os.path.join(gan.result_dir, gan.dataset, 'model'), iter)
        print(" [*] Load SUCCESS")
    else:
        print(" [*] Load FAILURE")
        return
    gan.genA2B.eval(), gan.genB2A.eval()
    real_A = t_image
    fake_A2B, _, fake_A2B_heatmap = gan.genA2B(real_A)
    result = np.transpose(np.squeeze(fake_A2B.data.cpu().numpy()), (1,2,0))
    return result
Пример #15
0
def main(args):
    h, w = _FRAME_SIZE
    x = tf.placeholder(tf.float32, [None, h, w, 3])
    model = UGATIT()
    y = model.generate(x)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        restorer = tf.train.Saver()
        restorer.restore(sess, args.checkpoint)

        cap = cv2.VideoCapture(_DEFAULT_CAMERA_ID)
        if not cap.isOpened():
            raise Exception("Unable to read camera feed")
        # by default use timestamp to name recorded videos
        video_path = os.path.join(
            _DEFAULT_OUTPUT_DIR,
            '{}.mp4'.format(datetime.now().strftime("%Y-%m-%d-%H%M%S")))
        fourcc = cv2.VideoWriter_fourcc(*_DEFAULT_CODEC)
        video_out = cv2.VideoWriter(
            filename=video_path,
            fourcc=fourcc,
            fps=_DEFAULT_FRAMES_PER_SECOND,
            frameSize=(w * 2, h) if args.sidebyside else _FRAME_SIZE)
        window_message = "Press [SPACE BAR] to stop"
        while cap.isOpened():
            ret, orig = cap.read()
            if ret:
                orig = cv2.resize(orig, _FRAME_SIZE).astype(np.float32)
                out = cv2.cvtColor(orig, cv2.COLOR_BGR2RGB) / 127.5 - 1.
                out = sess.run(y, feed_dict={x: np.expand_dims(out, 0)})
                out = (np.squeeze(out) + 1.) / 2.
                out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
                frame = np.hstack(
                    (orig / 255., out)) if args.sidebyside else out
                cv2.imshow(window_message, frame)
                if args.save:
                    video_out.write((frame * 255).astype(np.uint8))
                if cv2.waitKey(1) & 0xFF == ord(" "):
                    break
            else:
                break
        video_out.release()
        cap.release()
        cv2.destroyAllWindows()
Пример #16
0
def main():
    # parse arguments
    args = parse_args()

    # Change to fullscreen
    cv2.namedWindow('Output', cv2.WND_PROP_FULLSCREEN)
    cv2.setWindowProperty('Output', cv2.WND_PROP_FULLSCREEN,
                          cv2.WINDOW_FULLSCREEN)

    if args is None:
        exit()

    if args.phase == 'video':
        video(args)

    # open session
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        gan = UGATIT(sess, args)

        # build graph
        gan.build_model()

        # show network architecture
        show_all_variables()

        if args.phase == 'web':
            gan.video_inference_init()

            global gan_ref
            gan_ref = gan

            app.run(host="0.0.0.0", port=5000)
            exit()

        if args.phase == 'train':
            gan.train()
            print(" [*] Training finished!")

        if args.phase == 'test':
            gan.test()
            print(" [*] Test finished!")
Пример #17
0
def main():
    config = args()
    # open session
    gan = UGATIT(config)
    # build graph
    gan.build_model()
    gan.train()
Пример #18
0
def setup(opts):
    args = parse_args()
    args.dataset = 'portrait'
    gan = UGATIT(sess, args)
    gan.build_model()
    gan.load_from_latest(opts['checkpoint'])
    return gan
Пример #19
0
def setup(opts):
    args = parse_args()
    args.phase = 'test'
    args.img_size = 256
    gan = UGATIT(sess, args)
    gan.build_model()
    gan.load_from_latest(opts['checkpoint'])
    return gan
Пример #20
0
def main(args):
    if not os.path.exists(args.input):
        raise FileNotFoundError("Input image does not exist")
    h, w = _IMAGE_SIZE
    x = tf.placeholder(tf.float32, [None, h, w, 3])
    model = UGATIT()
    y = model.generate(x)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        restorer = tf.train.Saver()
        restorer.restore(sess, args.checkpoint)
        orig = cv2.imread(args.input)
        orig = cv2.resize(orig, _IMAGE_SIZE).astype(np.float32)
        out = cv2.cvtColor(orig, cv2.COLOR_BGR2RGB) / 127.5 - 1.
        out = sess.run(y, feed_dict={x: np.expand_dims(out, 0)})
        out = (np.squeeze(out) + 1.) / 2.
        out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
        window_message = "Press any key to close"
        if args.sidebyside:
            cv2.imshow(window_message, np.hstack((orig / 255., out)))
        else:
            cv2.imshow(window_message, out)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        if args.save:
            base_filename = os.path.basename(args.input).split(".")[0]
            out_path = os.path.join(_DEFAULT_OUTPUT_DIR,
                                    base_filename + "_out.png")
            save_out = (out * 255).astype(np.uint8)
            cv2.imwrite(out_path, save_out)
            print("output image saved to: {}".format(out_path))
            if args.sidebyside:  # also save side-by-side image
                sbs_path = os.path.join(_DEFAULT_OUTPUT_DIR,
                                        base_filename + "_sbs.png")
                cv2.imwrite(sbs_path,
                            np.hstack((orig.astype(np.uint8), save_out)))
                print("side-by-side image saved to: {}".format(sbs_path))
Пример #21
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    # open session
    gan = UGATIT(args)

    # build graph
    gan.build_model()

    if args.phase == 'train':
        gan.train()
        print(" [*] Training finished!")

    if args.phase == 'test':
        gan.test()
        print(" [*] Test finished!")

    if args.phase == 'val':
        gan.val()
        print(" [*] Val finished!")
Пример #22
0
def main():
    # parse arguments
    args = parse_args()

    if args is None:
        exit()

    if args.phase == 'video':
        video(args)

    # open session
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        gan = UGATIT(sess, args)

        # build graph
        gan.build_model()

        # show network architecture
        show_all_variables()

        if args.phase == 'web':
            gan.video_inference_init()

            global gan_ref
            gan_ref = gan

            app.run(host="0.0.0.0", port=5000)
            exit()

        if args.phase == 'train':
            gan.train()
            print(" [*] Training finished!")

        if args.phase == 'test':
            gan.test()
            print(" [*] Test finished!")
Пример #23
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    gan = UGATIT(args)

    # build graph
    gan.build()

    if args.phase == 'train':
        gan.train()
        print(" [*] Training finished!")

    if args.phase == 'test':
        gan.test()
        print(" [*] Test finished!")
Пример #24
0
def main():
    in_dir = os.getenv("INPUT_DIR")
    out_dir = os.getenv("RESULT_DIR")

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        args = FakeArgs()
        gan = UGATIT(sess, args)
        gan.build_model()

        tf.global_variables_initializer().run()

        gan.loop_on_input(in_dir, out_dir)
Пример #25
0
def runner(args):
    # open session
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        gan = UGATIT(sess, args)

        # build graph
        gan.build_model()

        # show network architecture
        show_all_variables()

        gan.test_endpoint_init()
def load_resources(mode):
    global face_classifier_classifier
    face_classifier_classifier = cv2.CascadeClassifier(face_cascade_path)

    if mode == modes.ANIME_MODE:
        global anime_session, anime_model

        anime_session = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True))
        anime_model = UGATIT(anime_session, args)
        anime_model.build_model()
        anime_model.load_model(anime_session)
Пример #27
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    # open session
    gan = UGATIT(args)

    # build graph
    gan.build_model()

    if args.phase == 'train':
        gan.train()
        print(" [*] Training finished!")

    if args.phase == 'test':
        gan.test()
        print(" [*] Test finished!")

    if args.phase == 'get_results':
        gan.GetTest_result()
        print(" [*] get your results!")
Пример #28
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()
    place = fluid.CUDAPlace(0)
    # place = fluid.CPUPlace()
    with fluid.dygraph.guard(place):
        # open session
        gan = UGATIT(args)

        # build graph
        gan.build_model()

        if args.phase == 'train':
            gan.train()
            print(" [*] Training finished!")

        if args.phase == 'test':
            gan.test()
            print(" [*] Test finished!")
Пример #29
0
def main():
    bps.init()
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)
    torch.cuda.set_device(bps.local_rank())
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    # open session
    gan = UGATIT(args)

    # build graph
    gan.build_model()

    if args.phase == 'train':
        gan.train()
        print(" [*] Training finished!")

    if args.phase == 'test':
        gan.test()
        print(" [*] Test finished!")
Пример #30
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    # open session
    gan = UGATIT(args)

    # build graph
    gan.build_model()

    if args.phase == 'train':
        gan.train()
        print(" [*] Training finished!")
        plt.plot(gan.logit_list_real, label="real_logit", alpha=0.3)
        plt.plot(gan.logit_list_fake, label="fake_logit", alpha=0.3)
        plt.legend()
        plt.show()

    if args.phase == 'test':
        gan.test(show_real=args.show_real)
        print(" [*] Test finished!")