from torch.utils.tensorboard import SummaryWriter from torch.utils.data import Dataset, DataLoader from data import Dictionary from custom_embedder_recurrent import CustomEmbedder from optimizer import RAdam import tqdm import transformers tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2") gpt2 = transformers.GPT2Model.from_pretrained("gpt2") embedding = gpt2.wte vocab = tokenizer.decoder dictionary = Dictionary() dictionary.word2idx = {v: int(k) for k, v in vocab.items()} dictionary.idx2word = {int(k): v for k, v in vocab.items()} model = CustomEmbedder(dictionary, 768) embedding.weight.requires_grad = False model = model.cuda() optimizer = RAdam(model.parameters(), lr=0.001) writer = SummaryWriter() class EDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return self.data.shape[0]
def decompress(self, compressedfile): start = time.time() filename_split = compressedfile.split('_') checkpoint = torch.load(compressedfile, map_location=self.device) body = checkpoint['bytes'] dictionary = Dictionary() dictionary.word2idx = checkpoint['word2idx'] dictionary.idx2word = checkpoint['idx2word'] context_map = Context(dictionary) ntokens = len(dictionary) model = RNNModel('LSTM', ntokens, 200, 200, 2, dropout=0.2, tie_weights=False) model.load_state_dict(checkpoint['model_state_dict']) model.to(self.device) model.eval() bit_string = '' join_body = list(body) for i in join_body: bit_string += "{0:08b}".format(i) encoded_text = self.remove_padding(bit_string) # decompress start here current_code = '' decoded_text = '' # we define an initial context # then we predict the initial huffman tree # read bits until we get to a leaf # convert the leaf to a char and add it to decompressed text # update the context and repeat the process context = ['<s>'] * 10 def tree_from_context(context): huffman = HuffmanCoding() prob = huffman.make_context_frequency_dict( context, model, context_map, self.device, threshold=self.args.threshold) huffman.make_heap_node(prob) huffman.merge_nodes() huffman.encode() huffman.reverse_mapping = {v: k for k, v in huffman.codes.items()} return huffman huffman = tree_from_context(context) fixed_huffman = HuffmanCoding() counts = checkpoint['fixed_huffman_counts'] fixed_huffman.make_heap_node(counts) fixed_huffman.merge_nodes() fixed_huffman.encode() fixed_huffman.reverse_mapping = { v: k for k, v in fixed_huffman.codes.items() } flag = None for bit in encoded_text: if flag == '0': current_code += bit if current_code in huffman.reverse_mapping: next_char = huffman.reverse_mapping[current_code] decoded_text += next_char current_code = '' context = context[1:] + [next_char] huffman = tree_from_context(context) flag = None continue elif flag == '1': current_code += bit if current_code in fixed_huffman.reverse_mapping: next_char = fixed_huffman.reverse_mapping[current_code] decoded_text += next_char current_code = '' context = context[1:] + [next_char] huffman = tree_from_context(context) flag = None continue else: flag = bit # write decompressed file with open(filename_split[0] + "_decompressed.txt", 'w') as f: f.writelines(decoded_text) print('Decompression Done!') end = time.time() print(round((end - start), 3), "s")