def decompress(quantbits, nz, gpu, state, nblocks): # model and compression params zdim = 8 * 16 * 16 zrange = torch.arange(zdim) xdim = 32**2 * 3 xrange = torch.arange(xdim) ansbits = 31 # ANS precision type = torch.float64 # datatype throughout compression device = "cpu" if gpu < 0 else f"cuda:{gpu}" # gpu # set up the different channel dimension reswidth = 256 # <=== MODEL ===> model = Model(xs=(3, 32, 32), nz=nz, zchannels=8, nprocessing=4, kernel_size=3, resdepth=8, reswidth=reswidth).to(device) model.load_state_dict( torch.load(f'model/params/imagenetcrop/nz4', map_location=lambda storage, location: storage)) model.eval() # get discretization bins for latent variables zendpoints, zcentres = discretize(nz, quantbits, type, device, model, "imagenetcrop") # get discretization bins for discretized logistic xbins = ImageBins(type, device, xdim) xendpoints = xbins.endpoints() xcentres = xbins.centres() # < ===== COMPRESSION ===> # initialize compression model.compress() # compression experiment params blocks = np.zeros((nblocks, 32, 32, 3), dtype=np.uint8) # <===== RECEIVER =====> iterator = tqdm(range(nblocks), desc="Decompression") for xi in iterator: # prior cdfs = logistic_cdf(zendpoints[-1].t(), torch.zeros(1, device=device, dtype=type), torch.ones(1, device=device, dtype=type)).t() pmfs = cdfs[:, 1:] - cdfs[:, :-1] pmfs = torch.cat( (cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)), dim=1) # decode z state, zsymtop = ANS(pmfs, bits=ansbits, quantbits=quantbits).decode(state) # < ===== Bit-Swap ====> # inference and generative model for zi in reversed(range(nz)): # generative model z = zcentres[zi, zrange, zsymtop] mu, scale = model.generate(zi)(given=z) cdfs = logistic_cdf( (zendpoints[zi - 1] if zi > 0 else xendpoints).t(), mu, scale).t() # most expensive calculation? pmfs = cdfs[:, 1:] - cdfs[:, :-1] pmfs = torch.cat( (cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)), dim=1) # decode z or x state, sym = ANS( pmfs, bits=ansbits, quantbits=quantbits if zi > 0 else 8).decode(state) # inference model input = zcentres[zi - 1, zrange, sym] if zi > 0 else xcentres[xrange, sym] mu, scale = model.infer(zi)(given=input) cdfs = logistic_cdf(zendpoints[zi].t(), mu, scale).t() # most expensive calculation? pmfs = cdfs[:, 1:] - cdfs[:, :-1] pmfs = torch.cat( (cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)), dim=1) # encode z state = ANS(pmfs, bits=ansbits, quantbits=quantbits).encode(state, zsymtop) zsymtop = sym # reshape to 32x32 pixel-block with 3 color channels im = zsymtop.clone().view(3, 32, 32).detach().cpu() blocks[blocks.shape[0] - xi - 1] = np.array(im, dtype=np.uint8).transpose( (1, 2, 0)) return blocks
def compress(quantbits, nz, gpu, blocks): # model and compression params zdim = 8 * 16 * 16 zrange = torch.arange(zdim) xdim = 32**2 * 3 xrange = torch.arange(xdim) ansbits = 31 # ANS precision type = torch.float64 # datatype throughout compression device = "cpu" if gpu < 0 else f"cuda:{gpu}" # gpu # set up the different channel dimension reswidth = 256 # <=== MODEL ===> model = Model(xs=(3, 32, 32), nz=nz, zchannels=8, nprocessing=4, kernel_size=3, resdepth=8, reswidth=reswidth).to(device) model.load_state_dict( torch.load(f'model/params/imagenetcrop/nz4', map_location=lambda storage, location: storage)) model.eval() # get discretization bins for latent variables zendpoints, zcentres = discretize(nz, quantbits, type, device, model, "imagenetcrop") class ToInt: def __call__(self, pic): return pic * 255 transform_ops = transforms.Compose([transforms.ToTensor(), ToInt()]) # get discretization bins for discretized logistic xbins = ImageBins(type, device, xdim) xendpoints = xbins.endpoints() xcentres = xbins.centres() # compression experiment params nblocks = blocks.shape[0] # < ===== COMPRESSION ===> # initialize compression model.compress() excess_state_len = 10000 state = list( map( int, np.random.randint( low=1 << 16, high=(1 << 32) - 1, size=excess_state_len, dtype=np.uint32))) # fill state list with 'random' bits state[-1] = state[-1] << 32 # <===== SENDER =====> iterator = tqdm(range(nblocks), desc="Bit-Swap") for xi in iterator: x = transform_ops(Image.fromarray(blocks[xi])).to(device).view(xdim) # < ===== Bit-Swap ====> # inference and generative model for zi in range(nz): # inference model input = zcentres[zi - 1, zrange, zsym] if zi > 0 else xcentres[xrange, x.long()] mu, scale = model.infer(zi)(given=input) cdfs = logistic_cdf(zendpoints[zi].t(), mu, scale).t() pmfs = cdfs[:, 1:] - cdfs[:, :-1] pmfs = torch.cat( (cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)), dim=1) # decode z state, zsymtop = ANS(pmfs, bits=ansbits, quantbits=quantbits).decode(state) # save excess state length for calculations # print("initial bits taken") if len(state) < excess_state_len else None excess_state_len = len( state) if len(state) < excess_state_len else excess_state_len # generative model z = zcentres[zi, zrange, zsymtop] mu, scale = model.generate(zi)(given=z) cdfs = logistic_cdf( (zendpoints[zi - 1] if zi > 0 else xendpoints).t(), mu, scale).t() # most expensive calculation? pmfs = cdfs[:, 1:] - cdfs[:, :-1] pmfs = torch.cat( (cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)), dim=1) # encode z or x state = ANS(pmfs, bits=ansbits, quantbits=(quantbits if zi > 0 else 8)).encode( state, zsym if zi > 0 else x.long()) zsym = zsymtop # prior cdfs = logistic_cdf(zendpoints[-1].t(), torch.zeros(1, device=device, dtype=type), torch.ones(1, device=device, dtype=type)).t() pmfs = cdfs[:, 1:] - cdfs[:, :-1] pmfs = torch.cat( (cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)), dim=1) # encode prior state = ANS(pmfs, bits=ansbits, quantbits=quantbits).encode(state, zsymtop) # remove excess streams del state[0:excess_state_len - 1] return state
def compress(quantbits, nz, gpu, blocks): # model and compression params zdim = 8*16*16 zrange = torch.arange(zdim) xdim = 32**2 * 3 xrange = torch.arange(xdim) ansbits = 31 # ANS precision type = torch.float64 # datatype throughout compression device = f"cuda:{gpu}" # gpu # set up the different channel dimension reswidth = 256 # seed for replicating experiment and stability np.random.seed(100) random.seed(50) torch.manual_seed(50) torch.cuda.manual_seed(50) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True # <=== MODEL ===> model = Model(xs = (3, 32, 32), nz=nz, zchannels=8, nprocessing=4, kernel_size=3, resdepth=8, reswidth=reswidth).to(device) model.load_state_dict( torch.load(f'model/params/imagenetcrop/nz4', map_location=lambda storage, location: storage ) ) model.eval() # get discretization bins for latent variables zendpoints, zcentres = discretize(nz, quantbits, type, device, model, "imagenet") # get discretization bins for discretized logistic xbins = ImageBins(type, device, xdim) xendpoints = xbins.endpoints() xcentres = xbins.centres() # <=== DATA ===> class ToInt: def __call__(self, pic): return pic * 255 transform_ops = transforms.Compose([transforms.ToTensor(), ToInt()]) # compression experiment params nblocks, h, w, c = blocks.shape # < ===== COMPRESSION ===> # initialize compression model.compress() state = list(map(int, np.random.randint(low=1 << 16, high=(1 << 32) - 1, size=10000, dtype=np.uint32))) # fill state list with 'random' bits state[-1] = state[-1] << 32 restbits = None # <===== SENDER =====> iterator = tqdm(range(nblocks), desc="Compression") for xi in iterator: x = blocks[xi] x = transform_ops(Image.fromarray(x)).to(device).view(xdim) # < ===== Bit-Swap ====> # inference and generative model for zi in range(nz): # inference model input = zcentres[zi - 1, zrange, zsym] if zi > 0 else xcentres[xrange, x.long()] mu, scale = model.infer(zi)(given=input) cdfs = logistic_cdf(zendpoints[zi].t(), mu, scale).t() # most expensive calculation? pmfs = cdfs[:, 1:] - cdfs[:, :-1] pmfs = torch.cat((cdfs[:,0].unsqueeze(1), pmfs, 1. - cdfs[:,-1].unsqueeze(1)), dim=1) # decode z state, zsymtop = ANS(pmfs, bits=ansbits, quantbits=quantbits).decode(state) # save excess bits for calculations if xi == zi == 0: restbits = state.copy() assert len(restbits) > 1, "too few initial bits" # otherwise initial state consists of too few bits # generative model z = zcentres[zi, zrange, zsymtop] mu, scale = model.generate(zi)(given=z) cdfs = logistic_cdf((zendpoints[zi - 1] if zi > 0 else xendpoints).t(), mu, scale).t() # most expensive calculation? pmfs = cdfs[:, 1:] - cdfs[:, :-1] pmfs = torch.cat((cdfs[:,0].unsqueeze(1), pmfs, 1. - cdfs[:,-1].unsqueeze(1)), dim=1) # encode z or x state = ANS(pmfs, bits=ansbits, quantbits=(quantbits if zi > 0 else 8)).encode(state, zsym if zi > 0 else x.long()) zsym = zsymtop # prior cdfs = logistic_cdf(zendpoints[-1].t(), torch.zeros(1, device=device, dtype=type), torch.ones(1, device=device, dtype=type)).t() pmfs = cdfs[:, 1:] - cdfs[:, :-1] pmfs = torch.cat((cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)), dim=1) # encode prior state = ANS(pmfs, bits=ansbits, quantbits=quantbits).encode(state, zsymtop) # calculating bits totalbits = (len(state) - (len(restbits) - 1)) * 32 bitsperdim = totalbits / (nblocks * h * w * c) return bitsperdim, state
def decompress(quantbits, nz, gpu, state): # model and compression params zdim = 8*16*16 zrange = torch.arange(zdim) xdim = 32**2 * 3 xrange = torch.arange(xdim) ansbits = 31 # ANS precision type = torch.float64 # datatype throughout compression device = f"cuda:{gpu}" # gpu # set up the different channel dimension reswidth = 256 # seed for replicating experiment and stability np.random.seed(100) random.seed(50) torch.manual_seed(50) torch.cuda.manual_seed(50) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True # <=== MODEL ===> model = Model(xs = (3, 32, 32), nz=nz, zchannels=8, nprocessing=4, kernel_size=3, resdepth=8, reswidth=reswidth).to(device) model.load_state_dict( torch.load(f'model/params/imagenetcrop/nz4', map_location=lambda storage, location: storage ) ) model.eval() # get discretization bins for latent variables zendpoints, zcentres = discretize(nz, quantbits, type, device, model, "imagenet") # get discretization bins for discretized logistic xbins = ImageBins(type, device, xdim) xendpoints = xbins.endpoints() xcentres = xbins.centres() # < ===== COMPRESSION ===> # initialize compression model.compress() # compression experiment params nblocks, h, w, c = 30, 32, 32, 3 blocks = torch.zeros(nblocks, h, w, c) # <===== RECEIVER =====> # prior cdfs = logistic_cdf(zendpoints[-1].t(), torch.zeros(1, device=device, dtype=type), torch.ones(1, device=device, dtype=type)).t() pmfs = cdfs[:, 1:] - cdfs[:, :-1] pmfs = torch.cat((cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)), dim=1) # decode z state, zsymtop = ANS(pmfs, bits=ansbits, quantbits=quantbits).decode(state) # <===== SENDER =====> iterator = tqdm(range(nblocks), desc="Compression") for xi in iterator: # < ===== Bit-Swap ====> # inference and generative model for zi in reversed(range(nz)): # generative model z = zcentres[zi, zrange, zsymtop] mu, scale = model.generate(zi)(given=z) cdfs = logistic_cdf((zendpoints[zi - 1] if zi > 0 else xendpoints).t(), mu, scale).t() # most expensive calculation? pmfs = cdfs[:, 1:] - cdfs[:, :-1] pmfs = torch.cat((cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)), dim=1) # decode z or x state, sym = ANS(pmfs, bits=ansbits, quantbits=quantbits if zi > 0 else 8).decode(state) # inference model input = zcentres[zi - 1, zrange, sym] if zi > 0 else xcentres[xrange, sym] mu, scale = model.infer(zi)(given=input) cdfs = logistic_cdf(zendpoints[zi].t(), mu, scale).t() # most expensive calculation? pmfs = cdfs[:, 1:] - cdfs[:, :-1] pmfs = torch.cat((cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)), dim=1) # encode z state = ANS(pmfs, bits=ansbits, quantbits=quantbits).encode(state, zsymtop) zsymtop = sym # reshape to 32x32 pixel-block with 3 color channels blocks[xi] = zsymtop.view(h, w, c) # return as numpy array return blocks.cpu().detach().numpy()