def decompress_less_mem(y_strings, y_min_vs, y_max_vs, y_shape, z_strings, z_min_v, z_max_v, z_shape, model, ckpt_dir): """Decompress bitstream to cubes. Input: compressed bitstream. latent representations (y) and hyper prior (z). Output: cubes with shape [batch size, length, width, height, channel(1)] """ print('===== Decompress =====') # load model. #model = importlib.import_module(model) synthesis_transform = model.SynthesisTransform() hyper_encoder = model.HyperEncoder() hyper_decoder = model.HyperDecoder() entropy_bottleneck = EntropyBottleneck() conditional_entropy_model = SymmetricConditional() checkpoint = tf.train.Checkpoint(synthesis_transform=synthesis_transform, hyper_encoder=hyper_encoder, hyper_decoder=hyper_decoder, estimator=entropy_bottleneck) status = checkpoint.restore(tf.train.latest_checkpoint(ckpt_dir)) start = time.time() zs = entropy_bottleneck.decompress(z_strings, z_min_v, z_max_v, z_shape, z_shape[-1]) print("Entropy Decoder (Hyper): {}s".format(round(time.time()-start, 4))) def loop_hyper_deocder(z): z = tf.expand_dims(z, 0) loc, scale = hyper_decoder(z) return tf.squeeze(loc, [0]), tf.squeeze(scale, [0]) start = time.time() locs, scales = tf.map_fn(loop_hyper_deocder, zs, dtype=(tf.float32, tf.float32), parallel_iterations=1, back_prop=False) lower_bound = 1e-9# TODO scales = tf.maximum(scales, lower_bound) print("Hyper Decoder: {}s".format(round(time.time()-start, 4))) start = time.time() # ys = conditional_entropy_model.decompress(y_strings, locs, scales, y_min_v, y_max_v, y_shape) def loop_range_decode(args): y_string, loc, scale, y_min_v, y_max_v = args loc = tf.expand_dims(loc, 0) scale = tf.expand_dims(scale, 0) y_decoded = conditional_entropy_model.decompress(y_string, loc, scale, y_min_v, y_max_v, y_shape) return tf.squeeze(y_decoded, 0) args = (y_strings, locs, scales, y_min_vs, y_max_vs) ys = tf.map_fn(loop_range_decode, args, dtype=tf.float32, parallel_iterations=1, back_prop=False) print("Entropy Decoder: {}s".format(round(time.time()-start, 4))) def loop_synthesis(y): y = tf.expand_dims(y, 0) x = synthesis_transform(y) return tf.squeeze(x, [0]) start = time.time() xs = tf.map_fn(loop_synthesis, ys, dtype=tf.float32, parallel_iterations=1, back_prop=False) print("Synthesis Transform: {}s".format(round(time.time()-start, 4))) return xs
RATIO_EVAL = 9 # # NUM_ITEATION = 2e5 NUM_ITEATION = int(args.num_iteration) alpha = args.alpha beta = args.beta gamma = args.gamma # weight of hyper prior. delta = args.delta # weight of latent representation. init_ckpt_dir = args.init_ckpt_dir lower_bound = args.lower_bound print('lower bound of scale:', lower_bound) reset_optimizer = bool(args.reset_optimizer) print('reset_optimizer:::', reset_optimizer) # Define variables analysis_transform, synthesis_transform = model.AnalysisTransform(), model.SynthesisTransform() hyper_encoder, hyper_decoder = model.HyperEncoder(), model.HyperDecoder() entropy_bottleneck = EntropyBottleneck() conditional_entropy_model = SymmetricConditional() global_step = tf.train.get_or_create_global_step() # lr = tf.train.exponential_decay(1e-4, global_step, 20000, 0.75, staircase=True) # lr = 1e-5 lr = args.lr main_optimizer = tf.train.AdamOptimizer(learning_rate = lr) ########## Define checkpoint ########## if args.reset_optimizer == 0: checkpoint = tf.train.Checkpoint(analysis_transform=analysis_transform, synthesis_transform=synthesis_transform,
def compress_hyper(cubes, model, ckpt_dir, decompress=False): """Compress cubes to bitstream. Input: cubes with shape [batch size, length, width, height, channel(1)]. Output: compressed bitstream. """ print('===== Compress =====') # load model. #model = importlib.import_module(model) analysis_transform = model.AnalysisTransform() synthesis_transform = model.SynthesisTransform() hyper_encoder = model.HyperEncoder() hyper_decoder = model.HyperDecoder() entropy_bottleneck = EntropyBottleneck() conditional_entropy_model = SymmetricConditional() checkpoint = tf.train.Checkpoint(analysis_transform=analysis_transform, synthesis_transform=synthesis_transform, hyper_encoder=hyper_encoder, hyper_decoder=hyper_decoder, estimator=entropy_bottleneck) status = checkpoint.restore(tf.train.latest_checkpoint(ckpt_dir)) x = tf.convert_to_tensor(cubes, "float32") def loop_analysis(x): x = tf.expand_dims(x, 0) y = analysis_transform(x) return tf.squeeze(y) start = time.time() ys = tf.map_fn(loop_analysis, x, dtype=tf.float32, parallel_iterations=1, back_prop=False) print("Analysis Transform: {}s".format(round(time.time() - start, 4))) def loop_hyper_encoder(y): y = tf.expand_dims(y, 0) z = hyper_encoder(y) return tf.squeeze(z) start = time.time() zs = tf.map_fn(loop_hyper_encoder, ys, dtype=tf.float32, parallel_iterations=1, back_prop=False) print("Hyper Encoder: {}s".format(round(time.time() - start, 4))) z_hats, _ = entropy_bottleneck(zs, False) print("Quantize hyperprior.") def loop_hyper_deocder(z): z = tf.expand_dims(z, 0) loc, scale = hyper_decoder(z) return tf.squeeze(loc, [0]), tf.squeeze(scale, [0]) start = time.time() locs, scales = tf.map_fn(loop_hyper_deocder, z_hats, dtype=(tf.float32, tf.float32), parallel_iterations=1, back_prop=False) lower_bound = 1e-9 # TODO scales = tf.maximum(scales, lower_bound) print("Hyper Decoder: {}s".format(round(time.time() - start, 4))) start = time.time() z_strings, z_min_v, z_max_v = entropy_bottleneck.compress(zs) z_shape = tf.shape(zs)[:] print("Entropy Encode (Hyper): {}s".format(round(time.time() - start, 4))) start = time.time() # y_strings, y_min_v, y_max_v = conditional_entropy_model.compress(ys, locs, scales) # y_shape = tf.shape(ys)[:] def loop_range_encode(args): y, loc, scale = args y = tf.expand_dims(y, 0) loc = tf.expand_dims(loc, 0) scale = tf.expand_dims(scale, 0) y_string, y_min_v, y_max_v = conditional_entropy_model.compress( y, loc, scale) return y_string, y_min_v, y_max_v args = (ys, locs, scales) y_strings, y_min_vs, y_max_vs = tf.map_fn(loop_range_encode, args, dtype=(tf.string, tf.int32, tf.int32), parallel_iterations=1, back_prop=False) y_shape = tf.convert_to_tensor(np.insert(tf.shape(ys)[1:].numpy(), 0, 1)) print("Entropy Encode: {}s".format(round(time.time() - start, 4))) if decompress: start = time.time() def loop_range_decode(args): y_string, loc, scale, y_min_v, y_max_v = args loc = tf.expand_dims(loc, 0) scale = tf.expand_dims(scale, 0) y_decoded = conditional_entropy_model.decompress( y_string, loc, scale, y_min_v, y_max_v, y_shape) return tf.squeeze(y_decoded, 0) args = (y_strings, locs, scales, y_min_vs, y_max_vs) y_decodeds = tf.map_fn(loop_range_decode, args, dtype=tf.float32, parallel_iterations=1, back_prop=False) print("Entropy Decode: {}s".format(round(time.time() - start, 4))) def loop_synthesis(y): y = tf.expand_dims(y, 0) x = synthesis_transform(y) return tf.squeeze(x, [0]) start = time.time() x_decodeds = tf.map_fn(loop_synthesis, y_decodeds, dtype=tf.float32, parallel_iterations=1, back_prop=False) print("Synthesis Transform: {}s".format(round(time.time() - start, 4))) return y_strings, y_min_vs, y_max_vs, y_shape, z_strings, z_min_v, z_max_v, z_shape, x_decodeds return y_strings, y_min_vs, y_max_vs, y_shape, z_strings, z_min_v, z_max_v, z_shape
def compress_more_less_mem(cubes, model, ckpt_dir, decompress=False): """Compress cubes to bitstream. Input: cubes with shape [batch size, length, width, height, channel(1)]. Output: compressed bitstream. """ print('===== Compress =====') # load model. #model = importlib.import_module(model) analysis_transform = model.AnalysisTransform() synthesis_transform = model.SynthesisTransform() hyper_encoder = model.HyperEncoder() hyper_decoder = model.HyperDecoder() entropy_bottleneck = EntropyBottleneck() conditional_entropy_model = SymmetricConditional() checkpoint = tf.train.Checkpoint(analysis_transform=analysis_transform, synthesis_transform=synthesis_transform, hyper_encoder=hyper_encoder, hyper_decoder=hyper_decoder, estimator=entropy_bottleneck) status = checkpoint.restore(tf.train.latest_checkpoint(ckpt_dir)) start = time.time() y_strings = [] y_min_vs = [] y_max_vs = [] x_decodeds = [] zs = [] for idx, cube in enumerate(cubes): if idx % 1000 == 0: print(idx) x = tf.convert_to_tensor(cube, "float32") x = tf.expand_dims(x, 0) y = analysis_transform(x) z = hyper_encoder(y) zs.append(z) z_hat, _ = entropy_bottleneck(z, False) loc, scale = hyper_decoder(z_hat) lower_bound = 1e-9# TODO scale = tf.maximum(scale, lower_bound) y_string, y_min_v, y_max_v = conditional_entropy_model.compress(y, loc, scale) y_strings.append(y_string) y_min_vs.append(y_min_v) y_max_vs.append(y_max_v) y_shape = tf.shape(y) y_dec = conditional_entropy_model.decompress(y_string, loc, scale, y_min_v, y_max_v, y_shape) x_dec = synthesis_transform(y_dec) x_dec = x_dec.numpy() x_decodeds.append(x_dec) y_strings = tf.convert_to_tensor(y_strings, dtype='string') y_min_vs = tf.convert_to_tensor(y_min_vs) y_max_vs = tf.convert_to_tensor(y_max_vs) zs = tf.concat(zs, axis=0) x_decodeds = np.concatenate(x_decodeds, 0) x_decodeds = tf.convert_to_tensor(x_decodeds, 'float32') z_strings, z_min_v, z_max_v = entropy_bottleneck.compress(zs) z_shape = tf.shape(zs)[:] return y_strings, y_min_vs, y_max_vs, y_shape, z_strings, z_min_v, z_max_v, z_shape, x_decodeds