示例#1
0
    def feed_transform(self, data_in, paths_out):
        checkpoint_dir = os.path.join(self.flags.checkpoint_dir,
                                      self.style_img_name, 'model')
        img_shape = utils.imread(data_in[0]).shape

        g = tf.Graph()
        soft_config = tf.ConfigProto(allow_soft_placement=True)
        soft_config.gpu_options.allow_growth = True

        with g.as_default(), tf.Session(config=soft_config) as sess:
            img_placeholder = tf.placeholder(tf.float32,
                                             shape=[None, *img_shape],
                                             name='img_placeholder')

            model = Transfer()
            pred = model(img_placeholder)

            saver = tf.train.Saver()
            if os.path.isdir(checkpoint_dir):
                ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                else:
                    raise Exception('No checkpoint found...')
            else:
                saver.restore(sess, checkpoint_dir)

            img = np.asarray([utils.imread(data_in[0])]).astype(np.float32)
            _pred = sess.run(pred, feed_dict={img_placeholder: img})
            utils.imsave(paths_out[0], _pred[0])  # paths_out and _pred is list
示例#2
0
def feed_transform(data_in, paths_out, checkpoint_dir):
    img_shape = utils.imread(data_in[0]).shape

    g = tf.Graph()
    soft_config = tf.ConfigProto(allow_soft_placement=True)
    soft_config.gpu_options.allow_growth = True

    with g.as_default(), tf.Session(config=soft_config) as sess:
        img_placeholder = tf.placeholder(tf.float32,
                                         shape=[None, *img_shape],
                                         name='img_placeholder')

        model = Transfer()
        pred = model(img_placeholder)

        saver = tf.train.Saver()
        if os.path.isdir(checkpoint_dir):
            ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                raise Exception('No checkpoint found...')
        else:
            saver.restore(sess, checkpoint_dir)

        img = np.asarray([utils.imread(data_in[0])]).astype(np.float32)
        start_tic = time.time()
        _pred = sess.run(pred, feed_dict={img_placeholder: img})
        end_toc = time.time()
        print('PT: {:.2f} msec.\n'.format((end_toc - start_tic) * 1000))
        utils.imsave(paths_out[0], _pred[0])  # paths_out and _pred is list
示例#3
0
def feed_transform(style_image, paths_out, checkpoint_dir):
    img_shape = style_image.shape
    print("manan")
    print(img_shape)
    g = tf.Graph()

    soft_config = tf.ConfigProto(allow_soft_placement=True)
    soft_config.gpu_options.allow_growth = True

    with g.as_default(), tf.Session(config=soft_config) as sess:
        img_placeholder = tf.placeholder(tf.float32,
                                         shape=[None, *img_shape],
                                         name='img_placeholder')

        model = Transfer()
        pred = model(img_placeholder)

        saver = tf.train.Saver()
        if os.path.isdir(checkpoint_dir):
            ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                raise Exception('No checkpoint found...')
        else:
            saver.restore(sess, checkpoint_dir)

        img = np.asarray([style_image]).astype(np.float32)
        start_tic = time.time()
        _pred = sess.run(pred, feed_dict={img_placeholder: img})
        end_toc = time.time()
        print('PT: {:.2f} msec.\n'.format((end_toc - start_tic) * 1000))
        img = np.clip(_pred[0], 0, 255).astype(np.uint8)
        return img
示例#4
0
def feed_forward_video(path_in, path_out, checkpoint_dir):
    # initialize video cap
    video_cap = VideoFileClip(path_in, audio=False)
    # initialize writer
    video_writer = ffmpeg_writer.FFMPEG_VideoWriter(path_out,
                                                    video_cap.size,
                                                    video_cap.fps,
                                                    codec='libx264',
                                                    preset='medium',
                                                    bitrate='2000k',
                                                    audiofile=path_in,
                                                    threads=None,
                                                    ffmpeg_params=None)

    g = tf.Graph()
    soft_config = tf.ConfigProto(allow_soft_placement=True)
    soft_config.gpu_options.allow_growth = True

    with g.as_default(), tf.Session(config=soft_config) as sess:
        batch_shape = (None, video_cap.size[1], video_cap.size[0], 3)
        img_placeholder = tf.placeholder(tf.float32,
                                         shape=batch_shape,
                                         name='img_placeholder')

        model = Transfer()
        pred = model(img_placeholder)
        saver = tf.train.Saver()

        if os.path.isdir(checkpoint_dir):
            ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                raise Exception('No checkpoint found...')
        else:
            saver.restore(sess, checkpoint_dir)

        frame_id = 0
        for frame in video_cap.iter_frames():
            print('frame id: {}'.format(frame_id))
            _pred = sess.run(pred,
                             feed_dict={
                                 img_placeholder:
                                 np.asarray([frame]).astype(np.float32)
                             })
            video_writer.write_frame(np.clip(_pred, 0, 255).astype(np.uint8))
            frame_id += 1

        video_writer.close()
                        help = 'Path to save the output.')

    return parser

'''
style_src_path = r'E:\python.file\git-Repository\TF--style-transfer-collection\style_src\star.jpg'
content_src_path = r'E:\python.file\git-Repository\TF--style-transfer-collection\content_src\girl.jpg'

style_src = get_img(style_src_path,max_size = 512)
content_src = get_img(content_src_path,max_size = 800)

transfer_op = Transfer(content_src,style_src,content_layers = ['conv2_2','conv3_2'],pool_method = 'avg')
img = transfer_op.transfer()
'''

if __name__ == '__main__':
    parser = Argparser()
    args = parser.parse_args()

    content_img = utils.get_img_PIL(args.content_img,max_size = args.content_shape)
    style_img = utils.get_img_PIL(args.style_img,max_size = args.style_shape)
    
    transfer_op = Transfer(content_img = content_img , style_img = style_img ,
                           loss_para = args.loss_para , pool_method = args.pool_method , 
                           content_layers = args.content_layers , style_layers = args.style_layers ,
                           epoches = args.epoches , save_path = args.save_path)
    transfer_op.transfer()