Esempio n. 1
0
 def restore(self, sess):
     skip_var_names = self.skip_var_names.split(
         ',') if self.skip_var_names else []
     var_list = Saver.get_var_list_of_ckpt_dir(
         self.ckpt_dir, skip_var_names=skip_var_names)
     print('Restoring {} variables...'.format(len(var_list)))
     Saver(self.ckpt_dir,
           var_list=var_list).restore_at_itr(sess, restore_itr=self.itr)
Esempio n. 2
0
def validate(val_dirs: ValidationDirs, images_iterator: ImagesIterator,
             flags: OutputFlags):
    """
    Saves in val_dirs.log_dir/val/dataset_name/measures.csv:
        - `img_name,bpp,psnr,ms-ssim forall img_name`
    """
    print(_VALIDATION_INFO_STR)

    validated_checkpoints = val_dirs.get_validated_checkpoints(
    )  # :: [10000, 18000, ..., 256000], ie, [int]
    all_ckpts = Saver.all_ckpts_with_iterations(val_dirs.ckpt_dir)
    if len(all_ckpts) == 0:
        print('No checkpoints found in {}'.format(val_dirs.ckpt_dir))
        return
    # if ckpt_step is -1, then all_ckpt[:-1:flags.ckpt_step] === [] because of how strides work
    ckpt_to_check = all_ckpts[:-1:flags.ckpt_step] + [
        all_ckpts[-1]
    ]  # every ckpt_step-th checkpoint plus the last one
    if flags.ckpt_step == -1:
        assert len(ckpt_to_check) == 1
    print('Validating {}/{} checkpoints (--ckpt_step {})...'.format(
        len(ckpt_to_check), len(all_ckpts), flags.ckpt_step))

    missing_checkpoints = [(ckpt_itr, ckpt_path)
                           for ckpt_itr, ckpt_path in ckpt_to_check
                           if ckpt_itr not in validated_checkpoints]
    if len(missing_checkpoints) == 0:
        print('All checkpoints validated, stopping...')
        return

    # ---

    # create networks
    autoencoder_config_path, probclass_config_path = logdir_helpers.config_paths_from_log_dir(
        val_dirs.log_dir,
        base_dirs=[constants.CONFIG_BASE_AE, constants.CONFIG_BASE_PC])
    ae_config, ae_config_rel_path = config_parser.parse(
        autoencoder_config_path)
    pc_config, pc_config_rel_path = config_parser.parse(probclass_config_path)

    ae_cls = autoencoder.get_network_cls(ae_config)
    pc_cls = probclass.get_network_cls(pc_config)

    # Instantiate autoencoder and probability classifier
    ae = ae_cls(ae_config)
    pc = pc_cls(pc_config, num_centers=ae_config.num_centers)

    x_val_ph = tf.placeholder(tf.uint8, (3, None, None), name='x_val_ph')
    x_val_uint8 = tf.expand_dims(x_val_ph, 0, name='batch')
    x_val = tf.to_float(x_val_uint8, name='x_val')

    enc_out_val = ae.encode(x_val, is_training=False)
    x_out_val = ae.decode(enc_out_val.qhard, is_training=False)

    bc_val = pc.bitcost(enc_out_val.qbar,
                        enc_out_val.symbols,
                        is_training=False,
                        pad_value=pc.auto_pad_value(ae))
    bpp_val = bits.bitcost_to_bpp(bc_val, x_val)

    x_out_val_uint8 = tf.cast(x_out_val, tf.uint8, name='x_out_val_uint8')
    # Using numpy implementation due to dynamic shapes
    msssim_val = ms_ssim_np.tf_msssim_np(x_val_uint8,
                                         x_out_val_uint8,
                                         data_format='NCHW')
    psnr_val = psnr_np(x_val_uint8, x_out_val_uint8)

    restorer = Saver(val_dirs.ckpt_dir,
                     var_list=Saver.get_var_list_of_ckpt_dir(
                         val_dirs.ckpt_dir))

    # create fetch_dict
    fetch_dict = {
        'bpp': bpp_val,
        'ms-ssim': msssim_val,
        'psnr': psnr_val,
    }

    if flags.real_bpp:
        fetch_dict['sym'] = enc_out_val.symbols  # NCHW

    if flags.save_ours:
        fetch_dict['img_out'] = x_out_val_uint8

    # ---
    fw = tf.summary.FileWriter(val_dirs.out_dir, graph=tf.get_default_graph())

    def full_summary_tag(summary_name):
        return '/'.join(['val', images_iterator.dataset_name, summary_name])

    # Distance
    try:
        codec_distance_ms_ssim = CodecDistance(images_iterator.dataset_name,
                                               codec='bpg',
                                               metric='ms-ssim')
        codec_distance_psnr = CodecDistance(images_iterator.dataset_name,
                                            codec='bpg',
                                            metric='psnr')
    except CodecDistanceReadException as e:  # no codec distance values stored for the current setup
        print('*** Distance to BPG not available for {}:\n{}'.format(
            images_iterator.dataset_name, str(e)))
        codec_distance_ms_ssim = None
        codec_distance_psnr = None

    # Note that for each checkpoint, the structure of the network will be the same. Thus the pad depending image
    # loading can be cached.

    # create session
    with tf_helpers.create_session() as sess:
        if flags.real_bpp:
            pred = probclass.PredictionNetwork(pc, pc_config,
                                               ae.get_centers_variable(), sess)
            checker = probclass.ProbclassNetworkTesting(pc, ae, sess)
            bpp_fetcher = bpp_helpers.BppFetcher(pred, checker)

        fetcher = sess.make_callable(fetch_dict, feed_list=[x_val_ph])

        last_ckpt_itr = missing_checkpoints[-1][0]
        for ckpt_itr, ckpt_path in missing_checkpoints:
            if not ckpt_still_exists(ckpt_path):
                # May happen if job is still training
                print('Checkpoint disappeared: {}'.format(ckpt_path))
                continue

            print(_CKPT_ITR_INFO_STR.format(ckpt_itr))

            restorer.restore_ckpt(sess, ckpt_path)

            values_aggregator = ValuesAggregator('bpp', 'ms-ssim', 'psnr')

            # truncates the previous measures.csv file! This way, only the last valid checkpoint is saved.
            measures_writer = MeasuresWriter(val_dirs.out_dir)

            # ----------------------------------------
            # iterate over images
            # images are padded to work with current auto encoder
            for img_i, (img_name, img_content) in enumerate(
                    images_iterator.iter_imgs(
                        pad=ae.get_subsampling_factor())):
                otp = fetcher(img_content)
                measures_writer.append(img_name, otp)

                if flags.real_bpp:
                    # Calculate
                    bpp_real, bpp_theory = bpp_fetcher.get_bpp(
                        otp['sym'],
                        bpp_helpers.num_pixels_in_image(img_content))

                    # Logging
                    bpp_loss = otp['bpp']
                    diff_percent_tr = (bpp_theory / bpp_real) * 100
                    diff_percent_lt = (bpp_loss / bpp_theory) * 100
                    print('BPP: Real         {:.5f}\n'
                          '     Theoretical: {:.5f} [{:5.1f}% of real]\n'
                          '     Loss:        {:.5f} [{:5.1f}% of real]'.format(
                              bpp_real, bpp_theory, diff_percent_tr, bpp_loss,
                              diff_percent_lt))
                    assert abs(
                        bpp_theory - bpp_loss
                    ) < 1e-3, 'Expected bpp_theory to match loss! Got {} and {}'.format(
                        bpp_theory, bpp_loss)

                if flags.save_ours and ckpt_itr == last_ckpt_itr:
                    save_img(img_name, otp['img_out'], val_dirs)

                values_aggregator.update(otp)

                print('{: 10d} {img_name} | Mean: {avgs}'.format(
                    img_i,
                    img_name=img_name,
                    avgs=values_aggregator.averages_str()),
                      end=('\r' if not flags.real_bpp else '\n'),
                      flush=True)

            measures_writer.close()

            print()  # add newline
            avgs = values_aggregator.averages()
            avg_bpp, avg_ms_ssim, avg_psnr = avgs['bpp'], avgs[
                'ms-ssim'], avgs['psnr']

            tf_helpers.log_values(
                fw, [(full_summary_tag('avg_bpp'), avg_bpp),
                     (full_summary_tag('avg_ms_ssim'), avg_ms_ssim),
                     (full_summary_tag('avg_psnr'), avg_psnr)],
                iteration=ckpt_itr)

            if codec_distance_ms_ssim and codec_distance_psnr:
                try:
                    d_ms_ssim = codec_distance_ms_ssim.distance(
                        avg_bpp, avg_ms_ssim)
                    d_pnsr = codec_distance_psnr.distance(avg_bpp, avg_psnr)
                    print('Distance to BPG: {:.3f} ms-ssim // {:.3f} psnr'.
                          format(d_ms_ssim, d_pnsr))
                    tf_helpers.log_values(
                        fw,
                        [(full_summary_tag('distance_BPG_MS-SSIM'), d_ms_ssim),
                         (full_summary_tag('distance_BPG_PSNR'), d_pnsr)],
                        iteration=ckpt_itr)
                except ValueError as e:  # out of range errors from distance calls
                    print(e)

            val_dirs.add_validated_checkpoint(ckpt_itr)

    print('Validation completed {}'.format(val_dirs))