def encode(self, model_type, input_path, compressed_file_path, quality_level): # with TOP N dimensions img = Image.open(input_path) w, h = img.size fileobj = open(compressed_file_path, mode='wb') buf = quality_level << 1 buf = buf + model_type arr = np.array([0], dtype=np.uint8) arr[0] = buf arr.tofile(fileobj) arr = np.array([w, h], dtype=np.uint16) arr.tofile(fileobj) fileobj.close() new_w = int(math.ceil(w / 16) * 16) new_h = int(math.ceil(h / 16) * 16) pad_w = new_w - w pad_h = new_h - h input_x = np.asarray(img) input_x = np.pad(input_x, ((pad_h, 0), (pad_w, 0), (0, 0)), mode='reflect') input_x = input_x.reshape(1, new_h, new_w, 3) input_x = input_x.transpose([0, 3, 1, 2]) h_s_out, y_hat, z_hat, sigma_z = self.sess.run( [self.h_s_out, self.y_hat, self.z_hat, self.sigma_z], feed_dict={self.input_x: input_x}) # NCHW ############### encode z #################################### bitout = arithmeticcoding.BitOutputStream( open(compressed_file_path, "ab+")) enc = arithmeticcoding.ArithmeticEncoder(bitout) printProgressBar(0, z_hat.shape[1], prefix='Encoding z_hat:', suffix='Complete', length=50) for ch_idx in range(z_hat.shape[1]): printProgressBar(ch_idx + 1, z_hat.shape[1], prefix='Encoding z_hat:', suffix='Complete', length=50) mu_val = 255 sigma_val = sigma_z[ch_idx] # exp_sigma_val = np.exp(sigma_val) freq = arithmeticcoding.ModelFrequencyTable(mu_val, sigma_val) for h_idx in range(z_hat.shape[2]): for w_idx in range(z_hat.shape[3]): symbol = np.rint(z_hat[0, ch_idx, h_idx, w_idx] + 255) if symbol < 0 or symbol > 511: print("symbol range error: " + str(symbol)) # print(symbol) enc.write(freq, symbol) # enc.write(freq, 512) # enc.finish() # bitout.close() ############### encode y #################################### padded_y1_hat = np.pad(y_hat[:, :self.M1, :, :], ((0, 0), (0, 0), (3, 0), (2, 1)), 'constant', constant_values=((0, 0), (0, 0), (0, 0), (0, 0))) # bitout = arithmeticcoding.BitOutputStream(open(enc_outputfile, "wb")) # enc = arithmeticcoding.ArithmeticEncoder(bitout) c_prime = h_s_out[:, :self.M1, :, :] sigma2 = h_s_out[:, self.M1:, :, :] padded_c_prime = np.pad(c_prime, ((0, 0), (0, 0), (3, 0), (2, 1)), 'constant', constant_values=((0, 0), (0, 0), (0, 0), (0, 0))) printProgressBar(0, y_hat.shape[2], prefix='Encoding y_hat:', suffix='Complete', length=50) for h_idx in range(y_hat.shape[2]): printProgressBar(h_idx + 1, y_hat.shape[2], prefix='Encoding y_hat:', suffix='Complete', length=50) for w_idx in range(y_hat.shape[3]): c_prime_i = self.extractor_prime(padded_c_prime, h_idx, w_idx) c_doubleprime_i = self.extractor_doubleprime( padded_y1_hat, h_idx, w_idx) concatenated_c_i = np.concatenate([c_doubleprime_i, c_prime_i], axis=1) pred_mean, pred_sigma = self.sess.run( [self.pred_mean, self.pred_sigma], feed_dict={self.concatenated_c_i: concatenated_c_i}) zero_means = np.zeros([ pred_mean.shape[0], self.M2, pred_mean.shape[2], pred_mean.shape[3] ]) concat_pred_mean = np.concatenate([pred_mean, zero_means], axis=1) concat_pred_sigma = np.concatenate([ pred_sigma, sigma2[:, :, h_idx:h_idx + 1, w_idx:w_idx + 1] ], axis=1) for ch_idx in range(self.M): mu_val = concat_pred_mean[0, ch_idx, 0, 0] + 255 sigma_val = concat_pred_sigma[0, ch_idx, 0, 0] # exp_sigma_val = np.exp(sigma_val) freq = arithmeticcoding.ModelFrequencyTable( mu_val, sigma_val) symbol = np.rint(y_hat[0, ch_idx, h_idx, w_idx] + 255) if symbol < 0 or symbol > 511: print("symbol range error: " + str(symbol)) enc.write(freq, symbol) enc.write(freq, 512) enc.finish() bitout.close() return compressed_file_path
def decode(self, compressed_file, recon_path): # with TOP N dimensions fileobj = open(compressed_file, mode='rb') fileobj.read(1) #dummy buf = fileobj.read(4) arr = np.frombuffer(buf, dtype=np.uint16) w = int(arr[0]) h = int(arr[1]) padded_w = int(math.ceil(w / 16) * 16) padded_h = int(math.ceil(h / 16) * 16) y_hat, z_hat, sigma_z = self.sess.run( [self.y_hat, self.z_hat, self.sigma_z], feed_dict={self.input_x: np.zeros( (1, 3, padded_h, padded_w))}) # NCHW padded_y1_hat = np.pad(y_hat[:, :self.M1, :, :], ((0, 0), (0, 0), (3, 0), (2, 1)), 'constant', constant_values=((0, 0), (0, 0), (0, 0), (0, 0))) ############### decode zhat #################################### bitin = arithmeticcoding.BitInputStream(fileobj) dec = arithmeticcoding.ArithmeticDecoder(bitin) printProgressBar(0, z_hat.shape[1], prefix='Decoding z_hat:', suffix='Complete', length=50) for ch_idx in range(z_hat.shape[1]): printProgressBar(ch_idx + 1, z_hat.shape[1], prefix='Decoding z_hat:', suffix='Complete', length=50) mu_val = 255 sigma_val = sigma_z[ch_idx] # exp_sigma_val = np.exp(sigma_val) freq = arithmeticcoding.ModelFrequencyTable(mu_val, sigma_val) for h_idx in range(z_hat.shape[2]): for w_idx in range(z_hat.shape[3]): symbol = dec.read(freq) if symbol == 512: # EOF symbol print("EOF symbol") break z_hat[:, ch_idx, h_idx, w_idx] = symbol - 255 # bitin.close() ################## ################################################# # Entropy decoding y # padded_z = np.zeros_like(padded_z, dtype = np.float32) h_s_out = self.sess.run(self.h_s_out, feed_dict={self.z_hat: z_hat}) c_prime = h_s_out[:, :self.M1, :, :] sigma2 = h_s_out[:, self.M1:, :, :] padded_c_prime = np.pad(c_prime, ((0, 0), (0, 0), (3, 0), (2, 1)), 'constant', constant_values=((0, 0), (0, 0), (0, 0), (0, 0))) padded_y1_hat[:, :, :, :] = 0.0 y_hat[:, :, :, :] = 0.0 # bitin = arithmeticcoding.BitInputStream(open(dec_inputfile, "rb")) # dec = arithmeticcoding.ArithmeticDecoder(bitin) printProgressBar(0, y_hat.shape[2], prefix='Decoding y_hat:', suffix='Complete', length=50) for h_idx in range(y_hat.shape[2]): printProgressBar(h_idx + 1, y_hat.shape[2], prefix='Decoding y_hat:', suffix='Complete', length=50) for w_idx in range(y_hat.shape[3]): c_prime_i = self.extractor_prime(padded_c_prime, h_idx, w_idx) c_doubleprime_i = self.extractor_doubleprime( padded_y1_hat, h_idx, w_idx) concatenated_c_i = np.concatenate([c_doubleprime_i, c_prime_i], axis=1) pred_mean, pred_sigma = self.sess.run( [self.pred_mean, self.pred_sigma], feed_dict={self.concatenated_c_i: concatenated_c_i}) zero_means = np.zeros([ pred_mean.shape[0], self.M2, pred_mean.shape[2], pred_mean.shape[3] ]) concat_pred_mean = np.concatenate([pred_mean, zero_means], axis=1) concat_pred_sigma = np.concatenate([ pred_sigma, sigma2[:, :, h_idx:h_idx + 1, w_idx:w_idx + 1] ], axis=1) for ch_idx in range(self.M): mu_val = concat_pred_mean[0, ch_idx, 0, 0] + 255 sigma_val = concat_pred_sigma[0, ch_idx, 0, 0] freq = arithmeticcoding.ModelFrequencyTable( mu_val, sigma_val) symbol = dec.read(freq) if symbol == 512: # EOF symbol print("EOF symbol") break if ch_idx < self.M1: padded_y1_hat[:, ch_idx, h_idx + 3, w_idx + 2] = symbol - 255 y_hat[:, ch_idx, h_idx, w_idx] = symbol - 255 bitin.close() ################################################# recon = self.sess.run(self.recon_image, {self.y_hat: y_hat}) recon = recon[0, -h:, -w:, :] im = Image.fromarray(recon.astype(np.uint8)) im.save(recon_path) return
def decode(self, compressed_file, recon_path): # with TOP N dimensions fileobj = open(compressed_file, mode='rb') fileobj.read(1) #dummy buf = fileobj.read(4) arr = np.frombuffer(buf, dtype=np.uint16) w = int(arr[0]) h = int(arr[1]) new_w = int(math.ceil(float(w) / 2.0) * 2) new_h = int(math.ceil(float(h) / 2.0) * 2) pad_w_1 = int((float(new_w) / 2.0) % 2) pad_h_1 = int((float(new_h) / 2.0) % 2) res_w_1 = math.floor(float(new_w) / 2.0) + pad_w_1 res_h_1 = math.floor(float(new_h) / 2.0) + pad_h_1 pad_w_2 = int((float(res_w_1) / 2.0) % 2) pad_h_2 = int((float(res_h_1) / 2.0) % 2) res_w_2 = math.floor(float(res_w_1) / 2.0) + pad_w_2 res_h_2 = math.floor(float(res_h_1) / 2.0) + pad_h_2 pad_w_3 = int((float(res_w_2) / 2.0) % 2) pad_h_3 = int((float(res_h_2) / 2.0) % 2) res_w_3 = math.floor(float(res_w_2) / 2.0) + pad_w_3 res_h_3 = math.floor(float(res_h_2) / 2.0) + pad_h_3 pad_w = new_w - w pad_h = new_h - h sigma_z = self.sess.run(self.sigma_z) y_hat = np.zeros( (1, self.M, int(float(res_h_3) / 2.0), int(float(res_w_3) / 2.0)), dtype=np.float32) y_w = y_hat.shape[3] y_h = y_hat.shape[2] new_y_w = int(math.ceil(float(y_w) / 4.0) * 4) new_y_h = int(math.ceil(float(y_h) / 4.0) * 4) pad_y_w = new_y_w - y_w pad_y_h = new_y_h - y_h pad_y_hat = np.pad(y_hat, ((0, 0), (0, 0), (0, pad_y_h), (0, pad_y_w)), mode='edge') z_hat = self.sess.run(self.z_hat, feed_dict={self.y_hat: pad_y_hat}) # NCHW # y_hat, z_hat, sigma_z = self.sess.run([self.y_hat, self.z_hat, self.sigma_z], # feed_dict={ # self.input_x: np.zeros((1, 3, padded_h, padded_w))}) # NCHW padded_y1_hat = np.pad(y_hat[:, :self.M1, :, :], ((0, 0), (0, 0), (3, 0), (2, 1)), 'constant', constant_values=((0, 0), (0, 0), (0, 0), (0, 0))) ############### decode zhat #################################### bitin = arithmeticcoding.BitInputStream(fileobj) dec = arithmeticcoding.ArithmeticDecoder(bitin) printProgressBar(0, z_hat.shape[1], prefix='Decoding z_hat:', suffix='Complete', length=50) for ch_idx in range(z_hat.shape[1]): printProgressBar(ch_idx + 1, z_hat.shape[1], prefix='Decoding z_hat:', suffix='Complete', length=50) mu_val = 255 sigma_val = sigma_z[ch_idx] # exp_sigma_val = np.exp(sigma_val) freq = arithmeticcoding.ModelFrequencyTable(mu_val, sigma_val) for h_idx in range(z_hat.shape[2]): for w_idx in range(z_hat.shape[3]): symbol = dec.read(freq) if symbol == 512: # EOF symbol print("EOF symbol") break z_hat[:, ch_idx, h_idx, w_idx] = symbol - 255 # bitin.close() ################## ################################################# # Entropy decoding y # padded_z = np.zeros_like(padded_z, dtype = np.float32) h_s_out = self.sess.run(self.h_s_out, feed_dict={self.z_hat: z_hat}) c_prime = h_s_out[:, :self.M1, :, :] sigma2 = h_s_out[:, self.M1:, :, :] padded_c_prime = np.pad(c_prime, ((0, 0), (0, 0), (3, 0), (2, 1)), 'constant', constant_values=((0, 0), (0, 0), (0, 0), (0, 0))) padded_y1_hat[:, :, :, :] = 0.0 y_hat[:, :, :, :] = 0.0 # bitin = arithmeticcoding.BitInputStream(open(dec_inputfile, "rb")) # dec = arithmeticcoding.ArithmeticDecoder(bitin) printProgressBar(0, y_hat.shape[2], prefix='Decoding y_hat:', suffix='Complete', length=50) for h_idx in range(y_hat.shape[2]): printProgressBar(h_idx + 1, y_hat.shape[2], prefix='Decoding y_hat:', suffix='Complete', length=50) for w_idx in range(y_hat.shape[3]): c_prime_i = self.extractor_prime(padded_c_prime, h_idx, w_idx) c_doubleprime_i = self.extractor_doubleprime( padded_y1_hat, h_idx, w_idx) concatenated_c_i = np.concatenate([c_doubleprime_i, c_prime_i], axis=1) pred_mean, pred_sigma = self.sess.run( [self.pred_mean, self.pred_sigma], feed_dict={self.concatenated_c_i: concatenated_c_i}) zero_means = np.zeros([ pred_mean.shape[0], self.M2, pred_mean.shape[2], pred_mean.shape[3] ]) concat_pred_mean = np.concatenate([pred_mean, zero_means], axis=1) concat_pred_sigma = np.concatenate([ pred_sigma, sigma2[:, :, h_idx:h_idx + 1, w_idx:w_idx + 1] ], axis=1) for ch_idx in range(self.M): mu_val = concat_pred_mean[0, ch_idx, 0, 0] + 255 sigma_val = concat_pred_sigma[0, ch_idx, 0, 0] freq = arithmeticcoding.ModelFrequencyTable( mu_val, sigma_val) symbol = dec.read(freq) if symbol == 512: # EOF symbol print("EOF symbol") break if ch_idx < self.M1: padded_y1_hat[:, ch_idx, h_idx + 3, w_idx + 2] = symbol - 255 y_hat[:, ch_idx, h_idx, w_idx] = symbol - 255 bitin.close() ################################################# ############### gsh1 = self.sess.run(self.gsh1, feed_dict={self.y_hat: y_hat}) gsh1 = gsh1[:, :res_h_3 - pad_h_3, :res_w_3 - pad_w_3, :] gsh2 = self.sess.run(self.gsh2, feed_dict={self.gsh1: gsh1}) gsh2 = gsh2[:, :res_h_2 - pad_h_2, :res_w_2 - pad_w_2, :] gsh3 = self.sess.run(self.gsh3, feed_dict={self.gsh2: gsh2}) gsh3 = gsh3[:, :res_h_1 - pad_h_1, :res_w_1 - pad_w_1, :] recon = self.sess.run(self.recon_image, feed_dict={self.gsh3: gsh3}) recon = recon[0, :recon.shape[1] - pad_h, :recon.shape[2] - pad_w, :] ############### im = Image.fromarray(recon.astype(np.uint8)) im.save(recon_path) return
def decode(self, compressed_file, recon_path): # with TOP N dimensions fileobj = open(compressed_file, mode='rb') fileobj.read(1) #dummy buf = fileobj.read(4) arr = np.frombuffer(buf, dtype=np.uint16) w = int(arr[0]) h = int(arr[1]) padded_w = int(math.ceil(w / 16) * 16) padded_h = int(math.ceil(h / 16) * 16) y_hat, z_hat, sigma_z = self.sess.run([self.y_hat, self.z_hat, self.sigma_z], feed_dict={self.input_x: np.zeros((1, 3, padded_h, padded_w))}) # NCHW ############### decode zhat #################################### bitin = arithmeticcoding.BitInputStream(fileobj) dec = arithmeticcoding.ArithmeticDecoder(bitin) z_hat[:, :, :, :] = 0.0 printProgressBar(0, z_hat.shape[1], prefix='Decoding z_hat:', suffix='Complete', length=50) for ch_idx in range(z_hat.shape[1]): printProgressBar(ch_idx + 1, z_hat.shape[1], prefix='Decoding z_hat:', suffix='Complete', length=50) mu_val = 255 sigma_val = sigma_z[ch_idx] freq = arithmeticcoding.ModelFrequencyTable(mu_val, sigma_val) for h_idx in range(z_hat.shape[2]): for w_idx in range(z_hat.shape[3]): symbol = dec.read(freq) if symbol == 512: # EOF symbol print("EOF symbol") break z_hat[:, ch_idx, h_idx, w_idx] = symbol - 255 ############### decode yhat #################################### c_prime = self.sess.run(self.c_prime, feed_dict={self.z_hat: z_hat}) # c_prime = np.round(c_prime, decimals=4) padded_c_prime = np.pad(c_prime, ((0, 0), (0, 0), (3, 0), (2, 1)), 'constant', constant_values=((0, 0), (0, 0), (0, 0), (0, 0))) padded_y_hat = np.pad(y_hat, ((0, 0), (0, 0), (3, 0), (2, 1)), 'constant', constant_values=((0, 0), (0, 0), (0, 0), (0, 0))) padded_y_hat[:, :, :, :] = 0.0 printProgressBar(0, y_hat.shape[2], prefix='Decoding y_hat:', suffix='Complete', length=50) for h_idx in range(y_hat.shape[2]): printProgressBar(h_idx + 1, y_hat.shape[2], prefix='Decoding y_hat:', suffix='Complete', length=50) for w_idx in range(y_hat.shape[3]): c_prime_i = self.extractor_prime(padded_c_prime, h_idx, w_idx) c_doubleprime_i = self.extractor_doubleprime(padded_y_hat, h_idx, w_idx) concatenated_c_i = np.concatenate([c_doubleprime_i, c_prime_i], axis=1) pred_mean, pred_sigma = self.sess.run( [self.pred_mean, self.pred_sigma], feed_dict={self.concatenated_c_i: concatenated_c_i}) for ch_idx in range(self.M): mu_val = pred_mean[0, ch_idx, 0, 0] + 255 sigma_val = pred_sigma[0, ch_idx, 0, 0] freq = arithmeticcoding.ModelFrequencyTable(mu_val, sigma_val) symbol = dec.read(freq) if symbol == 512: # EOF symbol print("EOF symbol") break padded_y_hat[:, ch_idx, h_idx + 3, w_idx + 2] = symbol - 255 bitin.close() y_hat = padded_y_hat[:, :, 3:, 2:-1] ################################################# recon = self.sess.run(self.recon_image, {self.y_hat: y_hat}) recon = recon[0, -h:, -w:, :] im = Image.fromarray(recon.astype(np.uint8)) im.save(recon_path) return
def encode(self, model_type, input_path, compressed_file_path, quality_level): # with TOP N dimensions img = Image.open(input_path) w, h = img.size fileobj = open(compressed_file_path, mode='wb') buf = quality_level << 1 buf = buf + model_type arr = np.array([0], dtype=np.uint8) arr[0] = buf arr.tofile(fileobj) arr = np.array([w, h], dtype=np.uint16) arr.tofile(fileobj) fileobj.close() new_w = int(math.ceil(float(w) / 2.0) * 2) new_h = int(math.ceil(float(h) / 2.0) * 2) pad_w = new_w - w pad_h = new_h - h img_array = np.asarray(img) img_array = np.pad(img_array, ((0, pad_h), (0, pad_w), (0, 0)), mode='edge') input_x = img_array input_x = input_x.reshape(1, new_h, new_w, 3) input_x = input_x.transpose([0, 3, 1, 2]) gah1, sigma_z = self.sess.run([self.gah1, self.sigma_z], feed_dict={self.input_x: input_x}) gah1 = np.pad(gah1, ((0, 0), (0, gah1.shape[1] % 2), (0, gah1.shape[2] % 2), (0, 0)), mode='constant') gah2 = self.sess.run(self.gah2, feed_dict={self.gah1: gah1}) gah2 = np.pad(gah2, ((0, 0), (0, gah2.shape[1] % 2), (0, gah2.shape[2] % 2), (0, 0)), mode='constant') gah3 = self.sess.run(self.gah3, feed_dict={self.gah2: gah2}) gah3 = np.pad(gah3, ((0, 0), (0, gah3.shape[1] % 2), (0, gah3.shape[2] % 2), (0, 0)), mode='constant') y_hat = self.sess.run(self.y_hat, feed_dict={self.gah3: gah3}) y_w = y_hat.shape[3] y_h = y_hat.shape[2] new_y_w = int(math.ceil(float(y_w) / 4.0) * 4) new_y_h = int(math.ceil(float(y_h) / 4.0) * 4) pad_y_w = new_y_w - y_w pad_y_h = new_y_h - y_h pad_y_hat = np.pad(y_hat, ((0, 0), (0, 0), (0, pad_y_h), (0, pad_y_w)), mode='symmetric') z_hat, c_prime = self.sess.run([self.z_hat, self.c_prime], feed_dict={self.y_hat: pad_y_hat}) ############### encode zhat #################################### printProgressBar(0, z_hat.shape[1], prefix='Encoding z_hat:', suffix='Complete', length=50) bitout = arithmeticcoding.BitOutputStream( open(compressed_file_path, "ab+")) enc = arithmeticcoding.ArithmeticEncoder(bitout) for ch_idx in range(z_hat.shape[1]): printProgressBar(ch_idx + 1, z_hat.shape[1], prefix='Encoding z_hat:', suffix='Complete', length=50) mu_val = 255 sigma_val = sigma_z[ch_idx] freq = arithmeticcoding.ModelFrequencyTable(mu_val, sigma_val) for h_idx in range(z_hat.shape[2]): for w_idx in range(z_hat.shape[3]): symbol = np.int(z_hat[0, ch_idx, h_idx, w_idx] + 255) if symbol < 0 or symbol > 511: print("symbol range error: " + str(symbol)) enc.write(freq, symbol) ############### encode yhat #################################### padded_y_hat = np.pad(y_hat, ((0, 0), (0, 0), (3, 0), (2, 1)), 'constant', constant_values=((0, 0), (0, 0), (0, 0), (0, 0))) padded_c_prime = np.pad(c_prime, ((0, 0), (0, 0), (3, 0), (2, 1)), 'constant', constant_values=((0, 0), (0, 0), (0, 0), (0, 0))) printProgressBar(0, y_hat.shape[2], prefix='Encoding y_hat:', suffix='Complete', length=50) for h_idx in range(y_hat.shape[2]): printProgressBar(h_idx + 1, y_hat.shape[2], prefix='Encoding y_hat:', suffix='Complete', length=50) for w_idx in range(y_hat.shape[3]): c_prime_i = self.extractor_prime(padded_c_prime, h_idx, w_idx) c_doubleprime_i = self.extractor_doubleprime( padded_y_hat, h_idx, w_idx) concatenated_c_i = np.concatenate([c_doubleprime_i, c_prime_i], axis=1) pred_mean, pred_sigma = self.sess.run( [self.pred_mean, self.pred_sigma], feed_dict={self.concatenated_c_i: concatenated_c_i}) for ch_idx in range(self.M): mu_val = pred_mean[0, ch_idx, 0, 0] + 255 sigma_val = pred_sigma[0, ch_idx, 0, 0] freq = arithmeticcoding.ModelFrequencyTable( mu_val, sigma_val) symbol = np.int(y_hat[0, ch_idx, h_idx, w_idx] + 255) if symbol < 0 or symbol > 511: print("symbol range error: " + str(symbol)) enc.write(freq, symbol) enc.write(freq, 512) enc.finish() bitout.close() return compressed_file_path