Esempio n. 1
0
def decompress_less_mem(y_strings, y_min_vs, y_max_vs, y_shape, z_strings, z_min_v, z_max_v, z_shape, model, ckpt_dir):
  """Decompress bitstream to cubes.
  Input: compressed bitstream. latent representations (y) and hyper prior (z).
  Output: cubes with shape [batch size, length, width, height, channel(1)]
  """

  print('===== Decompress =====')
  # load model.
  #model = importlib.import_module(model)
  synthesis_transform = model.SynthesisTransform()
  hyper_encoder = model.HyperEncoder()
  hyper_decoder = model.HyperDecoder()
  entropy_bottleneck = EntropyBottleneck()
  conditional_entropy_model = SymmetricConditional()

  checkpoint = tf.train.Checkpoint(synthesis_transform=synthesis_transform, 
                                    hyper_encoder=hyper_encoder, 
                                    hyper_decoder=hyper_decoder, 
                                    estimator=entropy_bottleneck)
  status = checkpoint.restore(tf.train.latest_checkpoint(ckpt_dir))

  start = time.time()
  zs = entropy_bottleneck.decompress(z_strings, z_min_v, z_max_v, z_shape, z_shape[-1])
  print("Entropy Decoder (Hyper): {}s".format(round(time.time()-start, 4)))

  def loop_hyper_deocder(z):
    z = tf.expand_dims(z, 0)
    loc, scale = hyper_decoder(z)
    return tf.squeeze(loc, [0]), tf.squeeze(scale, [0])

  start = time.time()
  locs, scales = tf.map_fn(loop_hyper_deocder, zs, dtype=(tf.float32, tf.float32),
                          parallel_iterations=1, back_prop=False)
  lower_bound = 1e-9# TODO
  scales = tf.maximum(scales, lower_bound)
  print("Hyper Decoder: {}s".format(round(time.time()-start, 4)))

  start = time.time()
  # ys = conditional_entropy_model.decompress(y_strings, locs, scales, y_min_v, y_max_v, y_shape)
  def loop_range_decode(args):
    y_string, loc, scale, y_min_v, y_max_v = args
    loc = tf.expand_dims(loc, 0)
    scale = tf.expand_dims(scale, 0)
    y_decoded = conditional_entropy_model.decompress(y_string, loc, scale, y_min_v, y_max_v, y_shape)
    return tf.squeeze(y_decoded, 0)

  args = (y_strings, locs, scales, y_min_vs, y_max_vs)
  ys = tf.map_fn(loop_range_decode, args, dtype=tf.float32, parallel_iterations=1, back_prop=False)
  print("Entropy Decoder: {}s".format(round(time.time()-start, 4)))

  def loop_synthesis(y):
    y = tf.expand_dims(y, 0)
    x = synthesis_transform(y)
    return tf.squeeze(x, [0])

  start = time.time()
  xs = tf.map_fn(loop_synthesis, ys, dtype=tf.float32, parallel_iterations=1, back_prop=False)
  print("Synthesis Transform: {}s".format(round(time.time()-start, 4)))

  return xs
Esempio n. 2
0
RATIO_EVAL = 9 #
# NUM_ITEATION = 2e5 
NUM_ITEATION = int(args.num_iteration)
alpha = args.alpha
beta = args.beta
gamma = args.gamma # weight of hyper prior.
delta = args.delta # weight of latent representation.
init_ckpt_dir = args.init_ckpt_dir
lower_bound = args.lower_bound
print('lower bound of scale:', lower_bound)
reset_optimizer = bool(args.reset_optimizer)
print('reset_optimizer:::', reset_optimizer)

# Define variables
analysis_transform, synthesis_transform = model.AnalysisTransform(), model.SynthesisTransform()
hyper_encoder, hyper_decoder = model.HyperEncoder(), model.HyperDecoder()

entropy_bottleneck = EntropyBottleneck()
conditional_entropy_model = SymmetricConditional()

global_step = tf.train.get_or_create_global_step()

# lr = tf.train.exponential_decay(1e-4, global_step, 20000, 0.75, staircase=True)
# lr = 1e-5
lr = args.lr
main_optimizer = tf.train.AdamOptimizer(learning_rate = lr)

########## Define checkpoint ########## 
if args.reset_optimizer == 0:
  checkpoint = tf.train.Checkpoint(analysis_transform=analysis_transform,
                              synthesis_transform=synthesis_transform,
Esempio n. 3
0
def compress_hyper(cubes, model, ckpt_dir, decompress=False):
    """Compress cubes to bitstream.
  Input: cubes with shape [batch size, length, width, height, channel(1)].
  Output: compressed bitstream.
  """

    print('===== Compress =====')
    # load model.
    #model = importlib.import_module(model)
    analysis_transform = model.AnalysisTransform()
    synthesis_transform = model.SynthesisTransform()
    hyper_encoder = model.HyperEncoder()
    hyper_decoder = model.HyperDecoder()
    entropy_bottleneck = EntropyBottleneck()
    conditional_entropy_model = SymmetricConditional()

    checkpoint = tf.train.Checkpoint(analysis_transform=analysis_transform,
                                     synthesis_transform=synthesis_transform,
                                     hyper_encoder=hyper_encoder,
                                     hyper_decoder=hyper_decoder,
                                     estimator=entropy_bottleneck)
    status = checkpoint.restore(tf.train.latest_checkpoint(ckpt_dir))

    x = tf.convert_to_tensor(cubes, "float32")

    def loop_analysis(x):
        x = tf.expand_dims(x, 0)
        y = analysis_transform(x)
        return tf.squeeze(y)

    start = time.time()
    ys = tf.map_fn(loop_analysis,
                   x,
                   dtype=tf.float32,
                   parallel_iterations=1,
                   back_prop=False)
    print("Analysis Transform: {}s".format(round(time.time() - start, 4)))

    def loop_hyper_encoder(y):
        y = tf.expand_dims(y, 0)
        z = hyper_encoder(y)
        return tf.squeeze(z)

    start = time.time()
    zs = tf.map_fn(loop_hyper_encoder,
                   ys,
                   dtype=tf.float32,
                   parallel_iterations=1,
                   back_prop=False)
    print("Hyper Encoder: {}s".format(round(time.time() - start, 4)))

    z_hats, _ = entropy_bottleneck(zs, False)
    print("Quantize hyperprior.")

    def loop_hyper_deocder(z):
        z = tf.expand_dims(z, 0)
        loc, scale = hyper_decoder(z)
        return tf.squeeze(loc, [0]), tf.squeeze(scale, [0])

    start = time.time()
    locs, scales = tf.map_fn(loop_hyper_deocder,
                             z_hats,
                             dtype=(tf.float32, tf.float32),
                             parallel_iterations=1,
                             back_prop=False)
    lower_bound = 1e-9  # TODO
    scales = tf.maximum(scales, lower_bound)
    print("Hyper Decoder: {}s".format(round(time.time() - start, 4)))

    start = time.time()
    z_strings, z_min_v, z_max_v = entropy_bottleneck.compress(zs)
    z_shape = tf.shape(zs)[:]
    print("Entropy Encode (Hyper): {}s".format(round(time.time() - start, 4)))

    start = time.time()

    # y_strings, y_min_v, y_max_v = conditional_entropy_model.compress(ys, locs, scales)
    # y_shape = tf.shape(ys)[:]
    def loop_range_encode(args):
        y, loc, scale = args
        y = tf.expand_dims(y, 0)
        loc = tf.expand_dims(loc, 0)
        scale = tf.expand_dims(scale, 0)
        y_string, y_min_v, y_max_v = conditional_entropy_model.compress(
            y, loc, scale)
        return y_string, y_min_v, y_max_v

    args = (ys, locs, scales)
    y_strings, y_min_vs, y_max_vs = tf.map_fn(loop_range_encode,
                                              args,
                                              dtype=(tf.string, tf.int32,
                                                     tf.int32),
                                              parallel_iterations=1,
                                              back_prop=False)
    y_shape = tf.convert_to_tensor(np.insert(tf.shape(ys)[1:].numpy(), 0, 1))

    print("Entropy Encode: {}s".format(round(time.time() - start, 4)))

    if decompress:
        start = time.time()

        def loop_range_decode(args):
            y_string, loc, scale, y_min_v, y_max_v = args
            loc = tf.expand_dims(loc, 0)
            scale = tf.expand_dims(scale, 0)
            y_decoded = conditional_entropy_model.decompress(
                y_string, loc, scale, y_min_v, y_max_v, y_shape)
            return tf.squeeze(y_decoded, 0)

        args = (y_strings, locs, scales, y_min_vs, y_max_vs)
        y_decodeds = tf.map_fn(loop_range_decode,
                               args,
                               dtype=tf.float32,
                               parallel_iterations=1,
                               back_prop=False)

        print("Entropy Decode: {}s".format(round(time.time() - start, 4)))

        def loop_synthesis(y):
            y = tf.expand_dims(y, 0)
            x = synthesis_transform(y)
            return tf.squeeze(x, [0])

        start = time.time()
        x_decodeds = tf.map_fn(loop_synthesis,
                               y_decodeds,
                               dtype=tf.float32,
                               parallel_iterations=1,
                               back_prop=False)
        print("Synthesis Transform: {}s".format(round(time.time() - start, 4)))
        return y_strings, y_min_vs, y_max_vs, y_shape, z_strings, z_min_v, z_max_v, z_shape, x_decodeds

    return y_strings, y_min_vs, y_max_vs, y_shape, z_strings, z_min_v, z_max_v, z_shape
Esempio n. 4
0
def compress_more_less_mem(cubes, model, ckpt_dir, decompress=False):
    """Compress cubes to bitstream.
    Input: cubes with shape [batch size, length, width, height, channel(1)].
    Output: compressed bitstream.
    """

    print('===== Compress =====')
    # load model.
    #model = importlib.import_module(model)
    analysis_transform = model.AnalysisTransform()
    synthesis_transform = model.SynthesisTransform()
    hyper_encoder = model.HyperEncoder()
    hyper_decoder = model.HyperDecoder()
    entropy_bottleneck = EntropyBottleneck()
    conditional_entropy_model = SymmetricConditional()

    checkpoint = tf.train.Checkpoint(analysis_transform=analysis_transform, 
                                synthesis_transform=synthesis_transform, 
                                hyper_encoder=hyper_encoder, 
                                hyper_decoder=hyper_decoder, 
                                estimator=entropy_bottleneck)
    status = checkpoint.restore(tf.train.latest_checkpoint(ckpt_dir))

    start = time.time()
    y_strings = [] 
    y_min_vs = [] 
    y_max_vs = []
    x_decodeds = []
    zs = []
    for idx, cube in enumerate(cubes):
        if idx % 1000 == 0:
            print(idx)
        x = tf.convert_to_tensor(cube, "float32")
        x = tf.expand_dims(x, 0)
        y = analysis_transform(x)
        z = hyper_encoder(y)
        zs.append(z)
        z_hat, _ = entropy_bottleneck(z, False)
        loc, scale = hyper_decoder(z_hat)
        lower_bound = 1e-9# TODO
        scale = tf.maximum(scale, lower_bound)
        y_string, y_min_v, y_max_v = conditional_entropy_model.compress(y, loc, scale)
        y_strings.append(y_string)
        y_min_vs.append(y_min_v)
        y_max_vs.append(y_max_v)

        y_shape = tf.shape(y)
        y_dec = conditional_entropy_model.decompress(y_string, loc, scale, y_min_v, y_max_v, y_shape)
        x_dec = synthesis_transform(y_dec)
        x_dec = x_dec.numpy()
        x_decodeds.append(x_dec)
 
    y_strings = tf.convert_to_tensor(y_strings, dtype='string')    
    y_min_vs = tf.convert_to_tensor(y_min_vs)
    y_max_vs = tf.convert_to_tensor(y_max_vs)
    zs = tf.concat(zs, axis=0)
    x_decodeds = np.concatenate(x_decodeds, 0)
    x_decodeds = tf.convert_to_tensor(x_decodeds, 'float32')

    z_strings, z_min_v, z_max_v = entropy_bottleneck.compress(zs)
    z_shape = tf.shape(zs)[:]


    return y_strings, y_min_vs, y_max_vs, y_shape, z_strings, z_min_v, z_max_v, z_shape, x_decodeds