Beispiel #1
0
def test():
    logging.info('Test (eval) phase.')
    tf.reset_default_graph()
    n_epochs = FLAGS.epochs if FLAGS.occlude_test or FLAGS.test_aug or FLAGS.multiepoch_test else 1
    t = build_graph(TEST, n_epochs=n_epochs, shuffle=False)

    test_counter = tfu.get_or_create_counter('testc')
    counter_hook = session_hooks.counter_hook(test_counter)
    example_hook = session_hooks.log_increment_per_sec(
        'Examples',
        test_counter.var * FLAGS.batch_size_test,
        None,
        every_n_secs=FLAGS.hook_seconds)

    hooks = [example_hook, counter_hook]
    dataset = data.datasets.current_dataset()
    if FLAGS.gui:
        plot_hook = session_hooks.send_to_worker_hook(
            [(t.x[0] + 1) / 2, t.coords3d_pred[0], t.coords3d_true[0]],
            util3d.plot3d_worker,
            worker_args=[dataset.joint_info.stick_figure_edges],
            worker_kwargs=dict(batched=False,
                               interval=100,
                               has_ground_truth=True),
            every_n_steps=1,
            use_threading=False)
        rate_limit_hook = session_hooks.rate_limit_hook(0.5)
        hooks.append(plot_hook)
        hooks.append(rate_limit_hook)

    fetch_names = [
        'image_path', 'coords3d_true_orig_cam', 'coords3d_pred_orig_cam',
        'coords3d_true_world', 'coords3d_pred_world', 'activity_name',
        'scene_name', 'joint_validity_mask', 'confidences',
        'coords3d_pred_backproj_orig_cam', 'compactness'
    ]
    fetch_tensors = {fetch_name: t[fetch_name] for fetch_name in fetch_names}

    global_init_op = tf.global_variables_initializer()
    local_init_op = tf.local_variables_initializer()

    def init_fn(_, sess):
        sess.run([global_init_op, local_init_op, test_counter.reset_op])

    f = helpers.run_eval_loop(fetches_to_collect=fetch_tensors,
                              load_path=FLAGS.load_path,
                              checkpoint_dir=FLAGS.checkpoint_dir,
                              hooks=hooks,
                              init_fn=init_fn)
    save_results(f)
Beispiel #2
0
def make_training_hooks(t_train, t_valid):
    saver = tf.compat.v1.train.Saver(max_to_keep=2, save_relative_paths=True)

    checkpoint_state = tf.train.get_checkpoint_state(FLAGS.logdir)
    if checkpoint_state:
        saver.recover_last_checkpoints(
            checkpoint_state.all_model_checkpoint_paths)

    global_step_tensor = tf.compat.v1.train.get_or_create_global_step()
    checkpoint_hook = tf.estimator.CheckpointSaverHook(FLAGS.logdir,
                                                       saver=saver,
                                                       save_secs=30 * 60)
    total_batch_size = FLAGS.batch_size * (2 if FLAGS.train_mixed else 1)
    example_counter_hook = session_hooks.log_increment_per_sec(
        'Training images',
        t_train.global_step * total_batch_size,
        every_n_secs=FLAGS.hook_seconds,
        summary_output_dir=FLAGS.logdir)

    i_epoch = t_train.global_step // (t_train.n_examples // FLAGS.batch_size)
    logger_hook = session_hooks.logger_hook(
        'Epoch {:03d}, global step {:07d}. Loss: {:.15e}',
        [i_epoch, t_train.global_step, t_train.loss],
        every_n_steps=1)

    hooks = [example_counter_hook, logger_hook, checkpoint_hook]

    if FLAGS.epochs:
        eta_hook = session_hooks.eta_hook(
            n_total_steps=(t_train.n_examples * FLAGS.epochs) //
            FLAGS.batch_size,
            every_n_secs=600,
            summary_output_dir=FLAGS.logdir)
        hooks.append(eta_hook)

    if FLAGS.validate_period:
        every_n_steps = (int(
            np.round(FLAGS.validate_period *
                     (t_train.n_examples // FLAGS.batch_size))))

        max_valid_steps = np.ceil(t_valid.n_examples / FLAGS.batch_size_test)
        summary_output_dir = FLAGS.logdir if FLAGS.tensorboard else None

        metrics = [
            EvaluationMetric(t_valid.mean_error,
                             'MPJPE',
                             '.3f',
                             is_higher_better=False),
            EvaluationMetric(t_valid.mean_error_procrustes,
                             'MPJPE-procrustes',
                             '.3f',
                             is_higher_better=False),
            EvaluationMetric(t_valid.mean_pck, '3DPCK@150mm', '.3%'),
            EvaluationMetric(t_valid.mean_auc, 'AUC', '.3%'),
        ]

        validation_hook = session_hooks.validation_hook(
            metrics,
            summary_output_dir=summary_output_dir,
            max_steps=max_valid_steps,
            max_seconds=120,
            every_n_steps=every_n_steps,
            _step_tensor=global_step_tensor)
        hooks.append(validation_hook)

    if FLAGS.tensorboard:
        other_summary_ops = [
            a for a in [tf.compat.v1.summary.merge_all()] if a is not None
        ]
        summary_hook = tf.estimator.SummarySaverHook(
            save_steps=1,
            output_dir=FLAGS.logdir,
            summary_op=tf.compat.v1.summary.merge(
                [*other_summary_ops, t_train.summary_op]))
        summary_hook = tfasync.PeriodicHookWrapper(
            summary_hook, every_n_steps=10, step_tensor=global_step_tensor)
        hooks.append(summary_hook)

    if FLAGS.gui:
        dataset = data.datasets3d.get_dataset(FLAGS.dataset)
        plot_hook = session_hooks.send_to_worker_hook(
            [
                tfu.std_to_nhwc(t_train.x[0]), t_train.coords3d_pred[0],
                t_train.coords3d_true[0]
            ],
            util3d.plot3d_worker,
            worker_args=[dataset.joint_info.stick_figure_edges],
            worker_kwargs=dict(batched=False, interval=100),
            every_n_secs=FLAGS.hook_seconds,
            use_threading=False)
        hooks.append(plot_hook)
        if 'coords3d_pred2d' in t_train:
            plot_hook = session_hooks.send_to_worker_hook(
                [
                    tfu.std_to_nhwc(t_train.x_2d[0]),
                    t_train.coords3d_pred2d[0], t_train.coords3d_pred2d[0]
                ],
                util3d.plot3d_worker,
                worker_args=[dataset.joint_info.stick_figure_edges],
                worker_kwargs=dict(batched=False,
                                   interval=100,
                                   has_ground_truth=False),
                every_n_secs=FLAGS.hook_seconds,
                use_threading=False)
            hooks.append(plot_hook)

    return hooks