def compression_forward(self, x, yclass): """ Forward pass through encoder, hyperprior, and decoder. Inputs x: Input image. Format (N,C,H,W), range [0,1], or [-1,1] if args.normalize_image is True torch.Tensor Outputs intermediates: NamedTuple of intermediate values """ image_dims = tuple(x.size()[1:]) # (C,H,W) if self.model_mode == ModelModes.EVALUATION and (self.training is False): #if 1: n_encoder_downsamples = self.Encoder.n_downsampling_layers factor = 2**n_encoder_downsamples x = utils.pad_factor(x, x.size()[2:], factor) # Encoder forward pass y = self.Encoder(x) #if 1: if self.model_mode == ModelModes.EVALUATION and (self.training is False): n_hyperencoder_downsamples = self.Hyperprior.analysis_net.n_downsampling_layers factor = 2**n_hyperencoder_downsamples y = utils.pad_factor(y, y.size()[2:], factor) hyperinfo = self.Hyperprior(y, spatial_shape=x.size()[2:]) latents_quantized = hyperinfo.decoded total_nbpp = hyperinfo.total_nbpp total_qbpp = hyperinfo.total_qbpp yhat_class = self.Classi(latents_quantized) #yhat_class = self.Classi(y) # Use quantized latents as input to G reconstruction = self.Generator(latents_quantized) if self.args.normalize_input_image is True: reconstruction = torch.tanh(reconstruction) # Undo padding #if 1: if self.model_mode == ModelModes.EVALUATION and (self.training is False): reconstruction = reconstruction[:, :, :image_dims[1], : image_dims[2]] intermediates = Intermediates(x, yhat_class, yclass, reconstruction, latents_quantized, total_nbpp, total_qbpp) return intermediates, hyperinfo
def compress(self, x, silent=False): """ * Pass image through encoder to obtain latents: x -> Encoder() -> y * Pass latents through hyperprior encoder to obtain hyperlatents: y -> hyperencoder() -> z * Encode hyperlatents via nonparametric entropy model. * Pass hyperlatents through mean-scale hyperprior decoder to obtain mean, scale over latents: z -> hyperdecoder() -> (mu, sigma). * Encode latents via entropy model derived from (mean, scale) hyperprior output. """ assert self.model_mode == ModelModes.EVALUATION and ( self.training is False), ( f'Set model mode to {ModelModes.EVALUATION} for compression.') spatial_shape = tuple(x.size()[2:]) if self.model_mode == ModelModes.EVALUATION and (self.training is False): #if 1: n_encoder_downsamples = self.Encoder.n_downsampling_layers factor = 2**n_encoder_downsamples x = utils.pad_factor(x, x.size()[2:], factor) # Encoder forward pass y = self.Encoder(x) if self.model_mode == ModelModes.EVALUATION and (self.training is False): #if 1: n_hyperencoder_downsamples = self.Hyperprior.analysis_net.n_downsampling_layers factor = 2**n_hyperencoder_downsamples y = utils.pad_factor(y, y.size()[2:], factor) compression_output = self.Hyperprior.compress_forward(y, spatial_shape) attained_hbpp = 32 * len( compression_output.hyperlatents_encoded) / np.prod(spatial_shape) attained_lbpp = 32 * len( compression_output.latents_encoded) / np.prod(spatial_shape) attained_bpp = 32 * ( (len(compression_output.hyperlatents_encoded) + len(compression_output.latents_encoded)) / np.prod(spatial_shape)) if silent is False: self.logger.info('[ESTIMATED]') self.logger.info(f'BPP: {compression_output.total_bpp:.3f}') self.logger.info( f'HL BPP: {compression_output.hyperlatent_bpp:.3f}') self.logger.info(f'L BPP: {compression_output.latent_bpp:.3f}') self.logger.info('[ATTAINED]') self.logger.info(f'BPP: {attained_bpp:.3f}') self.logger.info(f'HL BPP: {attained_hbpp:.3f}') self.logger.info(f'L BPP: {attained_lbpp:.3f}') return attained_bpp, compression_output
def vec_ans_index_decoder(encoded, indices, cdf, cdf_length, cdf_offset, precision, coding_shape, overflow_width=OVERFLOW_WIDTH, **kwargs): """ Reverse op of `vec_ans_index_encoder`. Decodes ans-encoded bitstring into a decoded message tensor. Arguments (`indices`, `cdf`, `cdf_length`, `cdf_offset`, `precision`) must be identical to the inputs to `vec_ans_index_encoder` used to generate the encoded tensor. """ original_shape = indices.shape B, n_channels, *_ = original_shape message = vrans.unflatten(encoded, coding_shape) indices = indices.astype(np.int32) cdf_index = indices max_overflow = (1 << overflow_width) - 1 overflow_cdf_size = (1 << overflow_width) + 1 overflow_cdf = np.arange(overflow_cdf_size, dtype=np.uint64)[None, :] enc_statfun_overflow = _vec_indexed_cdf_to_enc_statfun(overflow_cdf) dec_statfun_overflow = _vec_indexed_cdf_to_dec_statfun( overflow_cdf, np.ones_like(overflow_cdf) * len(overflow_cdf)) overflow_codec = base_codec(enc_statfun_overflow, dec_statfun_overflow, overflow_width) assert bool(np.all(cdf_index >= 0)) and bool( np.all(cdf_index < cdf.shape[0])), ("Invalid index.") max_value = cdf_length[cdf_index] - 2 assert bool(np.all(max_value >= 0)) and bool( np.all(max_value < cdf.shape[1] - 1)), ("Invalid max length.") if B == 1: # Vectorize on patches - there's probably a way to interlace patches with # batch elements for B > 1 ... if ((original_shape[2] % PATCH_SIZE[0] == 0) and (original_shape[3] % PATCH_SIZE[1] == 0)) is False: indices = utils.pad_factor(torch.Tensor(indices), original_shape[2:], factor=PATCH_SIZE).cpu().numpy().astype( np.int32) padded_shape = indices.shape assert (indices.shape[2] % PATCH_SIZE[0] == 0) and (indices.shape[3] % PATCH_SIZE[1] == 0) cdf_index, unfolded_shape = compression_utils.decompose( indices, n_channels) coding_shape = cdf_index.shape[1:] symbols = [] _, overflow_pop = substack(codec=overflow_codec, view_fun=overflow_view) for i in range(len(cdf_index)): cdf_index_i = cdf_index[i] cdf_i = cdf[cdf_index_i] cdf_length_i = cdf_length[cdf_index_i] enc_statfun = _vec_indexed_cdf_to_enc_statfun(cdf_i) dec_statfun = _vec_indexed_cdf_to_dec_statfun(cdf_i, cdf_length_i) symbol_push, symbol_pop = base_codec(enc_statfun, dec_statfun, precision) message, value = symbol_pop(message) max_value_i = cdf_length_i - 2 of_mask = value == max_value_i if np.any(of_mask): message, val = overflow_pop(message, overflow_width, of_mask) val = cast2u64(val) widths = val cond_mask = val == max_overflow while np.any(cond_mask): message, val = overflow_pop(message, overflow_width, of_mask) val = cast2u64(val) widths = np.where(cond_mask, widths + val, widths) cond_mask = val == max_overflow overflow = np.zeros_like(val) cond_mask = widths != 0 while np.any(cond_mask): counter = 0 message, val = overflow_pop(message, overflow_width, of_mask) val = cast2u64(val) assert np.all(val <= max_overflow) op = overflow | (val << (counter * overflow_width)) overflow = np.where(cond_mask, op, overflow) widths = np.where(cond_mask, widths - 1, widths) cond_mask = widths != 0 counter += 1 overflow_broadcast = value overflow_broadcast[of_mask] = overflow overflow = overflow_broadcast value = np.where(of_mask, overflow >> 1, value) cond_mask = np.logical_and(of_mask, overflow & 1) value = np.where(cond_mask, -value - 1, value) cond_mask = np.logical_and(of_mask, np.logical_not(overflow & 1)) value = np.where(cond_mask, value + max_value_i, value) symbol = value + cdf_offset[cdf_index_i] symbols.append(symbol) if B == 1: decoded = compression_utils.reconstitute(np.stack(symbols, axis=0), padded_shape, unfolded_shape) if tuple(decoded.shape) != tuple(original_shape): decoded = decoded[:, :, :original_shape[2], :original_shape[3]] else: decoded = np.stack(symbols, axis=0) return decoded
def vec_ans_index_buffered_encoder(symbols, indices, cdf, cdf_length, cdf_offset, precision, coding_shape, overflow_width=OVERFLOW_WIDTH, **kwargs): """ Vectorized version of `ans_index_encoder`. Incurs constant bit overhead, but is faster. ANS-encodes unbounded integer data using an indexed probability table. """ instructions = [] symbols_shape = symbols.shape B, n_channels = symbols_shape[:2] symbols = symbols.astype(np.int32) indices = indices.astype(np.int32) cdf_index = indices max_overflow = (1 << overflow_width) - 1 overflow_cdf_size = (1 << overflow_width) + 1 overflow_cdf = np.arange(overflow_cdf_size, dtype=np.uint64)[None, None, None, :] enc_statfun_overflow = _vec_indexed_cdf_to_enc_statfun(overflow_cdf) dec_statfun_overflow = _vec_indexed_cdf_to_dec_statfun( overflow_cdf, np.ones_like(overflow_cdf) * len(overflow_cdf)) overflow_push, overflow_pop = base_codec(enc_statfun_overflow, dec_statfun_overflow, overflow_width) assert bool(np.all(cdf_index >= 0)) and bool( np.all(cdf_index < cdf.shape[0])), ("Invalid index.") max_value = cdf_length[cdf_index] - 2 assert bool(np.all(max_value >= 0)) and bool( np.all(max_value < cdf.shape[1] - 1)), ("Invalid max length.") # Map values with tracked probabilities to range [0, ..., max_value] values = symbols - cdf_offset[cdf_index] # If outside of this range, map value to non-negative integer overflow. overflow = np.zeros_like(values) of_mask_lower = values < 0 overflow = np.where(of_mask_lower, -2 * values - 1, overflow) of_mask_upper = values >= max_value overflow = np.where(of_mask_upper, 2 * (values - max_value), overflow) values = np.where(np.logical_or(of_mask_lower, of_mask_upper), max_value, values) assert bool(np.all(values >= 0)), ( "Invalid shifted value for current symbol - values must be non-negative." ) assert bool(np.all(values < cdf_length[cdf_index] - 1)), ( "Invalid shifted value for current symbol - outside cdf index bounds.") if B == 1: # Vectorize on patches - there's probably a way to interlace patches with # batch elements for B > 1 ... if ((symbols_shape[2] % PATCH_SIZE[0] == 0) and (symbols_shape[3] % PATCH_SIZE[1] == 0)) is False: values = utils.pad_factor(torch.Tensor(values), symbols_shape[2:], factor=PATCH_SIZE).cpu().numpy().astype( np.int32) indices = utils.pad_factor(torch.Tensor(indices), symbols_shape[2:], factor=PATCH_SIZE).cpu().numpy().astype( np.int32) overflow = utils.pad_factor( torch.Tensor(overflow), symbols_shape[2:], factor=PATCH_SIZE).cpu().numpy().astype(np.int32) assert (values.shape[2] % PATCH_SIZE[0] == 0) and (values.shape[3] % PATCH_SIZE[1] == 0) assert (indices.shape[2] % PATCH_SIZE[0] == 0) and (indices.shape[3] % PATCH_SIZE[1] == 0) values, _ = compression_utils.decompose(values, n_channels) overflow, _ = compression_utils.decompose(overflow, n_channels) cdf_index, unfolded_shape = compression_utils.decompose( indices, n_channels) coding_shape = values.shape[1:] assert coding_shape == cdf_index.shape[1:] # LIFO - last item in buffer is first item decompressed for i in range(len(cdf_index)): # loop over batch dimension # Bin of discrete CDF that value belongs to value_i = values[i] cdf_index_i = cdf_index[i] cdf_i = cdf[cdf_index_i] cdf_length_i = cdf_length[cdf_index_i] max_value_i = cdf_length_i - 2 enc_statfun = _vec_indexed_cdf_to_enc_statfun(cdf_i) dec_statfun = _vec_indexed_cdf_to_dec_statfun(cdf_i, cdf_length_i) symbol_push, symbol_pop = base_codec(enc_statfun, dec_statfun, precision) start, freq = enc_statfun(value_i) instructions.append((start, freq, False, precision, 0)) """ Encode overflows here """ # No-op empty_start = np.zeros_like(value_i).astype(np.uint) empty_freq = np.ones_like(value_i).astype(np.uint) overflow_i = overflow[i] of_mask = value_i == max_value_i if np.any(of_mask): widths = np.zeros_like(value_i) cond_mask = (overflow_i >> (widths * overflow_width)) != 0 while np.any(cond_mask): widths = np.where(cond_mask, widths + 1, widths) cond_mask = (overflow_i >> (widths * overflow_width)) != 0 val = widths cond_mask = val >= max_overflow while np.any(cond_mask): print('Warning: Undefined behaviour.') val_push = cast2u64(max_overflow) overflow_start, overflow_freq = enc_statfun_overflow(val_push) start = overflow_start[of_mask] freq = overflow_start[of_mask] instructions.append( (start, freq, True, int(overflow_width), of_mask)) # val[cond_mask] -= max_overflow val = np.where(cond_mask, val - max_overflow, val) cond_mask = val >= max_overflow val_push = cast2u64(val) overflow_start, overflow_freq = enc_statfun_overflow(val_push) start = overflow_start[of_mask] freq = overflow_freq[of_mask] instructions.append( (start, freq, True, int(overflow_width), of_mask)) cond_mask = widths != 0 while np.any(cond_mask): counter = 0 encoding = (overflow_i >> (counter * overflow_width)) & max_overflow val = np.where(cond_mask, encoding, val) val_push = cast2u64(val) overflow_start, overflow_freq = enc_statfun_overflow(val_push) start = overflow_start[of_mask] freq = overflow_freq[of_mask] instructions.append( (start, freq, True, int(overflow_width), of_mask)) widths = np.where(cond_mask, widths - 1, widths) cond_mask = widths != 0 counter += 1 return instructions, coding_shape
def vec_ans_index_decoder(encoded, indices, cdf, cdf_length, cdf_offset, precision, coding_shape, overflow_width=OVERFLOW_WIDTH, **kwargs): """ Reverse op of `vec_ans_index_encoder`. Decodes ans-encoded bitstring into a decoded message tensor. Arguments (`indices`, `cdf`, `cdf_length`, `cdf_offset`, `precision`) must be identical to the inputs to `vec_ans_index_encoder` used to generate the encoded tensor. """ original_shape = indices.shape B, n_channels, *_ = original_shape message = vrans.unflatten(encoded, coding_shape) indices = indices.astype(np.int32) cdf_index = indices assert bool(np.all(cdf_index >= 0)) and bool( np.all(cdf_index < cdf.shape[0])), ("Invalid index.") max_value = cdf_length[cdf_index] - 2 assert bool(np.all(max_value >= 0)) and bool( np.all(max_value < cdf.shape[1] - 1)), ("Invalid max length.") if B == 1: # Vectorize on patches - there's probably a way to interlace patches with # batch elements for B > 1 ... if ((original_shape[2] % PATCH_SIZE[0] == 0) and (original_shape[3] % PATCH_SIZE[1] == 0)) is False: indices = utils.pad_factor(torch.Tensor(indices), original_shape[2:], factor=PATCH_SIZE).cpu().numpy().astype( np.int32) padded_shape = indices.shape assert (indices.shape[2] % PATCH_SIZE[0] == 0) and (indices.shape[3] % PATCH_SIZE[1] == 0) cdf_index, unfolded_shape = compression_utils.decompose( indices, n_channels) coding_shape = cdf_index.shape[1:] symbols = [] for i in range(len(cdf_index)): cdf_index_i = cdf_index[i] cdf_i = cdf[cdf_index_i] cdf_length_i = cdf_length[cdf_index_i] enc_statfun = _vec_indexed_cdf_to_enc_statfun(cdf_i) dec_statfun = _vec_indexed_cdf_to_dec_statfun(cdf_i, cdf_length_i) symbol_push, symbol_pop = base_codec(enc_statfun, dec_statfun, precision) message, value = symbol_pop(message) symbol = value + cdf_offset[cdf_index_i] symbols.append(symbol) if B == 1: decoded = compression_utils.reconstitute(np.stack(symbols, axis=0), padded_shape, unfolded_shape) if tuple(decoded.shape) != tuple(original_shape): decoded = decoded[:, :, :original_shape[2], :original_shape[3]] else: decoded = np.stack(symbols, axis=0) return decoded
def vec_ans_index_buffered_encoder(symbols, indices, cdf, cdf_length, cdf_offset, precision, coding_shape, overflow_width=OVERFLOW_WIDTH, **kwargs): """ Vectorized version of `ans_index_encoder`. Incurs constant bit overhead, but is faster. ANS-encodes unbounded integer data using an indexed probability table. """ instructions = [] symbols_shape = symbols.shape B, n_channels = symbols_shape[:2] symbols = symbols.astype(np.int32) indices = indices.astype(np.int32) cdf_index = indices assert bool(np.all(cdf_index >= 0)) and bool( np.all(cdf_index < cdf.shape[0])), ("Invalid index.") max_value = cdf_length[cdf_index] - 2 assert bool(np.all(max_value >= 0)) and bool( np.all(max_value < cdf.shape[1] - 1)), ("Invalid max length.") # Map values with tracked probabilities to range [0, ..., max_value] values = symbols - cdf_offset[cdf_index] # If outside of this range, map value to non-negative integer overflow. overflow = np.zeros_like(values) of_mask = values < 0 overflow = np.where(of_mask, -2 * values - 1, overflow) values = np.where(of_mask, max_value, values) of_mask = values >= max_value overflow = np.where(of_mask, 2 * (values - max_value), overflow) values = np.where(of_mask, max_value, values) assert bool(np.all(values >= 0)), ( "Invalid shifted value for current symbol - values must be non-negative." ) assert bool(np.all(values < cdf_length[cdf_index] - 1)), ( "Invalid shifted value for current symbol - outside cdf index bounds.") if B == 1: # Vectorize on patches - there's probably a way to interlace patches with # batch elements for B > 1 ... if ((symbols_shape[2] % PATCH_SIZE[0] == 0) and (symbols_shape[3] % PATCH_SIZE[1] == 0)) is False: values = utils.pad_factor(torch.Tensor(values), symbols_shape[2:], factor=PATCH_SIZE).cpu().numpy().astype( np.int32) indices = utils.pad_factor(torch.Tensor(indices), symbols_shape[2:], factor=PATCH_SIZE).cpu().numpy().astype( np.int32) assert (values.shape[2] % PATCH_SIZE[0] == 0) and (values.shape[3] % PATCH_SIZE[1] == 0) assert (indices.shape[2] % PATCH_SIZE[0] == 0) and (indices.shape[3] % PATCH_SIZE[1] == 0) values, _ = compression_utils.decompose(values, n_channels) cdf_index, unfolded_shape = compression_utils.decompose( indices, n_channels) coding_shape = values.shape[1:] # LIFO - last item in buffer is first item decompressed for i in range(len(cdf_index)): # loop over batch dimension # Bin of discrete CDF that value belongs to value_i = values[i] cdf_index_i = cdf_index[i] cdf_i = cdf[cdf_index_i] cdf_i_length = cdf_length[cdf_index_i] enc_statfun = _vec_indexed_cdf_to_enc_statfun(cdf_i) dec_statfun = _vec_indexed_cdf_to_dec_statfun(cdf_i, cdf_i_length) symbol_push, symbol_pop = base_codec(enc_statfun, dec_statfun, precision) start, freq = enc_statfun(value_i) instructions.append((start, freq, False)) """ Encode overflows here """ # of_mask = values == max_value # widths = np.zeros_like(values) # widths[of_mask] = 1 # overflow = # cond_mask = overflow >> (widths * overflow_width) != 0 # while np.all(cond_mask) is False: # widths[cond_mask] += 1 # val = widths return instructions, coding_shape