def on_epoch_start(self):
        if self.init == True:
            self.init = False
            self.perceptual_loss = ps.PerceptualLoss(model='net-lin', net='alex', use_gpu=torch.cuda.is_available(),
                                                     gpu_ids=[self.device])

        self.epoch_loss, self.epoch_test_loss = [], []
        self.epoch_start_time = time.time()
    def __init__(self,
                 args,
                 logger,
                 storage_train=defaultdict(list),
                 storage_test=defaultdict(list),
                 model_mode=ModelModes.TRAINING,
                 model_type=ModelTypes.COMPRESSION):
        super(Model, self).__init__()
        """
        Builds hific model from submodels in network.
        """
        self.args = args
        self.model_mode = model_mode
        self.model_type = model_type
        self.logger = logger
        self.log_interval = args.log_interval
        self.storage_train = storage_train
        self.storage_test = storage_test
        self.step_counter = 0

        if self.args.use_latent_mixture_model is True:
            self.args.latent_channels = self.args.latent_channels_DLMM

        if not hasattr(ModelTypes, self.model_type.upper()):
            raise ValueError("Invalid model_type: [{}]".format(
                self.model_type))
        if not hasattr(ModelModes, self.model_mode.upper()):
            raise ValueError("Invalid model_mode: [{}]".format(
                self.model_mode))

        self.image_dims = self.args.image_dims  # Assign from dataloader
        self.batch_size = self.args.batch_size

        self.Encoder = encoder.Encoder(self.image_dims,
                                       self.batch_size,
                                       C=self.args.latent_channels,
                                       channel_norm=self.args.use_channel_norm)
        self.Generator = generator.Generator(
            self.image_dims,
            self.batch_size,
            C=self.args.latent_channels,
            n_residual_blocks=self.args.n_residual_blocks,
            channel_norm=self.args.use_channel_norm,
            sample_noise=self.args.sample_noise,
            noise_dim=self.args.noise_dim)

        if self.args.use_latent_mixture_model is True:
            self.Hyperprior = hyperprior.HyperpriorDLMM(
                bottleneck_capacity=self.args.latent_channels,
                likelihood_type=self.args.likelihood_type,
                mixture_components=self.args.mixture_components)
        else:
            self.Hyperprior = hyperprior.Hyperprior(
                bottleneck_capacity=self.args.latent_channels,
                likelihood_type=self.args.likelihood_type)

        self.amortization_models = [self.Encoder, self.Generator]
        self.amortization_models.extend(self.Hyperprior.amortization_models)

        # Use discriminator if GAN mode enabled and in training/validation
        self.use_discriminator = (self.model_type == ModelTypes.COMPRESSION_GAN
                                  and
                                  (self.model_mode != ModelModes.EVALUATION))

        if self.use_discriminator is True:
            assert self.args.discriminator_steps > 0, 'Must specify nonzero training steps for D!'
            self.discriminator_steps = self.args.discriminator_steps
            self.logger.info(
                'GAN mode enabled. Training discriminator for {} steps.'.
                format(self.discriminator_steps))
            self.Discriminator = discriminator.Discriminator(
                image_dims=self.image_dims,
                context_dims=self.args.latent_dims,
                C=self.args.latent_channels)
            self.gan_loss = partial(losses.gan_loss, args.gan_loss_type)
        else:
            self.discriminator_steps = 0
            self.Discriminator = None

        self.squared_difference = torch.nn.MSELoss(reduction='none')
        # Expects [-1,1] images or [0,1] with normalize=True flag
        self.perceptual_loss = ps.PerceptualLoss(
            model='net-lin',
            net='alex',
            use_gpu=torch.cuda.is_available(),
            gpu_ids=[args.gpu])
Example #3
0
def end_of_epoch_metrics(args, model, data_loader, device, logger):

    model.eval()
    old_mode = model.model_mode
    #model.set_model_mode(ModelModes.EVALUATION)
    model.training = False
    classi_acc_total = []
    n, N = 0, len(data_loader.dataset)
    input_filenames_total = list()
    output_filenames_total = list()
    q_bpp_total, q_bpp_total_attained, LPIPS_total = torch.Tensor(
        N), torch.Tensor(N), torch.Tensor(N)
    SSIM_total, PSNR_total = torch.Tensor(N), torch.Tensor(N)
    comp_loss_total, classi_loss_total, classi_acc_total1 = torch.Tensor(
        N), torch.Tensor(N), torch.Tensor(N)
    with torch.no_grad():
        thisIndx = 0
        for idx1, (dataAll, yAll) in enumerate(tqdm(data_loader), 0):
            dataAll = dataAll.to(device, dtype=torch.float)
            yAll = yAll.to(device)
            losses, intermediates = model(dataAll,
                                          yAll,
                                          return_intermediates=True,
                                          writeout=True)
            classi_acc = losses['classi_acc']
            classi_acc_total.append(classi_acc.item())

    # Reproducibility
    make_deterministic()
    perceptual_loss_fn = ps.PerceptualLoss(model='net-lin',
                                           net='alex',
                                           use_gpu=torch.cuda.is_available())

    # Build probability tables
    logger.info('Building hyperprior probability tables...')
    model.Hyperprior.hyperprior_entropy_model.build_tables()
    logger.info('All tables built.')

    max_value = 255.
    SSIM_func = metrics.SSIM(data_range=max_value)
    utils.makedirs(args.output_dir)

    logger.info('Starting compression...')
    start_time = time.time()

    with torch.no_grad():
        thisIndx = 0
        for idx1, (dataAll, yAll) in enumerate(tqdm(data_loader), 0):
            dataAll = dataAll.to(device, dtype=torch.float)
            yAll = yAll.to(device)
            #if idx1 > 2:
            #    break
            B = dataAll.size(0)
            for idxB in range(B):
                data = dataAll[idxB, :, :, :]
                data = data.unsqueeze(0)
                y = yAll[idxB]
                y = y.unsqueeze(0)
                model.set_model_mode(old_mode)
                model.training = False
                losses = model(data, y, train_generator=False)
                compression_loss = losses['compression']

                if model.use_classiOnly is True:
                    classi_loss = losses['classi']
                    classi_acc = losses['classi_acc']

                model.set_model_mode(ModelModes.EVALUATION)
                model.training = False
                # Perform entropy coding
                q_bpp_attained, compressed_output = model.compress(data,
                                                                   silent=True)

                if args.save is True:
                    compression_utils.save_compressed_format(
                        compressed_output,
                        out_path=os.path.join(args.output_dir,
                                              "compressed.hfc"))

                reconstruction = model.decompress(compressed_output)
                q_bpp = compressed_output.total_bpp

                if args.normalize_input_image is True:
                    # [-1., 1.] -> [0., 1.]
                    data = (data + 1.) / 2.

                perceptual_loss = perceptual_loss_fn.forward(reconstruction,
                                                             data,
                                                             normalize=True)

                # [0., 1.] -> [0., 255.]
                psnr = metrics.psnr(reconstruction.cpu().numpy() * max_value,
                                    data.cpu().numpy() * max_value, max_value)
                ms_ssim = SSIM_func(reconstruction * max_value,
                                    data * max_value)
                PSNR_total[thisIndx] = torch.Tensor(psnr)
                SSIM_total[thisIndx] = ms_ssim.data

                q_bpp_per_im = float(q_bpp.item()) if type(
                    q_bpp) == torch.Tensor else float(q_bpp)

                fname = os.path.join(
                    args.output_dir,
                    "{}_RECON_{:.3f}bpp.png".format(thisIndx, q_bpp_per_im))
                torchvision.utils.save_image(reconstruction,
                                             fname,
                                             normalize=True)
                output_filenames_total.append(fname)

                q_bpp_total[thisIndx] = q_bpp.data if type(
                    q_bpp) == torch.Tensor else q_bpp
                q_bpp_total_attained[thisIndx] = q_bpp_attained.data if type(
                    q_bpp_attained) == torch.Tensor else q_bpp_attained
                LPIPS_total[thisIndx] = perceptual_loss.data
                comp_loss_total[thisIndx] = compression_loss.data
                if model.use_classiOnly is True:
                    classi_loss_total[thisIndx] = classi_loss.data
                    classi_acc_total1[thisIndx] = classi_acc.data
                thisIndx = thisIndx + 1

    logger.info(
        f'BPP: mean={q_bpp_total.mean(dim=0):.3f}, std={q_bpp_total.std(dim=0):.3f}'
    )
    logger.info(
        f'BPPA: mean={q_bpp_total_attained.mean(dim=0):.3f}, std={q_bpp_total_attained.std(dim=0):.3f}'
    )
    logger.info(
        f'LPIPS: mean={LPIPS_total.mean(dim=0):.3f}, std={LPIPS_total.std(dim=0):.3f}'
    )
    logger.info(
        f'PSNR: mean={PSNR_total.mean(dim=0):.3f}, std={PSNR_total.std(dim=0):.3f}'
    )
    logger.info(
        f'SSIM: mean={SSIM_total.mean(dim=0):.3f}, std={SSIM_total.std(dim=0):.3f}'
    )
    logger.info(
        f'CompLoss: mean={comp_loss_total.mean(dim=0):.3f}, std={comp_loss_total.std(dim=0):.3f}'
    )
    logger.info(
        f'ClassiLoss: mean={classi_loss_total.mean(dim=0):.3f}, std={classi_loss_total.std(dim=0):.3f}'
    )
    logger.info(
        f'ClassiAcc1: mean={classi_acc_total1.mean(dim=0):.3f}, std={classi_acc_total1.std(dim=0):.3f}'
    )
    logger.info(f'ClassiAcc2: mean={np.mean(classi_acc_total):.3f}')
    #df = pd.DataFrame([input_filenames_total, output_filenames_total]).T
    #df.columns = ['input_filename', 'output_filename']
    #df['bpp_original'] = bpp_total.cpu().numpy()
    #df['q_bpp'] = q_bpp_total.cpu().numpy()
    #df['LPIPS'] = LPIPS_total.cpu().numpy()

    #df['PSNR'] = PSNR_total.cpu().numpy()
    #df['MS_SSIM'] = MS_SSIM_total.cpu().numpy()

    #df_path = os.path.join(args.output_dir, 'compression_metrics.h5')
    #df.to_hdf(df_path, key='df')

    #pprint(df)

    #logger.info('Complete. Reconstructions saved to {}. Output statistics saved to {}'.format(args.output_dir, df_path))
    delta_t = time.time() - start_time
    logger.info('Time elapsed: {:.3f} s'.format(delta_t))
    logger.info('Rate: {:.3f} Images / s:'.format(float(N) / delta_t))

    model.set_model_mode(old_mode)
def compress_and_decompress(args):

    # Reproducibility
    make_deterministic()
    perceptual_loss_fn = ps.PerceptualLoss(model='net-lin', net='alex', use_gpu=torch.cuda.is_available())

    # Load model
    device = utils.get_device()
    logger = utils.logger_setup(logpath=os.path.join(args.image_dir, 'logs'), filepath=os.path.abspath(__file__))
    loaded_args, model, _ = utils.load_model(args.ckpt_path, logger, device, model_mode=ModelModes.EVALUATION,
        current_args_d=None, prediction=True, strict=False)

    # Override current arguments with recorded
    dictify = lambda x: dict((n, getattr(x, n)) for n in dir(x) if not (n.startswith('__') or 'logger' in n))
    loaded_args_d, args_d = dictify(loaded_args), dictify(args)
    loaded_args_d.update(args_d)
    args = utils.Struct(**loaded_args_d)
    logger.info(loaded_args_d)

    # Build probability tables
    logger.info('Building hyperprior probability tables...')
    model.Hyperprior.hyperprior_entropy_model.build_tables()
    logger.info('All tables built.')


    eval_loader = datasets.get_dataloaders('evaluation', root=args.image_dir, batch_size=args.batch_size,
                                           logger=logger, shuffle=False, normalize=args.normalize_input_image)

    n, N = 0, len(eval_loader.dataset)
    input_filenames_total = list()
    output_filenames_total = list()
    bpp_total, q_bpp_total, LPIPS_total = torch.Tensor(N), torch.Tensor(N), torch.Tensor(N)
    utils.makedirs(args.output_dir)
    
    logger.info('Starting compression...')
    start_time = time.time()

    with torch.no_grad():

        for idx, (data, bpp, filenames) in enumerate(tqdm(eval_loader), 0):
            data = data.to(device, dtype=torch.float)
            B = data.size(0)
            input_filenames_total.extend(filenames)

            if args.reconstruct is True:
                # Reconstruction without compression
                reconstruction, q_bpp = model(data, writeout=False)
            else:
                # Perform entropy coding
                compressed_output = model.compress(data)

                if args.save is True:
                    assert B == 1, 'Currently only supports saving single images.'
                    compression_utils.save_compressed_format(compressed_output, 
                        out_path=os.path.join(args.output_dir, f"{filenames[0]}_compressed.hfc"))

                reconstruction = model.decompress(compressed_output)
                q_bpp = compressed_output.total_bpp

            if args.normalize_input_image is True:
                # [-1., 1.] -> [0., 1.]
                data = (data + 1.) / 2.

            perceptual_loss = perceptual_loss_fn.forward(reconstruction, data, normalize=True)


            for subidx in range(reconstruction.shape[0]):
                if B > 1:
                    q_bpp_per_im = float(q_bpp.cpu().numpy()[subidx])
                else:
                    q_bpp_per_im = float(q_bpp.item()) if type(q_bpp) == torch.Tensor else float(q_bpp)

                fname = os.path.join(args.output_dir, "{}_RECON_{:.3f}bpp.png".format(filenames[subidx], q_bpp_per_im))
                torchvision.utils.save_image(reconstruction[subidx], fname, normalize=True)
                output_filenames_total.append(fname)

            bpp_total[n:n + B] = bpp.data
            q_bpp_total[n:n + B] = q_bpp.data if type(q_bpp) == torch.Tensor else q_bpp
            LPIPS_total[n:n + B] = perceptual_loss.data
            n += B

    df = pd.DataFrame([input_filenames_total, output_filenames_total]).T
    df.columns = ['input_filename', 'output_filename']
    df['bpp_original'] = bpp_total.cpu().numpy()
    df['q_bpp'] = q_bpp_total.cpu().numpy()
    df['LPIPS'] = LPIPS_total.cpu().numpy()

    df_path = os.path.join(args.output_dir, 'compression_metrics.h5')
    df.to_hdf(df_path, key='df')

    pprint(df)

    logger.info('Complete. Reconstructions saved to {}. Output statistics saved to {}'.format(args.output_dir, df_path))
    delta_t = time.time() - start_time
    logger.info('Time elapsed: {:.3f} s'.format(delta_t))
    logger.info('Rate: {:.3f} Images / s:'.format(float(N) / delta_t))
 def on_pre_performance_check(self):
     if self.init == True:
         self.init = False
         self.perceptual_loss = ps.PerceptualLoss(model='net-lin', net='alex', use_gpu=torch.cuda.is_available(),
                                                  gpu_ids=[self.device])