예제 #1
0
def run_rmbe_model(patch_list):
    g2 = tf.Graph()
    sess = tf.Session(graph=g2)

    with sess.as_default():
        with g2.as_default():
            patches_placeholder = tf.placeholder(
                tf.float32, shape=[None, patch_size, patch_size, 3])

            batch_size = 64
            patch_batch, iterator = data_loader.get_patch_batch(
                batch_size, patches_placeholder)

            output_op = rmbe_model.model(patch_batch)

            params_file = 'rmbe/rmbe_params/params'
            saver = tf.train.Saver()
            saver.restore(sess, params_file)

    sess.run(iterator.initializer, feed_dict={patches_placeholder: patch_list})

    new_patch_list = []
    while True:
        try:
            output = sess.run(output_op)
            new_patch_list.append(output)
        except tf.errors.OutOfRangeError:
            break

    new_patches = np.concatenate(new_patch_list, axis=0)

    # print(len(new_patch_list))
    # print(new_patches.shape)
    # print('-----')

    # print(dir(rmbe_model))
    # print('-----')

    # new_patches = patch_list

    return new_patches
예제 #2
0
def compress(sess, model, args):

    print(args)

    config_path = 'model_{}/config.json'.format(args.model_num)
    with open(config_path, 'r') as f:
        config = json.load(f)

    print(config)

    # bottleneck_channel = config['bottleneck_channel']

    data_list = args.data_list
    image_path_list = utils.read_image_list(data_list)

    batch_size = 64
    patch_size = config['patch_size']
    patches_placeholder = tf.placeholder(
        tf.float32, shape=[None, patch_size, patch_size, 3])
    patch_batch, iterator = data_loader.get_patch_batch(
        batch_size, patches_placeholder)

    quan_scale = config['quan_scale']

    encoder_output_op = model.encoder(patch_batch, patch_size, quan_scale)

    utils.restore_params(sess, args)

    # To be paralleled
    for image_path in image_path_list:
        image = io.imread(image_path)
        image_patches = utils.crop_image_input_patches(image, patch_size)

        sess.run(iterator.initializer,
                 feed_dict={patches_placeholder: image_patches})

        encoded_patches = []
        while True:
            try:
                encoded_output = sess.run(encoder_output_op)
                encoded_patches.append(encoded_output)
            except tf.errors.OutOfRangeError:
                break

        # print('len(encoded_patches): {}'.format(len(encoded_patches)))
        # print(encoded_patches[0].shape)
        # print(encoded_patches[-1].shape)

        encoded_patches_shape = encoded_patches[0][0].shape

        # seq_data = np.concatenate(encoded_patches).reshape(-1).astype(int).tolist()

        seq_data = np.concatenate(encoded_patches).reshape(
            -1, np.prod(encoded_patches_shape))

        # print(seq_data.reshape(-1).shape[0])

        # encoded_order = [int(i) for i in np.load('data_info/order_info_{}.npy'.format(args.model_num))]
        # seq_data = [seq_data_item[encoded_order] for seq_data_item in seq_data]

        seq_data = np.asarray(seq_data).reshape(-1).astype(int).tolist()

        # print(len(seq_data))

        encodepath = get_encodepath(image_path, image, seq_data, args, config,
                                    encoded_patches_shape)

        # print('np.concatenate(encoded_patches).shape: {}'.format(np.concatenate(encoded_patches).shape))
        # print('len(seq_data): {}'.format(len(seq_data)))
        # print('type(seq_data): {}'.format(type(seq_data)))
        # print('np.max(np.asarray(seq_data)): {}'.format(np.max(np.asarray(seq_data))))
        # print('np.min(np.asarray(seq_data)): {}'.format(np.min(np.asarray(seq_data))))
        # print(seq_data)

        # break

        encoded_save_dir = args.output_dir.format(args.model_num)
        if not os.path.exists(encoded_save_dir):
            os.makedirs(encoded_save_dir)

        apply_range_encoder(seq_data, encodepath, args, config)

        print('encodepath: {}'.format(encodepath))
예제 #3
0
def compress_and_uncompress(sess, model, args):

  print(args)

  config_path = 'rm_block_effect/recons_model/config.json'
  with open(config_path, 'r') as f:
    config = json.load(f)

  print(config)

  # data_list = 'data_info/train_data_list.txt'
  data_list = 'data_info/ori_valid_data_list.txt'
  image_path_list = utils.read_image_list(data_list)

  batch_size = 64
  patch_size = config['patch_size']
  patches_placeholder = tf.placeholder(tf.float32, shape=[None, patch_size, patch_size, 3])
  patch_batch, iterator = data_loader.get_patch_batch(batch_size, patches_placeholder)

  quan_scale = config['quan_scale']

  encoder_output_op = model.encoder(patch_batch, patch_size, quan_scale)
  decoder_output_op = model.decoder(encoder_output_op, quan_scale)

  params_file = 'rm_block_effect/recons_params/params'
  saver = tf.train.Saver()
  saver.restore(sess, params_file)

  for image_path in image_path_list:
    image = io.imread(image_path)
    image_patches = utils.crop_image_input_patches(image, patch_size)

    sess.run(iterator.initializer, feed_dict={patches_placeholder: image_patches})

    # decoded_patches = image_patches


    decoded_patches_list = []
    while True:
      try:
        decoded_output = sess.run(decoder_output_op)
        decoded_patches_list.append(decoded_output)
      except tf.errors.OutOfRangeError:
        break

    # print('len(decoded_patches): {}'.format(len(decoded_patches)))

    decoded_patches = np.concatenate(decoded_patches_list, axis=0)

    height, width, channel = image.shape
    recons_image = utils.concat_patches(decoded_patches, height, width, patch_size)

    recons_image = np.around(recons_image).astype(np.uint8)


    recons_image_path = image_path.replace('ori', 'recons')


    print('recons_image_path: {}'.format(recons_image_path))

    io.imsave(recons_image_path, recons_image)
예제 #4
0
def uncompress(sess, model, args):

    print(args)

    config_path = 'model_{}/config.json'.format(args.model_num)
    with open(config_path, 'r') as f:
        config = json.load(f)

    print(config)

    patch_size = config['patch_size']

    input_dir = args.input_dir.format(args.model_num)

    encoded_height, encoded_width, encoded_channel = get_encoded_shape(
        input_dir, config)
    patches_placeholder = tf.placeholder(
        tf.float32,
        shape=[None, encoded_height, encoded_width, encoded_channel])

    batch_size = 64
    patch_batch, iterator = data_loader.get_patch_batch(
        batch_size, patches_placeholder)

    quan_scale = config['quan_scale']

    decoder_output_op = model.decoder(patch_batch, quan_scale)

    utils.restore_params(sess, args)

    for filename in os.listdir(input_dir):

        filepath = str(Path(input_dir) / filename)

        seq_data_len, height, width = get_img_info(filename, config)

        # print('filepath: {}'.format(filepath))
        # print('seq_data_len: {}'.format(seq_data_len))
        # print('height: {}'.format(height))
        # print('width: {}'.format(width))

        seq_data = apply_range_decoder(seq_data_len, filepath, args, config)

        # encoded_order = np.load('data_info/order_info_{}.npy'.format(args.model_num))
        # decoded_order = [int(i) for i in sorted(range(len(encoded_order)), key=lambda k: encoded_order[k])]

        # print(filepath)
        # print(encoded_order.shape)
        # print(len(seq_data))

        # seq_data = np.asarray(seq_data).reshape(-1, encoded_height * encoded_width * encoded_channel)

        # print('shape_1', seq_data.shape)
        # print('shape_order', len(decoded_order))

        # seq_data = [seq_data_item[decoded_order] for seq_data_item in seq_data]

        # print(seq_data[200 : 210])
        # print('-----')
        # sys.stdout.flush()

        # print('len(seq_data): {}'.format(len(seq_data)))

        seq_data = np.asarray(seq_data).astype(np.float32)

        # print('shape_2', seq_data.shape)

        encoded_patches = seq_data.reshape(-1, encoded_height, encoded_width,
                                           encoded_channel)

        # print('encoded_patches.shape: {}'.format(encoded_patches.shape))

        sess.run(iterator.initializer,
                 feed_dict={patches_placeholder: encoded_patches})

        decoded_patches_list = []
        while True:
            try:
                decoded_output = sess.run(decoder_output_op)
                decoded_patches_list.append(decoded_output)
            except tf.errors.OutOfRangeError:
                break

        decoded_patches = np.concatenate(decoded_patches_list, axis=0)

        # print(decoded_patches[100, 60 : 70, 20, 0])
        # print('-----')
        # sys.stdout.flush()

        # print('decoded_patches.shape: {}'.format(decoded_patches.shape))

        recons_image = utils.concat_patches(decoded_patches, height, width,
                                            patch_size)

        recons_image_path = get_recons_image_path(filename, args, config)

        output_dir = args.output_dir.format(args.model_num)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        print('recons_image_path: {}'.format(recons_image_path))
        # print('recons_image.shape: {}'.format(recons_image.shape))

        # print(recons_image[200 : 210, 500, 0])
        # sys.stdout.flush()

        # break

        # cv2.imwrite(recons_image_path, recons_image)

        recons_image = np.around(recons_image).astype(np.uint8)

        io.imsave(recons_image_path, recons_image)
예제 #5
0
def compress_and_uncompress(sess, model, args):

    print(args)

    config_path = 'model_{}/config.json'.format(args.model_num)
    with open(config_path, 'r') as f:
        config = json.load(f)

    print(config)

    data_list = args.data_list
    image_path_list = utils.read_image_list(data_list)

    batch_size = 64
    patch_size = config['patch_size']
    patches_placeholder = tf.placeholder(
        tf.float32, shape=[None, patch_size, patch_size, 3])
    patch_batch, iterator = data_loader.get_patch_batch(
        batch_size, patches_placeholder)

    quan_scale = config['quan_scale']

    encoder_output_op = model.encoder(patch_batch, patch_size, quan_scale)
    decoder_output_op = model.decoder(encoder_output_op, quan_scale)

    utils.restore_params(sess, args)

    for image_path in image_path_list:
        image = io.imread(image_path)
        image_patches = utils.crop_image_input_patches(image)

        sess.run(iterator.initializer,
                 feed_dict={patches_placeholder: image_patches})

        # decoded_patches = image_patches

        decoded_patches_list = []
        while True:
            try:
                decoded_output = sess.run(decoder_output_op)
                decoded_patches_list.append(decoded_output)
            except tf.errors.OutOfRangeError:
                break

        # print('len(decoded_patches): {}'.format(len(decoded_patches)))

        decoded_patches = np.concatenate(decoded_patches_list, axis=0)

        height, width, channel = image.shape
        recons_image = utils.concat_patches(decoded_patches, height, width)

        filename = image_path.split('/')[-1][:-4]
        recons_image_path = get_recons_image_path(filename, args, config)

        output_dir = args.output_dir.format(args.model_num)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        print('recons_image_path: {}'.format(recons_image_path))

        io.imsave(recons_image_path, recons_image)

        image_read = np.asarray(Image.open(recons_image_path),
                                dtype=np.float32)