コード例 #1
0
    def encode(self, model_type, input_path, compressed_file_path,
               quality_level):  # with TOP N dimensions

        img = Image.open(input_path)
        w, h = img.size

        fileobj = open(compressed_file_path, mode='wb')

        buf = quality_level << 1
        buf = buf + model_type
        arr = np.array([0], dtype=np.uint8)
        arr[0] = buf
        arr.tofile(fileobj)

        arr = np.array([w, h], dtype=np.uint16)
        arr.tofile(fileobj)
        fileobj.close()

        new_w = int(math.ceil(w / 16) * 16)
        new_h = int(math.ceil(h / 16) * 16)

        pad_w = new_w - w
        pad_h = new_h - h

        input_x = np.asarray(img)
        input_x = np.pad(input_x, ((pad_h, 0), (pad_w, 0), (0, 0)),
                         mode='reflect')
        input_x = input_x.reshape(1, new_h, new_w, 3)
        input_x = input_x.transpose([0, 3, 1, 2])

        h_s_out, y_hat, z_hat, sigma_z = self.sess.run(
            [self.h_s_out, self.y_hat, self.z_hat, self.sigma_z],
            feed_dict={self.input_x: input_x})  # NCHW

        ############### encode z ####################################
        bitout = arithmeticcoding.BitOutputStream(
            open(compressed_file_path, "ab+"))
        enc = arithmeticcoding.ArithmeticEncoder(bitout)

        printProgressBar(0,
                         z_hat.shape[1],
                         prefix='Encoding z_hat:',
                         suffix='Complete',
                         length=50)
        for ch_idx in range(z_hat.shape[1]):
            printProgressBar(ch_idx + 1,
                             z_hat.shape[1],
                             prefix='Encoding z_hat:',
                             suffix='Complete',
                             length=50)
            mu_val = 255
            sigma_val = sigma_z[ch_idx]
            # exp_sigma_val = np.exp(sigma_val)

            freq = arithmeticcoding.ModelFrequencyTable(mu_val, sigma_val)

            for h_idx in range(z_hat.shape[2]):
                for w_idx in range(z_hat.shape[3]):
                    symbol = np.rint(z_hat[0, ch_idx, h_idx, w_idx] + 255)
                    if symbol < 0 or symbol > 511:
                        print("symbol range error: " + str(symbol))

                    # print(symbol)
                    enc.write(freq, symbol)

        # enc.write(freq, 512)
        # enc.finish()
        # bitout.close()

        ############### encode y ####################################
        padded_y1_hat = np.pad(y_hat[:, :self.M1, :, :],
                               ((0, 0), (0, 0), (3, 0), (2, 1)),
                               'constant',
                               constant_values=((0, 0), (0, 0), (0, 0), (0,
                                                                         0)))

        # bitout = arithmeticcoding.BitOutputStream(open(enc_outputfile, "wb"))
        # enc = arithmeticcoding.ArithmeticEncoder(bitout)

        c_prime = h_s_out[:, :self.M1, :, :]
        sigma2 = h_s_out[:, self.M1:, :, :]
        padded_c_prime = np.pad(c_prime, ((0, 0), (0, 0), (3, 0), (2, 1)),
                                'constant',
                                constant_values=((0, 0), (0, 0), (0, 0), (0,
                                                                          0)))

        printProgressBar(0,
                         y_hat.shape[2],
                         prefix='Encoding y_hat:',
                         suffix='Complete',
                         length=50)
        for h_idx in range(y_hat.shape[2]):
            printProgressBar(h_idx + 1,
                             y_hat.shape[2],
                             prefix='Encoding y_hat:',
                             suffix='Complete',
                             length=50)
            for w_idx in range(y_hat.shape[3]):
                c_prime_i = self.extractor_prime(padded_c_prime, h_idx, w_idx)
                c_doubleprime_i = self.extractor_doubleprime(
                    padded_y1_hat, h_idx, w_idx)
                concatenated_c_i = np.concatenate([c_doubleprime_i, c_prime_i],
                                                  axis=1)

                pred_mean, pred_sigma = self.sess.run(
                    [self.pred_mean, self.pred_sigma],
                    feed_dict={self.concatenated_c_i: concatenated_c_i})

                zero_means = np.zeros([
                    pred_mean.shape[0], self.M2, pred_mean.shape[2],
                    pred_mean.shape[3]
                ])

                concat_pred_mean = np.concatenate([pred_mean, zero_means],
                                                  axis=1)
                concat_pred_sigma = np.concatenate([
                    pred_sigma, sigma2[:, :, h_idx:h_idx + 1, w_idx:w_idx + 1]
                ],
                                                   axis=1)

                for ch_idx in range(self.M):
                    mu_val = concat_pred_mean[0, ch_idx, 0, 0] + 255
                    sigma_val = concat_pred_sigma[0, ch_idx, 0, 0]
                    # exp_sigma_val = np.exp(sigma_val)

                    freq = arithmeticcoding.ModelFrequencyTable(
                        mu_val, sigma_val)

                    symbol = np.rint(y_hat[0, ch_idx, h_idx, w_idx] + 255)
                    if symbol < 0 or symbol > 511:
                        print("symbol range error: " + str(symbol))
                    enc.write(freq, symbol)
        enc.write(freq, 512)
        enc.finish()
        bitout.close()

        return compressed_file_path
コード例 #2
0
    def decode(self, compressed_file, recon_path):  # with TOP N dimensions

        fileobj = open(compressed_file, mode='rb')
        fileobj.read(1)  #dummy
        buf = fileobj.read(4)
        arr = np.frombuffer(buf, dtype=np.uint16)
        w = int(arr[0])
        h = int(arr[1])

        padded_w = int(math.ceil(w / 16) * 16)
        padded_h = int(math.ceil(h / 16) * 16)

        y_hat, z_hat, sigma_z = self.sess.run(
            [self.y_hat, self.z_hat, self.sigma_z],
            feed_dict={self.input_x: np.zeros(
                (1, 3, padded_h, padded_w))})  # NCHW

        padded_y1_hat = np.pad(y_hat[:, :self.M1, :, :],
                               ((0, 0), (0, 0), (3, 0), (2, 1)),
                               'constant',
                               constant_values=((0, 0), (0, 0), (0, 0), (0,
                                                                         0)))

        ############### decode zhat ####################################
        bitin = arithmeticcoding.BitInputStream(fileobj)
        dec = arithmeticcoding.ArithmeticDecoder(bitin)

        printProgressBar(0,
                         z_hat.shape[1],
                         prefix='Decoding z_hat:',
                         suffix='Complete',
                         length=50)
        for ch_idx in range(z_hat.shape[1]):
            printProgressBar(ch_idx + 1,
                             z_hat.shape[1],
                             prefix='Decoding z_hat:',
                             suffix='Complete',
                             length=50)
            mu_val = 255
            sigma_val = sigma_z[ch_idx]
            # exp_sigma_val = np.exp(sigma_val)

            freq = arithmeticcoding.ModelFrequencyTable(mu_val, sigma_val)

            for h_idx in range(z_hat.shape[2]):
                for w_idx in range(z_hat.shape[3]):
                    symbol = dec.read(freq)
                    if symbol == 512:  # EOF symbol
                        print("EOF symbol")
                        break
                    z_hat[:, ch_idx, h_idx, w_idx] = symbol - 255

        # bitin.close()

        ##################
        #################################################
        # Entropy decoding y
        # padded_z = np.zeros_like(padded_z, dtype = np.float32)
        h_s_out = self.sess.run(self.h_s_out, feed_dict={self.z_hat: z_hat})
        c_prime = h_s_out[:, :self.M1, :, :]
        sigma2 = h_s_out[:, self.M1:, :, :]
        padded_c_prime = np.pad(c_prime, ((0, 0), (0, 0), (3, 0), (2, 1)),
                                'constant',
                                constant_values=((0, 0), (0, 0), (0, 0), (0,
                                                                          0)))

        padded_y1_hat[:, :, :, :] = 0.0
        y_hat[:, :, :, :] = 0.0

        # bitin = arithmeticcoding.BitInputStream(open(dec_inputfile, "rb"))
        # dec = arithmeticcoding.ArithmeticDecoder(bitin)

        printProgressBar(0,
                         y_hat.shape[2],
                         prefix='Decoding y_hat:',
                         suffix='Complete',
                         length=50)
        for h_idx in range(y_hat.shape[2]):
            printProgressBar(h_idx + 1,
                             y_hat.shape[2],
                             prefix='Decoding y_hat:',
                             suffix='Complete',
                             length=50)
            for w_idx in range(y_hat.shape[3]):
                c_prime_i = self.extractor_prime(padded_c_prime, h_idx, w_idx)
                c_doubleprime_i = self.extractor_doubleprime(
                    padded_y1_hat, h_idx, w_idx)
                concatenated_c_i = np.concatenate([c_doubleprime_i, c_prime_i],
                                                  axis=1)

                pred_mean, pred_sigma = self.sess.run(
                    [self.pred_mean, self.pred_sigma],
                    feed_dict={self.concatenated_c_i: concatenated_c_i})

                zero_means = np.zeros([
                    pred_mean.shape[0], self.M2, pred_mean.shape[2],
                    pred_mean.shape[3]
                ])

                concat_pred_mean = np.concatenate([pred_mean, zero_means],
                                                  axis=1)
                concat_pred_sigma = np.concatenate([
                    pred_sigma, sigma2[:, :, h_idx:h_idx + 1, w_idx:w_idx + 1]
                ],
                                                   axis=1)

                for ch_idx in range(self.M):
                    mu_val = concat_pred_mean[0, ch_idx, 0, 0] + 255
                    sigma_val = concat_pred_sigma[0, ch_idx, 0, 0]

                    freq = arithmeticcoding.ModelFrequencyTable(
                        mu_val, sigma_val)

                    symbol = dec.read(freq)
                    if symbol == 512:  # EOF symbol
                        print("EOF symbol")
                        break
                    if ch_idx < self.M1:
                        padded_y1_hat[:, ch_idx, h_idx + 3,
                                      w_idx + 2] = symbol - 255
                    y_hat[:, ch_idx, h_idx, w_idx] = symbol - 255
        bitin.close()

        #################################################

        recon = self.sess.run(self.recon_image, {self.y_hat: y_hat})
        recon = recon[0, -h:, -w:, :]

        im = Image.fromarray(recon.astype(np.uint8))
        im.save(recon_path)

        return
コード例 #3
0
    def decode(self, compressed_file, recon_path):  # with TOP N dimensions

        fileobj = open(compressed_file, mode='rb')
        fileobj.read(1)  #dummy
        buf = fileobj.read(4)
        arr = np.frombuffer(buf, dtype=np.uint16)
        w = int(arr[0])
        h = int(arr[1])

        new_w = int(math.ceil(float(w) / 2.0) * 2)
        new_h = int(math.ceil(float(h) / 2.0) * 2)

        pad_w_1 = int((float(new_w) / 2.0) % 2)
        pad_h_1 = int((float(new_h) / 2.0) % 2)
        res_w_1 = math.floor(float(new_w) / 2.0) + pad_w_1
        res_h_1 = math.floor(float(new_h) / 2.0) + pad_h_1

        pad_w_2 = int((float(res_w_1) / 2.0) % 2)
        pad_h_2 = int((float(res_h_1) / 2.0) % 2)
        res_w_2 = math.floor(float(res_w_1) / 2.0) + pad_w_2
        res_h_2 = math.floor(float(res_h_1) / 2.0) + pad_h_2

        pad_w_3 = int((float(res_w_2) / 2.0) % 2)
        pad_h_3 = int((float(res_h_2) / 2.0) % 2)
        res_w_3 = math.floor(float(res_w_2) / 2.0) + pad_w_3
        res_h_3 = math.floor(float(res_h_2) / 2.0) + pad_h_3

        pad_w = new_w - w
        pad_h = new_h - h

        sigma_z = self.sess.run(self.sigma_z)
        y_hat = np.zeros(
            (1, self.M, int(float(res_h_3) / 2.0), int(float(res_w_3) / 2.0)),
            dtype=np.float32)
        y_w = y_hat.shape[3]
        y_h = y_hat.shape[2]
        new_y_w = int(math.ceil(float(y_w) / 4.0) * 4)
        new_y_h = int(math.ceil(float(y_h) / 4.0) * 4)
        pad_y_w = new_y_w - y_w
        pad_y_h = new_y_h - y_h
        pad_y_hat = np.pad(y_hat, ((0, 0), (0, 0), (0, pad_y_h), (0, pad_y_w)),
                           mode='edge')
        z_hat = self.sess.run(self.z_hat, feed_dict={self.y_hat:
                                                     pad_y_hat})  # NCHW

        # y_hat, z_hat, sigma_z = self.sess.run([self.y_hat, self.z_hat, self.sigma_z],
        #                                       feed_dict={
        #                                           self.input_x: np.zeros((1, 3, padded_h, padded_w))})  # NCHW

        padded_y1_hat = np.pad(y_hat[:, :self.M1, :, :],
                               ((0, 0), (0, 0), (3, 0), (2, 1)),
                               'constant',
                               constant_values=((0, 0), (0, 0), (0, 0), (0,
                                                                         0)))

        ############### decode zhat ####################################
        bitin = arithmeticcoding.BitInputStream(fileobj)
        dec = arithmeticcoding.ArithmeticDecoder(bitin)

        printProgressBar(0,
                         z_hat.shape[1],
                         prefix='Decoding z_hat:',
                         suffix='Complete',
                         length=50)
        for ch_idx in range(z_hat.shape[1]):
            printProgressBar(ch_idx + 1,
                             z_hat.shape[1],
                             prefix='Decoding z_hat:',
                             suffix='Complete',
                             length=50)
            mu_val = 255
            sigma_val = sigma_z[ch_idx]
            # exp_sigma_val = np.exp(sigma_val)

            freq = arithmeticcoding.ModelFrequencyTable(mu_val, sigma_val)

            for h_idx in range(z_hat.shape[2]):
                for w_idx in range(z_hat.shape[3]):
                    symbol = dec.read(freq)
                    if symbol == 512:  # EOF symbol
                        print("EOF symbol")
                        break
                    z_hat[:, ch_idx, h_idx, w_idx] = symbol - 255

        # bitin.close()

        ##################
        #################################################
        # Entropy decoding y
        # padded_z = np.zeros_like(padded_z, dtype = np.float32)
        h_s_out = self.sess.run(self.h_s_out, feed_dict={self.z_hat: z_hat})
        c_prime = h_s_out[:, :self.M1, :, :]
        sigma2 = h_s_out[:, self.M1:, :, :]
        padded_c_prime = np.pad(c_prime, ((0, 0), (0, 0), (3, 0), (2, 1)),
                                'constant',
                                constant_values=((0, 0), (0, 0), (0, 0), (0,
                                                                          0)))

        padded_y1_hat[:, :, :, :] = 0.0
        y_hat[:, :, :, :] = 0.0

        # bitin = arithmeticcoding.BitInputStream(open(dec_inputfile, "rb"))
        # dec = arithmeticcoding.ArithmeticDecoder(bitin)

        printProgressBar(0,
                         y_hat.shape[2],
                         prefix='Decoding y_hat:',
                         suffix='Complete',
                         length=50)
        for h_idx in range(y_hat.shape[2]):
            printProgressBar(h_idx + 1,
                             y_hat.shape[2],
                             prefix='Decoding y_hat:',
                             suffix='Complete',
                             length=50)
            for w_idx in range(y_hat.shape[3]):
                c_prime_i = self.extractor_prime(padded_c_prime, h_idx, w_idx)
                c_doubleprime_i = self.extractor_doubleprime(
                    padded_y1_hat, h_idx, w_idx)
                concatenated_c_i = np.concatenate([c_doubleprime_i, c_prime_i],
                                                  axis=1)

                pred_mean, pred_sigma = self.sess.run(
                    [self.pred_mean, self.pred_sigma],
                    feed_dict={self.concatenated_c_i: concatenated_c_i})

                zero_means = np.zeros([
                    pred_mean.shape[0], self.M2, pred_mean.shape[2],
                    pred_mean.shape[3]
                ])

                concat_pred_mean = np.concatenate([pred_mean, zero_means],
                                                  axis=1)
                concat_pred_sigma = np.concatenate([
                    pred_sigma, sigma2[:, :, h_idx:h_idx + 1, w_idx:w_idx + 1]
                ],
                                                   axis=1)

                for ch_idx in range(self.M):
                    mu_val = concat_pred_mean[0, ch_idx, 0, 0] + 255
                    sigma_val = concat_pred_sigma[0, ch_idx, 0, 0]

                    freq = arithmeticcoding.ModelFrequencyTable(
                        mu_val, sigma_val)

                    symbol = dec.read(freq)
                    if symbol == 512:  # EOF symbol
                        print("EOF symbol")
                        break
                    if ch_idx < self.M1:
                        padded_y1_hat[:, ch_idx, h_idx + 3,
                                      w_idx + 2] = symbol - 255
                    y_hat[:, ch_idx, h_idx, w_idx] = symbol - 255
        bitin.close()

        #################################################

        ###############
        gsh1 = self.sess.run(self.gsh1, feed_dict={self.y_hat: y_hat})
        gsh1 = gsh1[:, :res_h_3 - pad_h_3, :res_w_3 - pad_w_3, :]

        gsh2 = self.sess.run(self.gsh2, feed_dict={self.gsh1: gsh1})
        gsh2 = gsh2[:, :res_h_2 - pad_h_2, :res_w_2 - pad_w_2, :]

        gsh3 = self.sess.run(self.gsh3, feed_dict={self.gsh2: gsh2})
        gsh3 = gsh3[:, :res_h_1 - pad_h_1, :res_w_1 - pad_w_1, :]

        recon = self.sess.run(self.recon_image, feed_dict={self.gsh3: gsh3})
        recon = recon[0, :recon.shape[1] - pad_h, :recon.shape[2] - pad_w, :]
        ###############

        im = Image.fromarray(recon.astype(np.uint8))
        im.save(recon_path)

        return
コード例 #4
0
    def decode(self, compressed_file, recon_path):  # with TOP N dimensions

        fileobj = open(compressed_file, mode='rb')
        fileobj.read(1) #dummy
        buf = fileobj.read(4)
        arr = np.frombuffer(buf, dtype=np.uint16)
        w = int(arr[0])
        h = int(arr[1])


        padded_w = int(math.ceil(w / 16) * 16)
        padded_h = int(math.ceil(h / 16) * 16)

        y_hat, z_hat, sigma_z = self.sess.run([self.y_hat, self.z_hat, self.sigma_z],
                                                                    feed_dict={self.input_x: np.zeros((1, 3, padded_h, padded_w))}) # NCHW

        ############### decode zhat ####################################
        bitin = arithmeticcoding.BitInputStream(fileobj)
        dec = arithmeticcoding.ArithmeticDecoder(bitin)

        z_hat[:, :, :, :] = 0.0

        printProgressBar(0, z_hat.shape[1], prefix='Decoding z_hat:', suffix='Complete', length=50)
        for ch_idx in range(z_hat.shape[1]):
            printProgressBar(ch_idx + 1, z_hat.shape[1], prefix='Decoding z_hat:', suffix='Complete', length=50)
            mu_val = 255
            sigma_val = sigma_z[ch_idx]

            freq = arithmeticcoding.ModelFrequencyTable(mu_val, sigma_val)

            for h_idx in range(z_hat.shape[2]):
                for w_idx in range(z_hat.shape[3]):
                    symbol = dec.read(freq)
                    if symbol == 512:  # EOF symbol
                        print("EOF symbol")
                        break
                    z_hat[:, ch_idx, h_idx, w_idx] = symbol - 255


        ############### decode yhat ####################################
        c_prime = self.sess.run(self.c_prime, feed_dict={self.z_hat: z_hat})
        # c_prime = np.round(c_prime, decimals=4)

        padded_c_prime = np.pad(c_prime, ((0, 0), (0, 0), (3, 0), (2, 1)), 'constant',
                                constant_values=((0, 0), (0, 0), (0, 0), (0, 0)))

        padded_y_hat = np.pad(y_hat, ((0, 0), (0, 0), (3, 0), (2, 1)), 'constant',
                              constant_values=((0, 0), (0, 0), (0, 0), (0, 0)))
        padded_y_hat[:, :, :, :] = 0.0

        printProgressBar(0, y_hat.shape[2], prefix='Decoding y_hat:', suffix='Complete', length=50)
        for h_idx in range(y_hat.shape[2]):
            printProgressBar(h_idx + 1, y_hat.shape[2], prefix='Decoding y_hat:', suffix='Complete', length=50)
            for w_idx in range(y_hat.shape[3]):
                c_prime_i = self.extractor_prime(padded_c_prime, h_idx, w_idx)
                c_doubleprime_i = self.extractor_doubleprime(padded_y_hat, h_idx, w_idx)
                concatenated_c_i = np.concatenate([c_doubleprime_i, c_prime_i], axis=1)

                pred_mean, pred_sigma = self.sess.run(
                    [self.pred_mean, self.pred_sigma],
                    feed_dict={self.concatenated_c_i: concatenated_c_i})

                for ch_idx in range(self.M):
                    mu_val = pred_mean[0, ch_idx, 0, 0] + 255
                    sigma_val = pred_sigma[0, ch_idx, 0, 0]

                    freq = arithmeticcoding.ModelFrequencyTable(mu_val, sigma_val)

                    symbol = dec.read(freq)
                    if symbol == 512:  # EOF symbol
                        print("EOF symbol")
                        break
                    padded_y_hat[:, ch_idx, h_idx + 3, w_idx + 2] = symbol - 255

        bitin.close()
        y_hat = padded_y_hat[:, :, 3:, 2:-1]
        #################################################

        recon = self.sess.run(self.recon_image, {self.y_hat: y_hat})
        recon = recon[0, -h:, -w:, :]

        im = Image.fromarray(recon.astype(np.uint8))
        im.save(recon_path)

        return
コード例 #5
0
    def encode(self, model_type, input_path, compressed_file_path,
               quality_level):  # with TOP N dimensions

        img = Image.open(input_path)
        w, h = img.size

        fileobj = open(compressed_file_path, mode='wb')

        buf = quality_level << 1
        buf = buf + model_type
        arr = np.array([0], dtype=np.uint8)
        arr[0] = buf
        arr.tofile(fileobj)

        arr = np.array([w, h], dtype=np.uint16)
        arr.tofile(fileobj)
        fileobj.close()

        new_w = int(math.ceil(float(w) / 2.0) * 2)
        new_h = int(math.ceil(float(h) / 2.0) * 2)

        pad_w = new_w - w
        pad_h = new_h - h

        img_array = np.asarray(img)
        img_array = np.pad(img_array, ((0, pad_h), (0, pad_w), (0, 0)),
                           mode='edge')
        input_x = img_array

        input_x = input_x.reshape(1, new_h, new_w, 3)
        input_x = input_x.transpose([0, 3, 1, 2])

        gah1, sigma_z = self.sess.run([self.gah1, self.sigma_z],
                                      feed_dict={self.input_x: input_x})

        gah1 = np.pad(gah1, ((0, 0), (0, gah1.shape[1] % 2),
                             (0, gah1.shape[2] % 2), (0, 0)),
                      mode='constant')
        gah2 = self.sess.run(self.gah2, feed_dict={self.gah1: gah1})

        gah2 = np.pad(gah2, ((0, 0), (0, gah2.shape[1] % 2),
                             (0, gah2.shape[2] % 2), (0, 0)),
                      mode='constant')
        gah3 = self.sess.run(self.gah3, feed_dict={self.gah2: gah2})

        gah3 = np.pad(gah3, ((0, 0), (0, gah3.shape[1] % 2),
                             (0, gah3.shape[2] % 2), (0, 0)),
                      mode='constant')
        y_hat = self.sess.run(self.y_hat, feed_dict={self.gah3: gah3})

        y_w = y_hat.shape[3]
        y_h = y_hat.shape[2]
        new_y_w = int(math.ceil(float(y_w) / 4.0) * 4)
        new_y_h = int(math.ceil(float(y_h) / 4.0) * 4)
        pad_y_w = new_y_w - y_w
        pad_y_h = new_y_h - y_h
        pad_y_hat = np.pad(y_hat, ((0, 0), (0, 0), (0, pad_y_h), (0, pad_y_w)),
                           mode='symmetric')
        z_hat, c_prime = self.sess.run([self.z_hat, self.c_prime],
                                       feed_dict={self.y_hat: pad_y_hat})

        ############### encode zhat ####################################
        printProgressBar(0,
                         z_hat.shape[1],
                         prefix='Encoding z_hat:',
                         suffix='Complete',
                         length=50)
        bitout = arithmeticcoding.BitOutputStream(
            open(compressed_file_path, "ab+"))
        enc = arithmeticcoding.ArithmeticEncoder(bitout)

        for ch_idx in range(z_hat.shape[1]):
            printProgressBar(ch_idx + 1,
                             z_hat.shape[1],
                             prefix='Encoding z_hat:',
                             suffix='Complete',
                             length=50)
            mu_val = 255
            sigma_val = sigma_z[ch_idx]

            freq = arithmeticcoding.ModelFrequencyTable(mu_val, sigma_val)

            for h_idx in range(z_hat.shape[2]):
                for w_idx in range(z_hat.shape[3]):
                    symbol = np.int(z_hat[0, ch_idx, h_idx, w_idx] + 255)
                    if symbol < 0 or symbol > 511:
                        print("symbol range error: " + str(symbol))
                    enc.write(freq, symbol)

        ############### encode yhat ####################################
        padded_y_hat = np.pad(y_hat, ((0, 0), (0, 0), (3, 0), (2, 1)),
                              'constant',
                              constant_values=((0, 0), (0, 0), (0, 0), (0, 0)))

        padded_c_prime = np.pad(c_prime, ((0, 0), (0, 0), (3, 0), (2, 1)),
                                'constant',
                                constant_values=((0, 0), (0, 0), (0, 0), (0,
                                                                          0)))

        printProgressBar(0,
                         y_hat.shape[2],
                         prefix='Encoding y_hat:',
                         suffix='Complete',
                         length=50)
        for h_idx in range(y_hat.shape[2]):
            printProgressBar(h_idx + 1,
                             y_hat.shape[2],
                             prefix='Encoding y_hat:',
                             suffix='Complete',
                             length=50)
            for w_idx in range(y_hat.shape[3]):

                c_prime_i = self.extractor_prime(padded_c_prime, h_idx, w_idx)
                c_doubleprime_i = self.extractor_doubleprime(
                    padded_y_hat, h_idx, w_idx)
                concatenated_c_i = np.concatenate([c_doubleprime_i, c_prime_i],
                                                  axis=1)

                pred_mean, pred_sigma = self.sess.run(
                    [self.pred_mean, self.pred_sigma],
                    feed_dict={self.concatenated_c_i: concatenated_c_i})

                for ch_idx in range(self.M):
                    mu_val = pred_mean[0, ch_idx, 0, 0] + 255
                    sigma_val = pred_sigma[0, ch_idx, 0, 0]

                    freq = arithmeticcoding.ModelFrequencyTable(
                        mu_val, sigma_val)

                    symbol = np.int(y_hat[0, ch_idx, h_idx, w_idx] + 255)
                    if symbol < 0 or symbol > 511:
                        print("symbol range error: " + str(symbol))

                    enc.write(freq, symbol)

        enc.write(freq, 512)
        enc.finish()
        bitout.close()

        return compressed_file_path