Пример #1
0
    def decompressDump(self, compression_output):
        """
        * Recover z* from compressed message.
        * Pass recovered hyperlatents through mean-scale hyperprior decoder obtain mean,
          scale over latents: z -> hyperdecoder() -> (mu, sigma).
        * Use latent entropy model to recover y* from compressed image.
        * Pass quantized latent through generator to obtain the reconstructed image.
          y* -> Generator() -> x*.
        """

        assert self.model_mode == ModelModes.EVALUATION and (
            self.training is False
        ), (f'Set model mode to {ModelModes.EVALUATION} for decompression.')

        latents_decoded, hyper_decoded = self.Hyperprior.decompress_forwardDump(
            compression_output, device=utils.get_device())

        # Use quantized latents as input to G
        reconstruction = self.Generator(latents_decoded)

        if self.args.normalize_input_image is True:
            reconstruction = torch.tanh(reconstruction)

        # Undo padding
        image_dims = compression_output.spatial_shape
        reconstruction = reconstruction[:, :, :image_dims[0], :image_dims[1]]

        if self.args.normalize_input_image is True:
            # [-1.,1.] -> [0.,1.]
            reconstruction = (reconstruction + 1.) / 2.

        reconstruction = torch.clamp(reconstruction, min=0., max=1.)

        return reconstruction, latents_decoded, hyper_decoded
Пример #2
0
def estimate_tails(cdf, target, shape, dtype=torch.float32, extra_counts=24):
    """
    Estimates approximate tail quantiles.
    This runs a simple Adam iteration to determine tail quantiles. The
    objective is to find an `x` such that: [[[ cdf(x) == target ]]]

    Note that `cdf` is assumed to be monotonic. When each tail estimate has passed the 
    optimal value of `x`, the algorithm does `extra_counts` (default 10) additional 
    iterations and then stops.

    This operation is vectorized. The tensor shape of `x` is given by `shape`, and
    `target` must have a shape that is broadcastable to the output of `func(x)`.

    Arguments:
    cdf: A callable that computes cumulative distribution function, survival
         function, or similar.
    target: The desired target value.
    shape: The shape of the tensor representing `x`.
    Returns:
    A `torch.Tensor` representing the solution (`x`).
    """
    # A bit hacky
    lr, eps = 1e-2, 1e-8
    beta_1, beta_2 = 0.9, 0.99

    # Tails should be monotonically increasing
    device = utils.get_device()
    tails = torch.zeros(shape, dtype=dtype, requires_grad=True, device=device)

    m = torch.zeros(shape, dtype=dtype)
    v = torch.ones(shape, dtype=dtype)
    counts = torch.zeros(shape, dtype=torch.int32)

    while torch.min(counts) < extra_counts:
        loss = abs(cdf(tails) - target)
        loss.backward(torch.ones_like(tails))

        tgrad = tails.grad.cpu()

        with torch.no_grad():
            m = beta_1 * m + (1. - beta_1) * tgrad
            v = beta_2 * v + (1. - beta_2) * torch.square(tgrad)
            tails -= (lr * m / (torch.sqrt(v) + eps)).to(device)

        # Condition assumes tails init'd at zero
        counts = torch.where(
            torch.logical_or(counts > 0,
                             tgrad.cpu() * tails.cpu() > 0), counts + 1,
            counts)

        tails.grad.zero_()

    return tails
def prepare_model(ckpt_path, input_dir):

    make_deterministic()
    device = utils.get_device()
    logger = utils.logger_setup(logpath=os.path.join(input_dir, f'logs_{time.time()}'), filepath=os.path.abspath(__file__))
    loaded_args, model, _ = utils.load_model(ckpt_path, logger, device, model_mode=ModelModes.EVALUATION,
        current_args_d=None, prediction=True, strict=False, silent=True)
    model.logger.info('Model loaded from disk.')

    # Build probability tables
    model.logger.info('Building hyperprior probability tables...')
    model.Hyperprior.hyperprior_entropy_model.build_tables()
    model.logger.info('All tables built.')

    return model, loaded_args
def compress_and_save(model, args, data_loader, output_dir):
    # Compress and save compressed format to disk

    device = utils.get_device()
    model.logger.info('Starting compression...')

    with torch.no_grad():
        for idx, (data, bpp, filenames) in enumerate(tqdm(data_loader), 0):
            data = data.to(device, dtype=torch.float)
            assert data.size(0) == 1, 'Currently only supports saving single images.'

            # Perform entropy coding
            compressed_output = model.compress(data)

            out_path = os.path.join(output_dir, f"{filenames[0]}_compressed.hfc")
            actual_bpp, theoretical_bpp = compression_utils.save_compressed_format(compressed_output,
                out_path=out_path)
            model.logger.info(f'Attained: {actual_bpp:.3f} bpp vs. theoretical: {theoretical_bpp:.3f} bpp.')
        if (self.step_counter % self.log_interval == 1):
            self.store_loss('weighted_compression_loss',
                            compression_model_loss.item())

        if return_intermediates is True:
            return losses, intermediates
        else:
            return losses


if __name__ == '__main__':

    logger = utils.logger_setup(logpath=os.path.join(directories.experiments,
                                                     'logs'),
                                filepath=os.path.abspath(__file__))
    device = utils.get_device()
    logger.info(f'Using device {device}')
    storage_train = defaultdict(list)
    storage_test = defaultdict(list)
    model = Model(hific_args,
                  logger,
                  storage_train,
                  storage_test,
                  model_type=ModelTypes.COMPRESSION_GAN)
    model.to(device)

    logger.info(model)

    transform_param_names = list()
    transform_params = list()
    logger.info('ALL PARAMETERS')
Пример #6
0
    cmd_args = parser.parse_args()

    if (cmd_args.gpu != 0) or (cmd_args.force_set_gpu is True):
        torch.cuda.set_device(cmd_args.gpu)

    if cmd_args.model_type == ModelTypes.COMPRESSION:
        args = mse_lpips_args
    elif cmd_args.model_type == ModelTypes.COMPRESSION_GAN:
        args = hific_args
    elif cmd_args.model_type == ModelTypes.CLASSI_ONLY:
        args = classi_only

    start_time = time.time()
    is_gpu = True
    device = utils.get_device(is_gpu=is_gpu)

    # Override default arguments from config file with provided command line arguments
    dictify = lambda x: dict((n, getattr(x, n)) for n in dir(x)
                             if not (n.startswith('__') or 'logger' in n))
    args_d, cmd_args_d = dictify(args), vars(cmd_args)
    args_d.update(cmd_args_d)
    args = utils.Struct(**args_d)
    args = utils.setup_generic_signature(args, special_info=args.model_type)
    args.target_rate = args.target_rate_map[args.regime]
    args.lambda_A = args.lambda_A_map[args.regime]
    args.n_steps = int(args.n_steps)

    storage = defaultdict(list)
    storage_test = defaultdict(list)
    logger = utils.logger_setup(logpath=os.path.join(args.snapshot, 'logs'),
Пример #7
0
    def build_tables(self, **kwargs):

        offsets = 0.

        lower_tail = self.distribution.lower_tail(self.tail_mass).cpu()
        upper_tail = self.distribution.upper_tail(self.tail_mass).cpu()

        self.compute_medians()
        medians = torch.squeeze(self.medians)

        # Largest distance observed between lower tail and median,
        # and between median and upper tail.
        minima = offsets - lower_tail
        minima = torch.ceil(minima).to(torch.int32)
        minima = torch.clamp(minima, min=0)

        maxima = upper_tail - offsets
        maxima = torch.ceil(maxima).to(torch.int32)
        maxima = torch.clamp(maxima, min=0)

        # PMF starting positions and lengths
        # pmf_start = offsets - minima.to(self.distribution.dtype)
        pmf_start = offsets - minima.to(torch.float32)
        pmf_length = maxima + minima + 1  # Symmetric for Gaussian

        max_length = pmf_length.max()
        samples = torch.arange(max_length, dtype=self.distribution.dtype)

        # Broadcast to [n_channels,1,*] format
        device = utils.get_device()
        samples = samples.view(1, -1) + pmf_start.view(-1, 1, 1)
        pmf = self.distribution.likelihood(samples.to(device),
                                           collapsed_format=True).cpu()

        # [n_channels, max_length]
        pmf = torch.squeeze(pmf)

        cdf_length = pmf_length + 2
        cdf_offset = -minima

        cdf_length = cdf_length.to(torch.int32)
        cdf_offset = cdf_offset.to(torch.int32)

        # CDF shape [n_channels, max_length + 2] - account for fenceposts + overflow
        CDF = torch.zeros((len(pmf_length), max_length + 2), dtype=torch.int32)
        for n, (pmf_, pmf_length_) in enumerate(zip(tqdm(pmf), pmf_length)):
            pmf_ = pmf_[:pmf_length_]  # [max_length]
            overflow = torch.clamp(1. - torch.sum(pmf_, dim=0, keepdim=True),
                                   min=0.)
            pmf_ = torch.cat((pmf_, overflow), dim=0)

            cdf_ = maths.pmf_to_quantized_cdf(pmf_, self.precision)
            cdf_ = F.pad(cdf_, (0, max_length - pmf_length_),
                         mode='constant',
                         value=0)
            CDF[n] = cdf_

        # Serialize, compression method responsible for identifying which
        # CDF to use during compression
        self.CDF = nn.Parameter(CDF, requires_grad=False)
        self.CDF_offset = nn.Parameter(cdf_offset, requires_grad=False)
        self.CDF_length = nn.Parameter(cdf_length, requires_grad=False)

        compression_utils.check_argument_shapes(self.CDF, self.CDF_length,
                                                self.CDF_offset)

        self.register_parameter('CDF', self.CDF)
        self.register_parameter('CDF_offset', self.CDF_offset)
        self.register_parameter('CDF_length', self.CDF_length)
def compress_and_decompress(args):

    # Reproducibility
    make_deterministic()
    perceptual_loss_fn = ps.PerceptualLoss(model='net-lin', net='alex', use_gpu=torch.cuda.is_available())

    # Load model
    device = utils.get_device()
    logger = utils.logger_setup(logpath=os.path.join(args.image_dir, 'logs'), filepath=os.path.abspath(__file__))
    loaded_args, model, _ = utils.load_model(args.ckpt_path, logger, device, model_mode=ModelModes.EVALUATION,
        current_args_d=None, prediction=True, strict=False)

    # Override current arguments with recorded
    dictify = lambda x: dict((n, getattr(x, n)) for n in dir(x) if not (n.startswith('__') or 'logger' in n))
    loaded_args_d, args_d = dictify(loaded_args), dictify(args)
    loaded_args_d.update(args_d)
    args = utils.Struct(**loaded_args_d)
    logger.info(loaded_args_d)

    # Build probability tables
    logger.info('Building hyperprior probability tables...')
    model.Hyperprior.hyperprior_entropy_model.build_tables()
    logger.info('All tables built.')


    eval_loader = datasets.get_dataloaders('evaluation', root=args.image_dir, batch_size=args.batch_size,
                                           logger=logger, shuffle=False, normalize=args.normalize_input_image)

    n, N = 0, len(eval_loader.dataset)
    input_filenames_total = list()
    output_filenames_total = list()
    bpp_total, q_bpp_total, LPIPS_total = torch.Tensor(N), torch.Tensor(N), torch.Tensor(N)
    utils.makedirs(args.output_dir)
    
    logger.info('Starting compression...')
    start_time = time.time()

    with torch.no_grad():

        for idx, (data, bpp, filenames) in enumerate(tqdm(eval_loader), 0):
            data = data.to(device, dtype=torch.float)
            B = data.size(0)
            input_filenames_total.extend(filenames)

            if args.reconstruct is True:
                # Reconstruction without compression
                reconstruction, q_bpp = model(data, writeout=False)
            else:
                # Perform entropy coding
                compressed_output = model.compress(data)

                if args.save is True:
                    assert B == 1, 'Currently only supports saving single images.'
                    compression_utils.save_compressed_format(compressed_output, 
                        out_path=os.path.join(args.output_dir, f"{filenames[0]}_compressed.hfc"))

                reconstruction = model.decompress(compressed_output)
                q_bpp = compressed_output.total_bpp

            if args.normalize_input_image is True:
                # [-1., 1.] -> [0., 1.]
                data = (data + 1.) / 2.

            perceptual_loss = perceptual_loss_fn.forward(reconstruction, data, normalize=True)


            for subidx in range(reconstruction.shape[0]):
                if B > 1:
                    q_bpp_per_im = float(q_bpp.cpu().numpy()[subidx])
                else:
                    q_bpp_per_im = float(q_bpp.item()) if type(q_bpp) == torch.Tensor else float(q_bpp)

                fname = os.path.join(args.output_dir, "{}_RECON_{:.3f}bpp.png".format(filenames[subidx], q_bpp_per_im))
                torchvision.utils.save_image(reconstruction[subidx], fname, normalize=True)
                output_filenames_total.append(fname)

            bpp_total[n:n + B] = bpp.data
            q_bpp_total[n:n + B] = q_bpp.data if type(q_bpp) == torch.Tensor else q_bpp
            LPIPS_total[n:n + B] = perceptual_loss.data
            n += B

    df = pd.DataFrame([input_filenames_total, output_filenames_total]).T
    df.columns = ['input_filename', 'output_filename']
    df['bpp_original'] = bpp_total.cpu().numpy()
    df['q_bpp'] = q_bpp_total.cpu().numpy()
    df['LPIPS'] = LPIPS_total.cpu().numpy()

    df_path = os.path.join(args.output_dir, 'compression_metrics.h5')
    df.to_hdf(df_path, key='df')

    pprint(df)

    logger.info('Complete. Reconstructions saved to {}. Output statistics saved to {}'.format(args.output_dir, df_path))
    delta_t = time.time() - start_time
    logger.info('Time elapsed: {:.3f} s'.format(delta_t))
    logger.info('Rate: {:.3f} Images / s:'.format(float(N) / delta_t))