示例#1
0
def train_loop(config,
               sess,
               coord,
               train_op,
               train_logger: Logger,
               test_logger: Logger,
               train_flags: TrainFlags,
               log_dir,
               saver: Saver,
               is_restored=False):
    global_step = tf.train.get_or_create_global_step()

    job_id = logdir_helpers.log_date_from_log_dir(log_dir)
    fw = tf.summary.FileWriter(log_dir, graph=sess.graph)

    training_timer = _Timer(train_flags.log_interval_train, config.batch_size)
    itr = 0
    num_metadata_runs = 0

    if is_restored:
        itr = sess.run(global_step)
        train_logger.log().to_tensorboard(fw,
                                          itr).to_console(itr,
                                                          append='Restored')
        test_logger.log().to_tensorboard(fw, itr).to_console(itr)

    print(_STARTING_TRAINING_INFO_STR)
    while not coord.should_stop():
        if (train_flags.log_run_metadata
                and num_metadata_runs < _MAX_METADATA_RUNS
                and (itr % (train_flags.log_interval_train - 1) == 0)):
            print('Logging run metadata...', end=' ')
            num_metadata_runs += 1
            (_, itr), run_metadata = run_and_fetch_metadata(
                [train_op, global_step], sess)
            fw.add_run_metadata(run_metadata, str(itr), itr)
            print('Done')
        else:
            _, itr = sess.run([train_op, global_step])

        # Train Logging --
        if itr % train_flags.log_interval_train == 0:
            info_str = '(img/s: {:.1f}) {}'.format(
                training_timer.get_avg_ex_per_sec(), job_id)
            train_logger.log().to_tensorboard(fw,
                                              itr).to_console(itr,
                                                              append=info_str)

        # Save --
        if itr % train_flags.log_interval_save == 0:
            print('Saving...')
            saver.save(sess, global_step)

        # Test Logging --
        if train_flags.log_interval_test > 0 and itr % train_flags.log_interval_test == 0:
            test_logger.log().to_tensorboard(fw, itr).to_console(itr)

        if itr % train_flags.log_interval_train == 0:  # Reset after all above for accurate timings
            training_timer.reset()
示例#2
0
    def __init__(self, ckpt_dir, log_dir_root, dataset_name, reset=False):
        self.ckpt_dir = ckpt_dir
        self.log_dir = Saver.log_dir_from_ckpt_dir(self.ckpt_dir)
        self.log_dir_root = log_dir_root
        self.dataset_name = dataset_name

        log_date = logdir_helpers.log_date_from_log_dir(self.log_dir)
        self.out_dir = path.join(
            self.log_dir_root,
            '{log_date} {dataset_name}'.format(log_date=log_date,
                                               dataset_name=dataset_name))
        self.validated_ckpts_f = path.join(self.out_dir, 'validated_ckpts.pkl')

        if reset:
            self._reset()
示例#3
0
def train(autoencoder_config_path, probclass_config_path,
          restore_manager: RestoreManager, log_dir_root, datasets: Datasets,
          train_flags: TrainFlags, ckpt_interval_hours: float,
          description: str):
    ae_config, ae_config_rel_path = config_parser.parse(
        autoencoder_config_path)
    pc_config, pc_config_rel_path = config_parser.parse(probclass_config_path)
    print_configs(('ae_config', ae_config), ('pc_config', pc_config))

    continue_in_ckpt_dir = restore_manager and restore_manager.continue_in_ckpt_dir
    if continue_in_ckpt_dir:
        logdir = restore_manager.log_dir
    else:
        logdir = logdir_helpers.create_unique_log_dir(
            [ae_config_rel_path, pc_config_rel_path],
            log_dir_root,
            restore_dir=restore_manager.ckpt_dir if restore_manager else None)
    print(_LOG_DIR_FORMAT.format(logdir))

    if description:
        _write_to_sheets(logdir_helpers.log_date_from_log_dir(logdir),
                         ae_config_rel_path,
                         pc_config_rel_path,
                         description,
                         git_ref=_get_git_ref(),
                         log_dir_root=log_dir_root,
                         is_continue=continue_in_ckpt_dir)

    ae_cls = autoencoder.get_network_cls(ae_config)
    pc_cls = probclass.get_network_cls(pc_config)

    # Instantiate autoencoder and probability classifier
    ae = ae_cls(ae_config)
    pc = pc_cls(pc_config, num_centers=ae_config.num_centers)

    # train ---
    ip_train = inputpipeline.InputPipeline(
        inputpipeline.get_dataset(datasets.train),
        ae_config.crop_size,
        batch_size=ae_config.batch_size,
        shuffle=False,
        num_preprocess_threads=NUM_PREPROCESS_THREADS,
        num_crops_per_img=NUM_CROPS_PER_IMG)
    x_train = ip_train.get_batch()

    enc_out_train = ae.encode(
        x_train, is_training=True)  # qbar is masked by the heatmap
    x_out_train = ae.decode(enc_out_train.qbar, is_training=True)
    # stop_gradient is beneficial for training. it prevents multiple gradients flowing into the heatmap.
    pc_in = tf.stop_gradient(enc_out_train.qbar)
    bc_train = pc.bitcost(pc_in,
                          enc_out_train.symbols,
                          is_training=True,
                          pad_value=pc.auto_pad_value(ae))
    bpp_train = bits.bitcost_to_bpp(bc_train, x_train)
    d_train = Distortions(ae_config, x_train, x_out_train, is_training=True)
    # summing over channel dimension gives 2D heatmap
    heatmap2D = (tf.reduce_sum(enc_out_train.heatmap, 1)
                 if enc_out_train.heatmap is not None else None)

    # loss ---
    total_loss, H_real, pc_comps, ae_comps = get_loss(ae_config, ae, pc,
                                                      d_train.d_loss_scaled,
                                                      bc_train,
                                                      enc_out_train.heatmap)
    train_op = get_train_op(ae_config, pc_config, ip_train, pc.variables(),
                            total_loss)

    # test ---
    with tf.name_scope('test'):
        ip_test = inputpipeline.InputPipeline(
            inputpipeline.get_dataset(datasets.test),
            ae_config.crop_size,
            batch_size=ae_config.batch_size,
            num_preprocess_threads=NUM_PREPROCESS_THREADS,
            num_crops_per_img=1,
            big_queues=False,
            shuffle=False)
        x_test = ip_test.get_batch()
        enc_out_test = ae.encode(x_test, is_training=False)
        x_out_test = ae.decode(enc_out_test.qhard, is_training=False)
        bc_test = pc.bitcost(enc_out_test.qhard,
                             enc_out_test.symbols,
                             is_training=False,
                             pad_value=pc.auto_pad_value(ae))
        bpp_test = bits.bitcost_to_bpp(bc_test, x_test)
        d_test = Distortions(ae_config, x_test, x_out_test, is_training=False)

    try:  # Try to get codec distnace for current dataset
        codec_distance_ms_ssim = CodecDistance(datasets.codec_distance,
                                               codec='bpg',
                                               metric='ms-ssim')
        get_distance = functools_ext.catcher(ValueError,
                                             handler=functools_ext.const(
                                                 np.nan),
                                             f=codec_distance_ms_ssim.distance)
        get_distance = functools_ext.compose(np.float32,
                                             get_distance)  # cast to float32
        d_BPG_test = tf.py_func(get_distance, [bpp_test, d_test.ms_ssim],
                                tf.float32,
                                stateful=False,
                                name='d_BPG')
        d_BPG_test.set_shape(())
    except CodecDistanceReadException as e:
        print('Cannot compute CodecDistance: {}'.format(e))
        d_BPG_test = tf.constant(np.nan, shape=(), name='ConstNaN')

    # ---

    train_logger = Logger()
    test_logger = Logger()
    distortion_name = ae_config.distortion_to_minimize

    train_logger.add_summaries(d_train.summaries_with_prefix('train'))
    # Visualize components of losses
    train_logger.add_summaries([
        tf.summary.scalar('train/PC_loss/{}'.format(name), comp)
        for name, comp in pc_comps
    ])
    train_logger.add_summaries([
        tf.summary.scalar('train/AE_loss/{}'.format(name), comp)
        for name, comp in ae_comps
    ])
    train_logger.add_summaries([tf.summary.scalar('train/bpp', bpp_train)])
    train_logger.add_console_tensor('loss={:.3f}', total_loss)
    train_logger.add_console_tensor('ms_ssim={:.3f}', d_train.ms_ssim)
    train_logger.add_console_tensor('bpp={:.3f}', bpp_train)
    train_logger.add_console_tensor('H_real={:.3f}', H_real)

    test_logger.add_summaries(d_test.summaries_with_prefix('test'))
    test_logger.add_summaries([
        tf.summary.scalar('test/bpp', bpp_test),
        tf.summary.scalar('test/distance_BPG_MS-SSIM', d_BPG_test),
        tf.summary.image('test/x_in',
                         prep_for_image_summary(x_test, n=3, name='x_in')),
        tf.summary.image('test/x_out',
                         prep_for_image_summary(x_out_test, n=3, name='x_out'))
    ])
    if heatmap2D is not None:
        test_logger.add_summaries([
            tf.summary.image(
                'test/hm',
                prep_for_grayscale_image_summary(heatmap2D,
                                                 n=3,
                                                 autoscale=True,
                                                 name='hm'))
        ])

    test_logger.add_console_tensor('ms_ssim={:.3f}', d_test.ms_ssim)
    test_logger.add_console_tensor('bpp={:.3f}', bpp_test)
    test_logger.add_summaries([
        tf.summary.histogram('centers', ae.get_centers_variable()),
        tf.summary.histogram(
            'test/qbar', enc_out_test.qbar[:ae_config.batch_size // 2, ...])
    ])
    test_logger.add_console_tensor('d_BPG={:.6f}', d_BPG_test)
    test_logger.add_console_tensor(Logger.Numpy1DFormatter('centers={}'),
                                   ae.get_centers_variable())

    print('Starting session and queues...')
    with tf_helpers.start_queues_in_sess(
            init_vars=restore_manager is None) as (sess, coord):
        train_logger.finalize_with_sess(sess)
        test_logger.finalize_with_sess(sess)

        if restore_manager:
            restore_manager.restore(sess)

        saver = Saver(Saver.ckpt_dir_for_log_dir(logdir),
                      max_to_keep=1,
                      keep_checkpoint_every_n_hours=ckpt_interval_hours)

        train_loop(ae_config,
                   sess,
                   coord,
                   train_op,
                   train_logger,
                   test_logger,
                   train_flags,
                   logdir,
                   saver,
                   is_restored=restore_manager is not None)
示例#4
0
 def job_id_from_out_dir(out_dir):
     base = path.basename(out_dir)  # should be {log_date} {dataset_name}
     return logdir_helpers.log_date_from_log_dir(
         base)  # may raise ValueError