def compress(inp, bitout):
    """
    Set up encoder and model. In this PPM model, symbol 256 represents EOF. Its frequency is 1 in the order -1
    context but its frequency is 0 in all other contexts (which have non-negative order).
    """
    enc = arithmeticcoding.ArithmeticEncoder(32, bitout)
    model = ppmmodel.PpmModel(MODEL_ORDER, 257, 256)
    history = []

    while True:
        # Read and encode one byte
        symbol = inp.read(1)
        if len(symbol) == 0:
            break
        symbol = symbol[0]
        encode_symbol(model, history, symbol, enc)
        model.increment_contexts(history, symbol)

        if model.model_order >= 1:
            # Prepend current symbol, dropping oldest symbol if necessary
            if len(history) == model.model_order:
                history.pop()
            history.insert(0, symbol)

    encode_symbol(model, history, 256, enc)  # EOF
    # Flush remaining code bits
    enc.finish()
Exemple #2
0
def entropy_coding(frame_index, lat, path_bin, latent, sigma, mu):

    if lat == 'mv':
        bias = 50
    else:
        bias = 100

    bin_name = 'f' + str(frame_index).zfill(3) + '_' + lat + '.bin'
    bitout = arithmeticcoding.BitOutputStream(open(path_bin + bin_name, "wb"))
    enc = arithmeticcoding.ArithmeticEncoder(32, bitout)

    for h in range(latent.shape[1]):
        for w in range(latent.shape[2]):
            for ch in range(latent.shape[3]):
                mu_val = mu[0, h, w, ch] + bias
                sigma_val = sigma[0, h, w, ch]
                symbol = latent[0, h, w, ch] + bias

                freq = arithmeticcoding.logFrequencyTable_exp(
                    mu_val, sigma_val, np.int(bias * 2 + 1))
                enc.write(freq, symbol)

    enc.finish()
    bitout.close()

    bits_value = os.path.getsize(path_bin + bin_name) * 8

    return bits_value
def main(args):
    # Argumentu apdorojimas
    if len(args) != 3:
        sys.exit("Usage: python ac-compress.py InputFile OutputFile BitsNum")

    input_file, output_file, bits_len = args[0], args[1], int(args[2])

    if bits_len < 1 or bits_len > 124:
        print bits_len, bits_len < 1, bits_len > 124
        sys.exit("Neleistina numBits reiksme, galimos reiksmes: [1 .. 124]")

    # Failo skaitymas ir dazniu lenteles sudarymas
    freqs = get_frequencies(input_file)

    enc = arithmeticcoding.ArithmeticEncoder(bits_len)

    max_total = enc.maximum_total
    freq_total = freqs.get_total()

    if freq_total > max_total:
        print(
            "Kumuliatyvus dazniu lenteles dydis yra perdidelis nuruodytam kodavimo intervalui."
        )
        print("Pradedama dazniu lenteles normalizacija.")
        freqs.normalize_freq_table(max_total)
        # sys.exit()

    # Failas skaitomas antra karta, pritaikomas aritmetiko kodavimas suspaudymas ir sukuriamas isvesties failas
    with open(input_file, "rb") as inp, \
            contextlib.closing(bitIO.BitOutputStream(open(output_file, "wb"))) as outp:
        write_header(outp, bits_len, freqs.get_bits_length())
        write_frequencies(outp, freqs,
                          freqs.get_bits_length())  # rasoma dazniu lentele
        enc.set_bitout(outp)
        compress(freqs, inp, enc)  # suspaudzia duomenys
def compress(inp, bitout):
    # Set up encoder and model. In this PPM model, symbol 256 represents EOF;
    # its frequency is 1 in the order -1 context but its frequency
    # is 0 in all other contexts (which have non-negative order).
    enc = arithmeticcoding.ArithmeticEncoder(bitout)
    model = ppmmodel.PpmModel(MODEL_ORDER, 257, 256)
    history = []

    while True:
        # Read and encode one byte
        symbol = inp.read(1)
        if len(symbol) == 0:
            break
        symbol = symbol[0] if python3 else ord(symbol)
        encode_symbol(model, history, symbol, enc)
        model.increment_contexts(history, symbol)

        if model.model_order >= 1:
            # Append current symbol or shift back by one
            if len(history) == model.model_order:
                del history[0]
            history.append(symbol)

    encode_symbol(model, history, 256, enc)  # EOF
    enc.finish()  # Flush remaining code bits
 def compress(quantized, output_file):
     """
     Function to load d
     
     Input:
     filename : Input hdf5 file consisting of training dataset
     
     Output:
     dataframe of paths to images dataset
     """
     data = pickle.dumps(quantized)
     with open(output_file, "wb") as file:
         bitout = arithmeticcoding.BitOutputStream(file)
         initfreqs = arithmeticcoding.FlatFrequencyTable(257)
         freqs = arithmeticcoding.SimpleFrequencyTable(initfreqs)
         enc = arithmeticcoding.ArithmeticEncoder(32, bitout)
         i = 0
         while i < len(data):
             # Read and encode one byte
             symbol = data[i]
             i += 1
             enc.write(freqs, symbol)
             freqs.increment(symbol)
         enc.write(freqs, 256)  # EOF
         enc.finish()  # Flush remaining code bits
 def start(self, dictionary_size=256):
     self.dictionary_size = dictionary_size
     self.bitout = arithmeticcoding.BitOutputStream(
         open(self.outputfile, "wb"))
     #self.freqsTable = arithmeticcoding.SimpleFrequencyTable([float(i % 8 + 1) for i in range(self.dictionary_size + 1)])
     self.freqsTable = arithmeticcoding.FlatFrequencyTable(
         self.dictionary_size + 1)
     self.encoder = arithmeticcoding.ArithmeticEncoder(32, self.bitout)
Exemple #7
0
def compress(freqs, inp, bitout):
    enc = arithmeticcoding.ArithmeticEncoder(32, bitout)
    while True:
        symbol = inp.read(1)
        if len(symbol) == 0:
            break
        enc.write(freqs, symbol[0])
    enc.write(freqs, 256)  # EOF
    enc.finish()  # Flush remaining code bits
Exemple #8
0
def compress(model):
    bit_out = arithmeticcoding.BitOutputStream(open('./result/data.bin', "wb"))
    enc = arithmeticcoding.ArithmeticEncoder(bit_out)
    # model = GenRNN(input_size=1, hidden_size=opt.hidden_size, output_size=len(characters), n_layers=opt.num_layers)
    # device = t.device(opt.device)
    # model.load_state_dict(t.load('./checkpoints/net_{}.pth'.format(opt.model_name, opt.chunk_len)))
    z = open('./result/old.txt', 'w')
    # model = model.to(device)
    # model.eval()
    hidden = None
    num_line = 0
    sum_all = 0
    time_num = 0
    acc_num = 0
    end_freq = generate_freqs(pro=1, first_step=True)
    with open('./result/test.qs') as f:
        while True:
            text = f.readline().replace('\n', '')
            z.write(text)
            z.write('\n')
            if not text:
                break
            encode_text = [char2int[char] for char in text]
            num_line += 1
            hidden = None
            for char_index in range(len(encode_text)):
                if char_index == 0:
                    freq = generate_freqs(pro=1, first_step=True)
                    sum_all += -np.log2(1 / 35.)
                    time_num += 8.0
                    # enc.write(freq, encode_text[char_index])
                else:
                    target_char = np.array(encode_text[char_index])
                    context_char = np.array(encode_text[char_index - 1])
                    out, hidden = predict(model, context_char, hidden)
                    out = out[0]  # (35, )
                    sum_all += -np.log2(out[target_char])
                    time_num += 8.0
                    freq = generate_freqs(pro=out, first_step=False)

                    if np.argmax(out) == target_char.astype(np.int):
                        acc_num += 1
                enc.write(freq, encode_text[char_index])
                end_freq = freq
            # enc.write(end_freq, 40)

            if num_line % 100 == 0:
                print(num_line)
            if num_line > 10000:
                break
    freq = generate_freqs(pro=1, first_step=True)
    # print(end_freq)
    enc.write(end_freq, len(characters))
    enc.finish()
    print(acc_num / time_num, sum_all / time_num)
def compress(inp, bitout):
    initfreqs = arithmeticcoding.FlatFrequencyTable(257)
    freqs = arithmeticcoding.SimpleFrequencyTable(initfreqs)
    enc = arithmeticcoding.ArithmeticEncoder(32, bitout)
    while True:
        # Read and encode one byte
        symbol = inp.read(1)
        if len(symbol) == 0:
            break
        enc.write(freqs, symbol[0])
        freqs.increment(symbol[0])
    enc.write(freqs, 256)  # EOF
    enc.finish()  # Flush remaining code bits
Exemple #10
0
def comparess(file1, model, indices_char):
    #this is painfully slow
    #if at all possible it should be revised so that it can mostly be run on the gpu
    #by painfully slow i mean on the order of .02 seconds per character guess.
    #ie ~16 minutes for a 50k character file.

    f1 = open(file1, 'r').read()
    data_size = len(f1)
    i = 0
    #output = [0, f1[0]]

    bitout = arithmeticcoding.BitOutputStream(open(file1 + '.comp', "wb"))
    initfreqs = arithmeticcoding.FlatFrequencyTable(AE_SIZE)
    freqs = arithmeticcoding.SimpleFrequencyTable(initfreqs)
    enc = arithmeticcoding.ArithmeticEncoder(bitout)
    guesses_right = 0
    gss = ''

    while i < data_size:
        current = ord(f1[i])
        if i < maxlen:
            enc.write(freqs,
                      0)  # Always 'guessing' zero correctly before maxlen
            freqs.increment(0)
            enc.write(freqs, current)
            freqs.increment(current)
        else:
            guess = predict(f1[(i - maxlen):i], model, indices_char)
            if (f1[i] == guess and guesses_right < 255):
                guesses_right += 1
                print("Guessed", f1[i], "correctly")
            else:
                enc.write(freqs, guesses_right)
                print("Wrong guess. Outputing", guesses_right,
                      "correct guesses")
                freqs.increment(guesses_right)
                print(i, "Outputing char", current)
                enc.write(freqs, current)
                freqs.increment(current)
                guesses_right = 0

        if (i % 100 == 0): print("i:", i)
        i += 1

    if guesses_right > 0:
        enc.write(freqs, guesses_right)
    enc.write(freqs, MAGIC_EOF)
    print("out eof sanity check")
    enc.finish()
    bitout.close()
    return None
def compress(inp, bitout):

    pmf = [0.2, 0.1, 0.3, 0.2, 0.2]
    pmf_new = [0.2, 0.1, 0.1, 0.3, 0.3]
    pmf = pmf_quantization(pmf)
    pmf_new = pmf_quantization(pmf_new)

    freqs = arithmeticcoding.ContextFrequencyTable(pmf)
    #freqs = arithmeticcoding.SimpleFrequencyTable(initfreqs)
    enc = arithmeticcoding.ArithmeticEncoder(32, bitout)
    for symbol in inp:
        # Read and encode one byte
        enc.write(freqs, symbol)
        freqs.increment(pmf_new)
    enc.write(freqs, 4)  # EOF
    enc.finish()  # Flush remaining code bits
Exemple #12
0
def compress(snp, numsymbol):
	initfreqs = arithmeticcoding.FlatFrequencyTable(numsymbol)
	freqs = arithmeticcoding.SimpleFrequencyTable(initfreqs)
	enc = arithmeticcoding.ArithmeticEncoder(32)

	snp  = np.squeeze(snp)
	rows,cols,channel = snp.shape
	
	for c in range(channel):      
		for i in range(rows):
			for j in range(cols):
		# Read and encode one byte
				symbol = snp[i,j,c]	 
				enc.write(freqs, symbol)
				freqs.increment(symbol)
	enc.write(freqs, numsymbol-1)  # EOF
	enc.finish()  # Flush remaining code bits
	return enc.bit_nums
Exemple #13
0
	def compress(self, inp, freqs, num_symbols):
		enc = arithmeticcoding.ArithmeticEncoder()
		for i in range(len(inp)):
			symbol = inp[i]
			enc.write(freqs, symbol)
		enc.write(freqs, num_symbols)  # EOF
		self.s = enc.finish()  # Flush remaining code bits
		return self.s


	# Writes an unsigned integer of the given bit width to the given stream.
#	def write_int(bitout, numbits, value):
#		for i in reversed(range(numbits)):
#			bitout.write((value >> i) & 1)  # Big endian


	# Main launcher
	#if __name__ == "__main__":
	#	main(sys.argv[1 : ])
Exemple #14
0
    def compressTree(
        self, node, overall_freqs, N
    ):  #n is the number of nodes in the hidden layer and pw is the list of all the normalized probability; use cummulative frequencies, then,
        #won't have to normalize
        enc = arithmeticcoding.ArithmeticEncoder()
        q = deque([node])
        #self.j = 0
        while len(q) != 0:
            temp = q.popleft()
            if temp.v > 1:
                tempValue = temp.v
                i = 0
                for child in temp.childNodes:
                    if child != None:

                        if tempValue > 0:
                            q.append(child)
                            binomial_frequencies = ec(
                            ).binomial_encoder_frequencies(
                                overall_freqs[i:], tempValue
                            )  # binomial encoder can convert to frequencies. convert to binary independently and check compression ratio for confirming correct amount of compression
                            freqs = arithmeticcoding.SimpleFrequencyTable(
                                binomial_frequencies)
                            enc.write(freqs, child.v)
                            tempValue = tempValue - child.v
                            #a = a + '1011'
                            i += 1
                            #print('Compressing Tree...',self.j)
                            #self.j += 1
                        #print (i)
            elif temp.v == 1:
                for child in temp.childNodes:
                    if child != None:
                        if child.v == 1:
                            symbol = child.c
                            q.append(child)
                            freqs = arithmeticcoding.SimpleFrequencyTable(
                                overall_freqs)
                            enc.write(freqs, symbol)

        compressed_tree = enc.finish()

        return compressed_tree
def compress(inp, bitout):
    # Set up encoder and model
    enc = arithmeticcoding.ArithmeticEncoder(bitout)
    model = ppmmodel.PpmModel(MODEL_ORDER, 257, 256)
    history = []

    while True:
        # Read and encode one byte
        symbol = inp.read(1)
        if len(symbol) == 0:
            break
        symbol = symbol[0] if python3 else ord(symbol)
        encode_symbol(model, history, symbol, enc)
        model.increment_contexts(history, symbol)

        if model.model_order >= 1:
            # Append current symbol or shift back by one
            if len(history) == model.model_order:
                del history[0]
            history.append(symbol)

    encode_symbol(model, history, 256, enc)  # EOF
    enc.finish()  # Flush remaining code bits
Exemple #16
0
 def openFileLeft(self):
     enc = arithmeticcoding.ArithmeticEncoder(self.bitoutL)
     return enc
    def encode(self, model_type, input_path, compressed_file_path,
               quality_level):  # with TOP N dimensions

        img = Image.open(input_path)
        w, h = img.size

        fileobj = open(compressed_file_path, mode='wb')

        buf = quality_level << 1
        buf = buf + model_type
        arr = np.array([0], dtype=np.uint8)
        arr[0] = buf
        arr.tofile(fileobj)

        arr = np.array([w, h], dtype=np.uint16)
        arr.tofile(fileobj)
        fileobj.close()

        new_w = int(math.ceil(w / 16) * 16)
        new_h = int(math.ceil(h / 16) * 16)

        pad_w = new_w - w
        pad_h = new_h - h

        input_x = np.asarray(img)
        input_x = np.pad(input_x, ((pad_h, 0), (pad_w, 0), (0, 0)),
                         mode='reflect')
        input_x = input_x.reshape(1, new_h, new_w, 3)
        input_x = input_x.transpose([0, 3, 1, 2])

        h_s_out, y_hat, z_hat, sigma_z = self.sess.run(
            [self.h_s_out, self.y_hat, self.z_hat, self.sigma_z],
            feed_dict={self.input_x: input_x})  # NCHW

        ############### encode z ####################################
        bitout = arithmeticcoding.BitOutputStream(
            open(compressed_file_path, "ab+"))
        enc = arithmeticcoding.ArithmeticEncoder(bitout)

        printProgressBar(0,
                         z_hat.shape[1],
                         prefix='Encoding z_hat:',
                         suffix='Complete',
                         length=50)
        for ch_idx in range(z_hat.shape[1]):
            printProgressBar(ch_idx + 1,
                             z_hat.shape[1],
                             prefix='Encoding z_hat:',
                             suffix='Complete',
                             length=50)
            mu_val = 255
            sigma_val = sigma_z[ch_idx]
            # exp_sigma_val = np.exp(sigma_val)

            freq = arithmeticcoding.ModelFrequencyTable(mu_val, sigma_val)

            for h_idx in range(z_hat.shape[2]):
                for w_idx in range(z_hat.shape[3]):
                    symbol = np.rint(z_hat[0, ch_idx, h_idx, w_idx] + 255)
                    if symbol < 0 or symbol > 511:
                        print("symbol range error: " + str(symbol))

                    # print(symbol)
                    enc.write(freq, symbol)

        # enc.write(freq, 512)
        # enc.finish()
        # bitout.close()

        ############### encode y ####################################
        padded_y1_hat = np.pad(y_hat[:, :self.M1, :, :],
                               ((0, 0), (0, 0), (3, 0), (2, 1)),
                               'constant',
                               constant_values=((0, 0), (0, 0), (0, 0), (0,
                                                                         0)))

        # bitout = arithmeticcoding.BitOutputStream(open(enc_outputfile, "wb"))
        # enc = arithmeticcoding.ArithmeticEncoder(bitout)

        c_prime = h_s_out[:, :self.M1, :, :]
        sigma2 = h_s_out[:, self.M1:, :, :]
        padded_c_prime = np.pad(c_prime, ((0, 0), (0, 0), (3, 0), (2, 1)),
                                'constant',
                                constant_values=((0, 0), (0, 0), (0, 0), (0,
                                                                          0)))

        printProgressBar(0,
                         y_hat.shape[2],
                         prefix='Encoding y_hat:',
                         suffix='Complete',
                         length=50)
        for h_idx in range(y_hat.shape[2]):
            printProgressBar(h_idx + 1,
                             y_hat.shape[2],
                             prefix='Encoding y_hat:',
                             suffix='Complete',
                             length=50)
            for w_idx in range(y_hat.shape[3]):
                c_prime_i = self.extractor_prime(padded_c_prime, h_idx, w_idx)
                c_doubleprime_i = self.extractor_doubleprime(
                    padded_y1_hat, h_idx, w_idx)
                concatenated_c_i = np.concatenate([c_doubleprime_i, c_prime_i],
                                                  axis=1)

                pred_mean, pred_sigma = self.sess.run(
                    [self.pred_mean, self.pred_sigma],
                    feed_dict={self.concatenated_c_i: concatenated_c_i})

                zero_means = np.zeros([
                    pred_mean.shape[0], self.M2, pred_mean.shape[2],
                    pred_mean.shape[3]
                ])

                concat_pred_mean = np.concatenate([pred_mean, zero_means],
                                                  axis=1)
                concat_pred_sigma = np.concatenate([
                    pred_sigma, sigma2[:, :, h_idx:h_idx + 1, w_idx:w_idx + 1]
                ],
                                                   axis=1)

                for ch_idx in range(self.M):
                    mu_val = concat_pred_mean[0, ch_idx, 0, 0] + 255
                    sigma_val = concat_pred_sigma[0, ch_idx, 0, 0]
                    # exp_sigma_val = np.exp(sigma_val)

                    freq = arithmeticcoding.ModelFrequencyTable(
                        mu_val, sigma_val)

                    symbol = np.rint(y_hat[0, ch_idx, h_idx, w_idx] + 255)
                    if symbol < 0 or symbol > 511:
                        print("symbol range error: " + str(symbol))
                    enc.write(freq, symbol)
        enc.write(freq, 512)
        enc.finish()
        bitout.close()

        return compressed_file_path
Exemple #18
0
    def inferenceNN(
        self, x, M, N, overall_freqs, L, activationFunction
    ):  #N is the number of hidden nodes, the weights are of dimension MxN
        y = [0 for i in range(N)]
        enc = arithmeticcoding.ArithmeticEncoder()
        dec = arithmeticcoding.ArithmeticDecoder(L)
        q = deque([N])
        #q_node = deque([node])
        self.w = 0
        tot_queue_length = floor(2 * log2(N + 1) + 1)
        max_queue_length = floor(2 * log2(N + 1) + 1)
        current_queue_length = floor(2 * log2(N + 1) + 1)
        j = 0
        level = 0
        flag = 0
        flagp = 0
        k = len(overall_freqs)
        print('M:', M, 'N:', N)
        while len(q) != 0 and level < M:
            currentNodeValue = q.popleft()
            current_queue_length -= floor(2 * log2(currentNodeValue + 1) + 1)

            if flagp == 0:
                print('current_queue_length', current_queue_length)
                flagp = 1
            #currentnode = q_node.popleft()
            if currentNodeValue > 1:
                c = 0  #colour initialized with 0
                while c <= k - 1 and currentNodeValue > 0:  #kth colour need not be encoded
                    binomial_frequencies = ec().binomial_encoder_frequencies(
                        overall_freqs[c:], currentNodeValue)
                    freqs = arithmeticcoding.SimpleFrequencyTable(
                        binomial_frequencies)
                    childNodeValue = dec.read(freqs)
                    #if childNodeValue != currentnode.childNodes[c].v:
                    #	print('Not Matching!', childNodeValue, currentnode.childNodes[c].v)
                    #else:
                    #	print('No problems here')
                    enc.write(freqs, childNodeValue)
                    currentNodeValue -= childNodeValue
                    q.append(childNodeValue)
                    current_queue_length += floor(2 *
                                                  log2(childNodeValue + 1) + 1)
                    max_queue_length = max(max_queue_length,
                                           current_queue_length)
                    tot_queue_length += current_queue_length
                    self.w += 1
                    #q_node.append(currentnode.childNodes[c])
                    #print('childNodeValue',childNodeValue)
                    if childNodeValue > 0:
                        flag = 1
                    for i in range(childNodeValue):

                        #	print('level:',level,'x[level]',x[level])
                        #	print('Calculating Y....', level,':',self.w)
                        y[j + i] += uc().index_to_weight(c) * x[level]
                        #print(x[level], c)
                        #y[j+i] += c*x[level]
                    c = c + 1
                    j = (j + childNodeValue) % N
                    if j == 0 and flag:
                        level = level + 1
                        #print('level:',level)
                        flag = 0
            elif currentNodeValue == 1:
                freqs = arithmeticcoding.SimpleFrequencyTable(overall_freqs)
                c = dec.read(freqs)
                enc.write(freqs, c)
                q.append(1)
                current_queue_length += 3
                max_queue_length = max(max_queue_length, current_queue_length)
                tot_queue_length += current_queue_length

                self.w += 1
                y[j + i] += uc().index_to_weight(c) * x[level]
                j = (j + 1) % N
                if j == 0:
                    level += 1

        avg_queue_length = tot_queue_length / self.w

        L1 = enc.finish()  #return L1 if needed
        y = np.array(y)
        if activationFunction == 'ReLU':
            y = uc().ReLU(y)
        elif activationFunction == 'sigmoid':
            y = uc().sigmoid(y)
        elif activationFunction == None:
            y = y
        return y, avg_queue_length, max_queue_length
    def encode(self, model_type, input_path, compressed_file_path,
               quality_level):  # with TOP N dimensions

        img = Image.open(input_path)
        w, h = img.size

        fileobj = open(compressed_file_path, mode='wb')

        buf = quality_level << 1
        buf = buf + model_type
        arr = np.array([0], dtype=np.uint8)
        arr[0] = buf
        arr.tofile(fileobj)

        arr = np.array([w, h], dtype=np.uint16)
        arr.tofile(fileobj)
        fileobj.close()

        new_w = int(math.ceil(float(w) / 2.0) * 2)
        new_h = int(math.ceil(float(h) / 2.0) * 2)

        pad_w = new_w - w
        pad_h = new_h - h

        img_array = np.asarray(img)
        img_array = np.pad(img_array, ((0, pad_h), (0, pad_w), (0, 0)),
                           mode='edge')
        input_x = img_array

        input_x = input_x.reshape(1, new_h, new_w, 3)
        input_x = input_x.transpose([0, 3, 1, 2])

        gah1, sigma_z = self.sess.run([self.gah1, self.sigma_z],
                                      feed_dict={self.input_x: input_x})

        gah1 = np.pad(gah1, ((0, 0), (0, gah1.shape[1] % 2),
                             (0, gah1.shape[2] % 2), (0, 0)),
                      mode='constant')
        gah2 = self.sess.run(self.gah2, feed_dict={self.gah1: gah1})

        gah2 = np.pad(gah2, ((0, 0), (0, gah2.shape[1] % 2),
                             (0, gah2.shape[2] % 2), (0, 0)),
                      mode='constant')
        gah3 = self.sess.run(self.gah3, feed_dict={self.gah2: gah2})

        gah3 = np.pad(gah3, ((0, 0), (0, gah3.shape[1] % 2),
                             (0, gah3.shape[2] % 2), (0, 0)),
                      mode='constant')
        y_hat = self.sess.run(self.y_hat, feed_dict={self.gah3: gah3})

        y_w = y_hat.shape[3]
        y_h = y_hat.shape[2]
        new_y_w = int(math.ceil(float(y_w) / 4.0) * 4)
        new_y_h = int(math.ceil(float(y_h) / 4.0) * 4)
        pad_y_w = new_y_w - y_w
        pad_y_h = new_y_h - y_h
        pad_y_hat = np.pad(y_hat, ((0, 0), (0, 0), (0, pad_y_h), (0, pad_y_w)),
                           mode='symmetric')
        z_hat, c_prime = self.sess.run([self.z_hat, self.c_prime],
                                       feed_dict={self.y_hat: pad_y_hat})

        ############### encode zhat ####################################
        printProgressBar(0,
                         z_hat.shape[1],
                         prefix='Encoding z_hat:',
                         suffix='Complete',
                         length=50)
        bitout = arithmeticcoding.BitOutputStream(
            open(compressed_file_path, "ab+"))
        enc = arithmeticcoding.ArithmeticEncoder(bitout)

        for ch_idx in range(z_hat.shape[1]):
            printProgressBar(ch_idx + 1,
                             z_hat.shape[1],
                             prefix='Encoding z_hat:',
                             suffix='Complete',
                             length=50)
            mu_val = 255
            sigma_val = sigma_z[ch_idx]

            freq = arithmeticcoding.ModelFrequencyTable(mu_val, sigma_val)

            for h_idx in range(z_hat.shape[2]):
                for w_idx in range(z_hat.shape[3]):
                    symbol = np.int(z_hat[0, ch_idx, h_idx, w_idx] + 255)
                    if symbol < 0 or symbol > 511:
                        print("symbol range error: " + str(symbol))
                    enc.write(freq, symbol)

        ############### encode yhat ####################################
        padded_y_hat = np.pad(y_hat, ((0, 0), (0, 0), (3, 0), (2, 1)),
                              'constant',
                              constant_values=((0, 0), (0, 0), (0, 0), (0, 0)))

        padded_c_prime = np.pad(c_prime, ((0, 0), (0, 0), (3, 0), (2, 1)),
                                'constant',
                                constant_values=((0, 0), (0, 0), (0, 0), (0,
                                                                          0)))

        printProgressBar(0,
                         y_hat.shape[2],
                         prefix='Encoding y_hat:',
                         suffix='Complete',
                         length=50)
        for h_idx in range(y_hat.shape[2]):
            printProgressBar(h_idx + 1,
                             y_hat.shape[2],
                             prefix='Encoding y_hat:',
                             suffix='Complete',
                             length=50)
            for w_idx in range(y_hat.shape[3]):

                c_prime_i = self.extractor_prime(padded_c_prime, h_idx, w_idx)
                c_doubleprime_i = self.extractor_doubleprime(
                    padded_y_hat, h_idx, w_idx)
                concatenated_c_i = np.concatenate([c_doubleprime_i, c_prime_i],
                                                  axis=1)

                pred_mean, pred_sigma = self.sess.run(
                    [self.pred_mean, self.pred_sigma],
                    feed_dict={self.concatenated_c_i: concatenated_c_i})

                for ch_idx in range(self.M):
                    mu_val = pred_mean[0, ch_idx, 0, 0] + 255
                    sigma_val = pred_sigma[0, ch_idx, 0, 0]

                    freq = arithmeticcoding.ModelFrequencyTable(
                        mu_val, sigma_val)

                    symbol = np.int(y_hat[0, ch_idx, h_idx, w_idx] + 255)
                    if symbol < 0 or symbol > 511:
                        print("symbol range error: " + str(symbol))

                    enc.write(freq, symbol)

        enc.write(freq, 512)
        enc.finish()
        bitout.close()

        return compressed_file_path
Exemple #20
0
 def openFileRight(self):
     enc = arithmeticcoding.ArithmeticEncoder(self.bitoutR)
     return enc