示例#1
0
def get_run_params(current_directory):

    ae_config, ae_config_rel_path = config_parser.parse(args.ae_config_path)
    pc_config, pc_config_rel_path = config_parser.parse(args.pc_config_path)
    run_dict = {'ae_config': ae_config,
                'ae_config_rel_path': ae_config_rel_path,
                'pc_config': pc_config,
                'pc_config_rel_path': pc_config_rel_path,
                'total_iterations': ae_config.iterations,
                'batch_size': ae_config.batch_size,
                'root_weights': current_directory + '/weights/',
                'root_save_img': current_directory + '/images/',
                'show_every': ae_config.show_every,
                'validate_every': ae_config.validate_every,  # max number of iteration between validation steps
                'decrease_val_steps': ae_config.decrease_val_steps,
                'load_model_name': ae_config.load_model_name,
                'load_model': ae_config.load_model,
                'train_model': ae_config.train_model,
                'test_model': ae_config.test_model,
                'save_model': ae_config.save_model,
                'plot_test_img': False,
                'save_test_img': True,
                'plot_loss_graph': False,
                'save_loss_graph': False,
                'create_loss_list': False,
                'save_config': True,
                }

    return run_dict
示例#2
0
def load_classifier(clf_ckpt_p):
    clf_config_p, postfix = _parse_clf_ckpt_p(clf_ckpt_p)
    print(f'Using classifier with config {clf_config_p}')
    clf_config, _ = config_parser.parse(clf_config_p)
    # if postfix:
    #     print('Adding from postfix...', postfix)
    #     c = _GlobalConfig()
    #     c.add_from_flag(postfix)
    #     print('Updaing config with', c)
    #     c.update_config(clf_config)
    clf = ClassifierNetwork(clf_config)
    clf.to(pe.DEVICE)
    print(clf)
    map_location = None if pe.CUDA_AVAILABLE else 'cpu'
    # clf_checkpoint_p = Restorer(ckpts_p).get_latest_ckpt()
    print('Restoring', clf_ckpt_p)
    state_dicts = torch.load(clf_ckpt_p, map_location=map_location)
    clf.load_state_dict(state_dicts['net'])
    print(f'Loaded!')
    return clf
示例#3
0
def validate(val_dirs: ValidationDirs, images_iterator: ImagesIterator,
             flags: OutputFlags):
    """
    Saves in val_dirs.log_dir/val/dataset_name/measures.csv:
        - `img_name,bpp,psnr,ms-ssim forall img_name`
    """
    print(_VALIDATION_INFO_STR)

    validated_checkpoints = val_dirs.get_validated_checkpoints(
    )  # :: [10000, 18000, ..., 256000], ie, [int]
    all_ckpts = Saver.all_ckpts_with_iterations(val_dirs.ckpt_dir)
    if len(all_ckpts) == 0:
        print('No checkpoints found in {}'.format(val_dirs.ckpt_dir))
        return
    # if ckpt_step is -1, then all_ckpt[:-1:flags.ckpt_step] === [] because of how strides work
    ckpt_to_check = all_ckpts[:-1:flags.ckpt_step] + [
        all_ckpts[-1]
    ]  # every ckpt_step-th checkpoint plus the last one
    if flags.ckpt_step == -1:
        assert len(ckpt_to_check) == 1
    print('Validating {}/{} checkpoints (--ckpt_step {})...'.format(
        len(ckpt_to_check), len(all_ckpts), flags.ckpt_step))

    missing_checkpoints = [(ckpt_itr, ckpt_path)
                           for ckpt_itr, ckpt_path in ckpt_to_check
                           if ckpt_itr not in validated_checkpoints]
    if len(missing_checkpoints) == 0:
        print('All checkpoints validated, stopping...')
        return

    # ---

    # create networks
    autoencoder_config_path, probclass_config_path = logdir_helpers.config_paths_from_log_dir(
        val_dirs.log_dir,
        base_dirs=[constants.CONFIG_BASE_AE, constants.CONFIG_BASE_PC])
    ae_config, ae_config_rel_path = config_parser.parse(
        autoencoder_config_path)
    pc_config, pc_config_rel_path = config_parser.parse(probclass_config_path)

    ae_cls = autoencoder.get_network_cls(ae_config)
    pc_cls = probclass.get_network_cls(pc_config)

    # Instantiate autoencoder and probability classifier
    ae = ae_cls(ae_config)
    pc = pc_cls(pc_config, num_centers=ae_config.num_centers)

    x_val_ph = tf.placeholder(tf.uint8, (3, None, None), name='x_val_ph')
    x_val_uint8 = tf.expand_dims(x_val_ph, 0, name='batch')
    x_val = tf.to_float(x_val_uint8, name='x_val')

    enc_out_val = ae.encode(x_val, is_training=False)
    x_out_val = ae.decode(enc_out_val.qhard, is_training=False)

    bc_val = pc.bitcost(enc_out_val.qbar,
                        enc_out_val.symbols,
                        is_training=False,
                        pad_value=pc.auto_pad_value(ae))
    bpp_val = bits.bitcost_to_bpp(bc_val, x_val)

    x_out_val_uint8 = tf.cast(x_out_val, tf.uint8, name='x_out_val_uint8')
    # Using numpy implementation due to dynamic shapes
    msssim_val = ms_ssim_np.tf_msssim_np(x_val_uint8,
                                         x_out_val_uint8,
                                         data_format='NCHW')
    psnr_val = psnr_np(x_val_uint8, x_out_val_uint8)

    restorer = Saver(val_dirs.ckpt_dir,
                     var_list=Saver.get_var_list_of_ckpt_dir(
                         val_dirs.ckpt_dir))

    # create fetch_dict
    fetch_dict = {
        'bpp': bpp_val,
        'ms-ssim': msssim_val,
        'psnr': psnr_val,
    }

    if flags.real_bpp:
        fetch_dict['sym'] = enc_out_val.symbols  # NCHW

    if flags.save_ours:
        fetch_dict['img_out'] = x_out_val_uint8

    # ---
    fw = tf.summary.FileWriter(val_dirs.out_dir, graph=tf.get_default_graph())

    def full_summary_tag(summary_name):
        return '/'.join(['val', images_iterator.dataset_name, summary_name])

    # Distance
    try:
        codec_distance_ms_ssim = CodecDistance(images_iterator.dataset_name,
                                               codec='bpg',
                                               metric='ms-ssim')
        codec_distance_psnr = CodecDistance(images_iterator.dataset_name,
                                            codec='bpg',
                                            metric='psnr')
    except CodecDistanceReadException as e:  # no codec distance values stored for the current setup
        print('*** Distance to BPG not available for {}:\n{}'.format(
            images_iterator.dataset_name, str(e)))
        codec_distance_ms_ssim = None
        codec_distance_psnr = None

    # Note that for each checkpoint, the structure of the network will be the same. Thus the pad depending image
    # loading can be cached.

    # create session
    with tf_helpers.create_session() as sess:
        if flags.real_bpp:
            pred = probclass.PredictionNetwork(pc, pc_config,
                                               ae.get_centers_variable(), sess)
            checker = probclass.ProbclassNetworkTesting(pc, ae, sess)
            bpp_fetcher = bpp_helpers.BppFetcher(pred, checker)

        fetcher = sess.make_callable(fetch_dict, feed_list=[x_val_ph])

        last_ckpt_itr = missing_checkpoints[-1][0]
        for ckpt_itr, ckpt_path in missing_checkpoints:
            if not ckpt_still_exists(ckpt_path):
                # May happen if job is still training
                print('Checkpoint disappeared: {}'.format(ckpt_path))
                continue

            print(_CKPT_ITR_INFO_STR.format(ckpt_itr))

            restorer.restore_ckpt(sess, ckpt_path)

            values_aggregator = ValuesAggregator('bpp', 'ms-ssim', 'psnr')

            # truncates the previous measures.csv file! This way, only the last valid checkpoint is saved.
            measures_writer = MeasuresWriter(val_dirs.out_dir)

            # ----------------------------------------
            # iterate over images
            # images are padded to work with current auto encoder
            for img_i, (img_name, img_content) in enumerate(
                    images_iterator.iter_imgs(
                        pad=ae.get_subsampling_factor())):
                otp = fetcher(img_content)
                measures_writer.append(img_name, otp)

                if flags.real_bpp:
                    # Calculate
                    bpp_real, bpp_theory = bpp_fetcher.get_bpp(
                        otp['sym'],
                        bpp_helpers.num_pixels_in_image(img_content))

                    # Logging
                    bpp_loss = otp['bpp']
                    diff_percent_tr = (bpp_theory / bpp_real) * 100
                    diff_percent_lt = (bpp_loss / bpp_theory) * 100
                    print('BPP: Real         {:.5f}\n'
                          '     Theoretical: {:.5f} [{:5.1f}% of real]\n'
                          '     Loss:        {:.5f} [{:5.1f}% of real]'.format(
                              bpp_real, bpp_theory, diff_percent_tr, bpp_loss,
                              diff_percent_lt))
                    assert abs(
                        bpp_theory - bpp_loss
                    ) < 1e-3, 'Expected bpp_theory to match loss! Got {} and {}'.format(
                        bpp_theory, bpp_loss)

                if flags.save_ours and ckpt_itr == last_ckpt_itr:
                    save_img(img_name, otp['img_out'], val_dirs)

                values_aggregator.update(otp)

                print('{: 10d} {img_name} | Mean: {avgs}'.format(
                    img_i,
                    img_name=img_name,
                    avgs=values_aggregator.averages_str()),
                      end=('\r' if not flags.real_bpp else '\n'),
                      flush=True)

            measures_writer.close()

            print()  # add newline
            avgs = values_aggregator.averages()
            avg_bpp, avg_ms_ssim, avg_psnr = avgs['bpp'], avgs[
                'ms-ssim'], avgs['psnr']

            tf_helpers.log_values(
                fw, [(full_summary_tag('avg_bpp'), avg_bpp),
                     (full_summary_tag('avg_ms_ssim'), avg_ms_ssim),
                     (full_summary_tag('avg_psnr'), avg_psnr)],
                iteration=ckpt_itr)

            if codec_distance_ms_ssim and codec_distance_psnr:
                try:
                    d_ms_ssim = codec_distance_ms_ssim.distance(
                        avg_bpp, avg_ms_ssim)
                    d_pnsr = codec_distance_psnr.distance(avg_bpp, avg_psnr)
                    print('Distance to BPG: {:.3f} ms-ssim // {:.3f} psnr'.
                          format(d_ms_ssim, d_pnsr))
                    tf_helpers.log_values(
                        fw,
                        [(full_summary_tag('distance_BPG_MS-SSIM'), d_ms_ssim),
                         (full_summary_tag('distance_BPG_PSNR'), d_pnsr)],
                        iteration=ckpt_itr)
                except ValueError as e:  # out of range errors from distance calls
                    print(e)

            val_dirs.add_validated_checkpoint(ckpt_itr)

    print('Validation completed {}'.format(val_dirs))
示例#4
0
def train(autoencoder_config_path, probclass_config_path,
          restore_manager: RestoreManager, log_dir_root, datasets: Datasets,
          train_flags: TrainFlags, ckpt_interval_hours: float,
          description: str):
    ae_config, ae_config_rel_path = config_parser.parse(
        autoencoder_config_path)
    pc_config, pc_config_rel_path = config_parser.parse(probclass_config_path)
    print_configs(('ae_config', ae_config), ('pc_config', pc_config))

    continue_in_ckpt_dir = restore_manager and restore_manager.continue_in_ckpt_dir
    if continue_in_ckpt_dir:
        logdir = restore_manager.log_dir
    else:
        logdir = logdir_helpers.create_unique_log_dir(
            [ae_config_rel_path, pc_config_rel_path],
            log_dir_root,
            restore_dir=restore_manager.ckpt_dir if restore_manager else None)
    print(_LOG_DIR_FORMAT.format(logdir))

    if description:
        _write_to_sheets(logdir_helpers.log_date_from_log_dir(logdir),
                         ae_config_rel_path,
                         pc_config_rel_path,
                         description,
                         git_ref=_get_git_ref(),
                         log_dir_root=log_dir_root,
                         is_continue=continue_in_ckpt_dir)

    ae_cls = autoencoder.get_network_cls(ae_config)
    pc_cls = probclass.get_network_cls(pc_config)

    # Instantiate autoencoder and probability classifier
    ae = ae_cls(ae_config)
    pc = pc_cls(pc_config, num_centers=ae_config.num_centers)

    # train ---
    ip_train = inputpipeline.InputPipeline(
        inputpipeline.get_dataset(datasets.train),
        ae_config.crop_size,
        batch_size=ae_config.batch_size,
        shuffle=False,
        num_preprocess_threads=NUM_PREPROCESS_THREADS,
        num_crops_per_img=NUM_CROPS_PER_IMG)
    x_train = ip_train.get_batch()

    enc_out_train = ae.encode(
        x_train, is_training=True)  # qbar is masked by the heatmap
    x_out_train = ae.decode(enc_out_train.qbar, is_training=True)
    # stop_gradient is beneficial for training. it prevents multiple gradients flowing into the heatmap.
    pc_in = tf.stop_gradient(enc_out_train.qbar)
    bc_train = pc.bitcost(pc_in,
                          enc_out_train.symbols,
                          is_training=True,
                          pad_value=pc.auto_pad_value(ae))
    bpp_train = bits.bitcost_to_bpp(bc_train, x_train)
    d_train = Distortions(ae_config, x_train, x_out_train, is_training=True)
    # summing over channel dimension gives 2D heatmap
    heatmap2D = (tf.reduce_sum(enc_out_train.heatmap, 1)
                 if enc_out_train.heatmap is not None else None)

    # loss ---
    total_loss, H_real, pc_comps, ae_comps = get_loss(ae_config, ae, pc,
                                                      d_train.d_loss_scaled,
                                                      bc_train,
                                                      enc_out_train.heatmap)
    train_op = get_train_op(ae_config, pc_config, ip_train, pc.variables(),
                            total_loss)

    # test ---
    with tf.name_scope('test'):
        ip_test = inputpipeline.InputPipeline(
            inputpipeline.get_dataset(datasets.test),
            ae_config.crop_size,
            batch_size=ae_config.batch_size,
            num_preprocess_threads=NUM_PREPROCESS_THREADS,
            num_crops_per_img=1,
            big_queues=False,
            shuffle=False)
        x_test = ip_test.get_batch()
        enc_out_test = ae.encode(x_test, is_training=False)
        x_out_test = ae.decode(enc_out_test.qhard, is_training=False)
        bc_test = pc.bitcost(enc_out_test.qhard,
                             enc_out_test.symbols,
                             is_training=False,
                             pad_value=pc.auto_pad_value(ae))
        bpp_test = bits.bitcost_to_bpp(bc_test, x_test)
        d_test = Distortions(ae_config, x_test, x_out_test, is_training=False)

    try:  # Try to get codec distnace for current dataset
        codec_distance_ms_ssim = CodecDistance(datasets.codec_distance,
                                               codec='bpg',
                                               metric='ms-ssim')
        get_distance = functools_ext.catcher(ValueError,
                                             handler=functools_ext.const(
                                                 np.nan),
                                             f=codec_distance_ms_ssim.distance)
        get_distance = functools_ext.compose(np.float32,
                                             get_distance)  # cast to float32
        d_BPG_test = tf.py_func(get_distance, [bpp_test, d_test.ms_ssim],
                                tf.float32,
                                stateful=False,
                                name='d_BPG')
        d_BPG_test.set_shape(())
    except CodecDistanceReadException as e:
        print('Cannot compute CodecDistance: {}'.format(e))
        d_BPG_test = tf.constant(np.nan, shape=(), name='ConstNaN')

    # ---

    train_logger = Logger()
    test_logger = Logger()
    distortion_name = ae_config.distortion_to_minimize

    train_logger.add_summaries(d_train.summaries_with_prefix('train'))
    # Visualize components of losses
    train_logger.add_summaries([
        tf.summary.scalar('train/PC_loss/{}'.format(name), comp)
        for name, comp in pc_comps
    ])
    train_logger.add_summaries([
        tf.summary.scalar('train/AE_loss/{}'.format(name), comp)
        for name, comp in ae_comps
    ])
    train_logger.add_summaries([tf.summary.scalar('train/bpp', bpp_train)])
    train_logger.add_console_tensor('loss={:.3f}', total_loss)
    train_logger.add_console_tensor('ms_ssim={:.3f}', d_train.ms_ssim)
    train_logger.add_console_tensor('bpp={:.3f}', bpp_train)
    train_logger.add_console_tensor('H_real={:.3f}', H_real)

    test_logger.add_summaries(d_test.summaries_with_prefix('test'))
    test_logger.add_summaries([
        tf.summary.scalar('test/bpp', bpp_test),
        tf.summary.scalar('test/distance_BPG_MS-SSIM', d_BPG_test),
        tf.summary.image('test/x_in',
                         prep_for_image_summary(x_test, n=3, name='x_in')),
        tf.summary.image('test/x_out',
                         prep_for_image_summary(x_out_test, n=3, name='x_out'))
    ])
    if heatmap2D is not None:
        test_logger.add_summaries([
            tf.summary.image(
                'test/hm',
                prep_for_grayscale_image_summary(heatmap2D,
                                                 n=3,
                                                 autoscale=True,
                                                 name='hm'))
        ])

    test_logger.add_console_tensor('ms_ssim={:.3f}', d_test.ms_ssim)
    test_logger.add_console_tensor('bpp={:.3f}', bpp_test)
    test_logger.add_summaries([
        tf.summary.histogram('centers', ae.get_centers_variable()),
        tf.summary.histogram(
            'test/qbar', enc_out_test.qbar[:ae_config.batch_size // 2, ...])
    ])
    test_logger.add_console_tensor('d_BPG={:.6f}', d_BPG_test)
    test_logger.add_console_tensor(Logger.Numpy1DFormatter('centers={}'),
                                   ae.get_centers_variable())

    print('Starting session and queues...')
    with tf_helpers.start_queues_in_sess(
            init_vars=restore_manager is None) as (sess, coord):
        train_logger.finalize_with_sess(sess)
        test_logger.finalize_with_sess(sess)

        if restore_manager:
            restore_manager.restore(sess)

        saver = Saver(Saver.ckpt_dir_for_log_dir(logdir),
                      max_to_keep=1,
                      keep_checkpoint_every_n_hours=ckpt_interval_hours)

        train_loop(ae_config,
                   sess,
                   coord,
                   train_op,
                   train_logger,
                   test_logger,
                   train_flags,
                   logdir,
                   saver,
                   is_restored=restore_manager is not None)