def decompressDump(self, compression_output): """ * Recover z* from compressed message. * Pass recovered hyperlatents through mean-scale hyperprior decoder obtain mean, scale over latents: z -> hyperdecoder() -> (mu, sigma). * Use latent entropy model to recover y* from compressed image. * Pass quantized latent through generator to obtain the reconstructed image. y* -> Generator() -> x*. """ assert self.model_mode == ModelModes.EVALUATION and ( self.training is False ), (f'Set model mode to {ModelModes.EVALUATION} for decompression.') latents_decoded, hyper_decoded = self.Hyperprior.decompress_forwardDump( compression_output, device=utils.get_device()) # Use quantized latents as input to G reconstruction = self.Generator(latents_decoded) if self.args.normalize_input_image is True: reconstruction = torch.tanh(reconstruction) # Undo padding image_dims = compression_output.spatial_shape reconstruction = reconstruction[:, :, :image_dims[0], :image_dims[1]] if self.args.normalize_input_image is True: # [-1.,1.] -> [0.,1.] reconstruction = (reconstruction + 1.) / 2. reconstruction = torch.clamp(reconstruction, min=0., max=1.) return reconstruction, latents_decoded, hyper_decoded
def estimate_tails(cdf, target, shape, dtype=torch.float32, extra_counts=24): """ Estimates approximate tail quantiles. This runs a simple Adam iteration to determine tail quantiles. The objective is to find an `x` such that: [[[ cdf(x) == target ]]] Note that `cdf` is assumed to be monotonic. When each tail estimate has passed the optimal value of `x`, the algorithm does `extra_counts` (default 10) additional iterations and then stops. This operation is vectorized. The tensor shape of `x` is given by `shape`, and `target` must have a shape that is broadcastable to the output of `func(x)`. Arguments: cdf: A callable that computes cumulative distribution function, survival function, or similar. target: The desired target value. shape: The shape of the tensor representing `x`. Returns: A `torch.Tensor` representing the solution (`x`). """ # A bit hacky lr, eps = 1e-2, 1e-8 beta_1, beta_2 = 0.9, 0.99 # Tails should be monotonically increasing device = utils.get_device() tails = torch.zeros(shape, dtype=dtype, requires_grad=True, device=device) m = torch.zeros(shape, dtype=dtype) v = torch.ones(shape, dtype=dtype) counts = torch.zeros(shape, dtype=torch.int32) while torch.min(counts) < extra_counts: loss = abs(cdf(tails) - target) loss.backward(torch.ones_like(tails)) tgrad = tails.grad.cpu() with torch.no_grad(): m = beta_1 * m + (1. - beta_1) * tgrad v = beta_2 * v + (1. - beta_2) * torch.square(tgrad) tails -= (lr * m / (torch.sqrt(v) + eps)).to(device) # Condition assumes tails init'd at zero counts = torch.where( torch.logical_or(counts > 0, tgrad.cpu() * tails.cpu() > 0), counts + 1, counts) tails.grad.zero_() return tails
def prepare_model(ckpt_path, input_dir): make_deterministic() device = utils.get_device() logger = utils.logger_setup(logpath=os.path.join(input_dir, f'logs_{time.time()}'), filepath=os.path.abspath(__file__)) loaded_args, model, _ = utils.load_model(ckpt_path, logger, device, model_mode=ModelModes.EVALUATION, current_args_d=None, prediction=True, strict=False, silent=True) model.logger.info('Model loaded from disk.') # Build probability tables model.logger.info('Building hyperprior probability tables...') model.Hyperprior.hyperprior_entropy_model.build_tables() model.logger.info('All tables built.') return model, loaded_args
def compress_and_save(model, args, data_loader, output_dir): # Compress and save compressed format to disk device = utils.get_device() model.logger.info('Starting compression...') with torch.no_grad(): for idx, (data, bpp, filenames) in enumerate(tqdm(data_loader), 0): data = data.to(device, dtype=torch.float) assert data.size(0) == 1, 'Currently only supports saving single images.' # Perform entropy coding compressed_output = model.compress(data) out_path = os.path.join(output_dir, f"{filenames[0]}_compressed.hfc") actual_bpp, theoretical_bpp = compression_utils.save_compressed_format(compressed_output, out_path=out_path) model.logger.info(f'Attained: {actual_bpp:.3f} bpp vs. theoretical: {theoretical_bpp:.3f} bpp.')
if (self.step_counter % self.log_interval == 1): self.store_loss('weighted_compression_loss', compression_model_loss.item()) if return_intermediates is True: return losses, intermediates else: return losses if __name__ == '__main__': logger = utils.logger_setup(logpath=os.path.join(directories.experiments, 'logs'), filepath=os.path.abspath(__file__)) device = utils.get_device() logger.info(f'Using device {device}') storage_train = defaultdict(list) storage_test = defaultdict(list) model = Model(hific_args, logger, storage_train, storage_test, model_type=ModelTypes.COMPRESSION_GAN) model.to(device) logger.info(model) transform_param_names = list() transform_params = list() logger.info('ALL PARAMETERS')
cmd_args = parser.parse_args() if (cmd_args.gpu != 0) or (cmd_args.force_set_gpu is True): torch.cuda.set_device(cmd_args.gpu) if cmd_args.model_type == ModelTypes.COMPRESSION: args = mse_lpips_args elif cmd_args.model_type == ModelTypes.COMPRESSION_GAN: args = hific_args elif cmd_args.model_type == ModelTypes.CLASSI_ONLY: args = classi_only start_time = time.time() is_gpu = True device = utils.get_device(is_gpu=is_gpu) # Override default arguments from config file with provided command line arguments dictify = lambda x: dict((n, getattr(x, n)) for n in dir(x) if not (n.startswith('__') or 'logger' in n)) args_d, cmd_args_d = dictify(args), vars(cmd_args) args_d.update(cmd_args_d) args = utils.Struct(**args_d) args = utils.setup_generic_signature(args, special_info=args.model_type) args.target_rate = args.target_rate_map[args.regime] args.lambda_A = args.lambda_A_map[args.regime] args.n_steps = int(args.n_steps) storage = defaultdict(list) storage_test = defaultdict(list) logger = utils.logger_setup(logpath=os.path.join(args.snapshot, 'logs'),
def build_tables(self, **kwargs): offsets = 0. lower_tail = self.distribution.lower_tail(self.tail_mass).cpu() upper_tail = self.distribution.upper_tail(self.tail_mass).cpu() self.compute_medians() medians = torch.squeeze(self.medians) # Largest distance observed between lower tail and median, # and between median and upper tail. minima = offsets - lower_tail minima = torch.ceil(minima).to(torch.int32) minima = torch.clamp(minima, min=0) maxima = upper_tail - offsets maxima = torch.ceil(maxima).to(torch.int32) maxima = torch.clamp(maxima, min=0) # PMF starting positions and lengths # pmf_start = offsets - minima.to(self.distribution.dtype) pmf_start = offsets - minima.to(torch.float32) pmf_length = maxima + minima + 1 # Symmetric for Gaussian max_length = pmf_length.max() samples = torch.arange(max_length, dtype=self.distribution.dtype) # Broadcast to [n_channels,1,*] format device = utils.get_device() samples = samples.view(1, -1) + pmf_start.view(-1, 1, 1) pmf = self.distribution.likelihood(samples.to(device), collapsed_format=True).cpu() # [n_channels, max_length] pmf = torch.squeeze(pmf) cdf_length = pmf_length + 2 cdf_offset = -minima cdf_length = cdf_length.to(torch.int32) cdf_offset = cdf_offset.to(torch.int32) # CDF shape [n_channels, max_length + 2] - account for fenceposts + overflow CDF = torch.zeros((len(pmf_length), max_length + 2), dtype=torch.int32) for n, (pmf_, pmf_length_) in enumerate(zip(tqdm(pmf), pmf_length)): pmf_ = pmf_[:pmf_length_] # [max_length] overflow = torch.clamp(1. - torch.sum(pmf_, dim=0, keepdim=True), min=0.) pmf_ = torch.cat((pmf_, overflow), dim=0) cdf_ = maths.pmf_to_quantized_cdf(pmf_, self.precision) cdf_ = F.pad(cdf_, (0, max_length - pmf_length_), mode='constant', value=0) CDF[n] = cdf_ # Serialize, compression method responsible for identifying which # CDF to use during compression self.CDF = nn.Parameter(CDF, requires_grad=False) self.CDF_offset = nn.Parameter(cdf_offset, requires_grad=False) self.CDF_length = nn.Parameter(cdf_length, requires_grad=False) compression_utils.check_argument_shapes(self.CDF, self.CDF_length, self.CDF_offset) self.register_parameter('CDF', self.CDF) self.register_parameter('CDF_offset', self.CDF_offset) self.register_parameter('CDF_length', self.CDF_length)
def compress_and_decompress(args): # Reproducibility make_deterministic() perceptual_loss_fn = ps.PerceptualLoss(model='net-lin', net='alex', use_gpu=torch.cuda.is_available()) # Load model device = utils.get_device() logger = utils.logger_setup(logpath=os.path.join(args.image_dir, 'logs'), filepath=os.path.abspath(__file__)) loaded_args, model, _ = utils.load_model(args.ckpt_path, logger, device, model_mode=ModelModes.EVALUATION, current_args_d=None, prediction=True, strict=False) # Override current arguments with recorded dictify = lambda x: dict((n, getattr(x, n)) for n in dir(x) if not (n.startswith('__') or 'logger' in n)) loaded_args_d, args_d = dictify(loaded_args), dictify(args) loaded_args_d.update(args_d) args = utils.Struct(**loaded_args_d) logger.info(loaded_args_d) # Build probability tables logger.info('Building hyperprior probability tables...') model.Hyperprior.hyperprior_entropy_model.build_tables() logger.info('All tables built.') eval_loader = datasets.get_dataloaders('evaluation', root=args.image_dir, batch_size=args.batch_size, logger=logger, shuffle=False, normalize=args.normalize_input_image) n, N = 0, len(eval_loader.dataset) input_filenames_total = list() output_filenames_total = list() bpp_total, q_bpp_total, LPIPS_total = torch.Tensor(N), torch.Tensor(N), torch.Tensor(N) utils.makedirs(args.output_dir) logger.info('Starting compression...') start_time = time.time() with torch.no_grad(): for idx, (data, bpp, filenames) in enumerate(tqdm(eval_loader), 0): data = data.to(device, dtype=torch.float) B = data.size(0) input_filenames_total.extend(filenames) if args.reconstruct is True: # Reconstruction without compression reconstruction, q_bpp = model(data, writeout=False) else: # Perform entropy coding compressed_output = model.compress(data) if args.save is True: assert B == 1, 'Currently only supports saving single images.' compression_utils.save_compressed_format(compressed_output, out_path=os.path.join(args.output_dir, f"{filenames[0]}_compressed.hfc")) reconstruction = model.decompress(compressed_output) q_bpp = compressed_output.total_bpp if args.normalize_input_image is True: # [-1., 1.] -> [0., 1.] data = (data + 1.) / 2. perceptual_loss = perceptual_loss_fn.forward(reconstruction, data, normalize=True) for subidx in range(reconstruction.shape[0]): if B > 1: q_bpp_per_im = float(q_bpp.cpu().numpy()[subidx]) else: q_bpp_per_im = float(q_bpp.item()) if type(q_bpp) == torch.Tensor else float(q_bpp) fname = os.path.join(args.output_dir, "{}_RECON_{:.3f}bpp.png".format(filenames[subidx], q_bpp_per_im)) torchvision.utils.save_image(reconstruction[subidx], fname, normalize=True) output_filenames_total.append(fname) bpp_total[n:n + B] = bpp.data q_bpp_total[n:n + B] = q_bpp.data if type(q_bpp) == torch.Tensor else q_bpp LPIPS_total[n:n + B] = perceptual_loss.data n += B df = pd.DataFrame([input_filenames_total, output_filenames_total]).T df.columns = ['input_filename', 'output_filename'] df['bpp_original'] = bpp_total.cpu().numpy() df['q_bpp'] = q_bpp_total.cpu().numpy() df['LPIPS'] = LPIPS_total.cpu().numpy() df_path = os.path.join(args.output_dir, 'compression_metrics.h5') df.to_hdf(df_path, key='df') pprint(df) logger.info('Complete. Reconstructions saved to {}. Output statistics saved to {}'.format(args.output_dir, df_path)) delta_t = time.time() - start_time logger.info('Time elapsed: {:.3f} s'.format(delta_t)) logger.info('Rate: {:.3f} Images / s:'.format(float(N) / delta_t))