def test(): logging.info('Test (eval) phase.') tf.reset_default_graph() n_epochs = FLAGS.epochs if FLAGS.occlude_test or FLAGS.test_aug or FLAGS.multiepoch_test else 1 t = build_graph(TEST, n_epochs=n_epochs, shuffle=False) test_counter = tfu.get_or_create_counter('testc') counter_hook = session_hooks.counter_hook(test_counter) example_hook = session_hooks.log_increment_per_sec( 'Examples', test_counter.var * FLAGS.batch_size_test, None, every_n_secs=FLAGS.hook_seconds) hooks = [example_hook, counter_hook] dataset = data.datasets.current_dataset() if FLAGS.gui: plot_hook = session_hooks.send_to_worker_hook( [(t.x[0] + 1) / 2, t.coords3d_pred[0], t.coords3d_true[0]], util3d.plot3d_worker, worker_args=[dataset.joint_info.stick_figure_edges], worker_kwargs=dict(batched=False, interval=100, has_ground_truth=True), every_n_steps=1, use_threading=False) rate_limit_hook = session_hooks.rate_limit_hook(0.5) hooks.append(plot_hook) hooks.append(rate_limit_hook) fetch_names = [ 'image_path', 'coords3d_true_orig_cam', 'coords3d_pred_orig_cam', 'coords3d_true_world', 'coords3d_pred_world', 'activity_name', 'scene_name', 'joint_validity_mask', 'confidences', 'coords3d_pred_backproj_orig_cam', 'compactness' ] fetch_tensors = {fetch_name: t[fetch_name] for fetch_name in fetch_names} global_init_op = tf.global_variables_initializer() local_init_op = tf.local_variables_initializer() def init_fn(_, sess): sess.run([global_init_op, local_init_op, test_counter.reset_op]) f = helpers.run_eval_loop(fetches_to_collect=fetch_tensors, load_path=FLAGS.load_path, checkpoint_dir=FLAGS.checkpoint_dir, hooks=hooks, init_fn=init_fn) save_results(f)
def make_training_hooks(t_train, t_valid): saver = tf.compat.v1.train.Saver(max_to_keep=2, save_relative_paths=True) checkpoint_state = tf.train.get_checkpoint_state(FLAGS.logdir) if checkpoint_state: saver.recover_last_checkpoints( checkpoint_state.all_model_checkpoint_paths) global_step_tensor = tf.compat.v1.train.get_or_create_global_step() checkpoint_hook = tf.estimator.CheckpointSaverHook(FLAGS.logdir, saver=saver, save_secs=30 * 60) total_batch_size = FLAGS.batch_size * (2 if FLAGS.train_mixed else 1) example_counter_hook = session_hooks.log_increment_per_sec( 'Training images', t_train.global_step * total_batch_size, every_n_secs=FLAGS.hook_seconds, summary_output_dir=FLAGS.logdir) i_epoch = t_train.global_step // (t_train.n_examples // FLAGS.batch_size) logger_hook = session_hooks.logger_hook( 'Epoch {:03d}, global step {:07d}. Loss: {:.15e}', [i_epoch, t_train.global_step, t_train.loss], every_n_steps=1) hooks = [example_counter_hook, logger_hook, checkpoint_hook] if FLAGS.epochs: eta_hook = session_hooks.eta_hook( n_total_steps=(t_train.n_examples * FLAGS.epochs) // FLAGS.batch_size, every_n_secs=600, summary_output_dir=FLAGS.logdir) hooks.append(eta_hook) if FLAGS.validate_period: every_n_steps = (int( np.round(FLAGS.validate_period * (t_train.n_examples // FLAGS.batch_size)))) max_valid_steps = np.ceil(t_valid.n_examples / FLAGS.batch_size_test) summary_output_dir = FLAGS.logdir if FLAGS.tensorboard else None metrics = [ EvaluationMetric(t_valid.mean_error, 'MPJPE', '.3f', is_higher_better=False), EvaluationMetric(t_valid.mean_error_procrustes, 'MPJPE-procrustes', '.3f', is_higher_better=False), EvaluationMetric(t_valid.mean_pck, '3DPCK@150mm', '.3%'), EvaluationMetric(t_valid.mean_auc, 'AUC', '.3%'), ] validation_hook = session_hooks.validation_hook( metrics, summary_output_dir=summary_output_dir, max_steps=max_valid_steps, max_seconds=120, every_n_steps=every_n_steps, _step_tensor=global_step_tensor) hooks.append(validation_hook) if FLAGS.tensorboard: other_summary_ops = [ a for a in [tf.compat.v1.summary.merge_all()] if a is not None ] summary_hook = tf.estimator.SummarySaverHook( save_steps=1, output_dir=FLAGS.logdir, summary_op=tf.compat.v1.summary.merge( [*other_summary_ops, t_train.summary_op])) summary_hook = tfasync.PeriodicHookWrapper( summary_hook, every_n_steps=10, step_tensor=global_step_tensor) hooks.append(summary_hook) if FLAGS.gui: dataset = data.datasets3d.get_dataset(FLAGS.dataset) plot_hook = session_hooks.send_to_worker_hook( [ tfu.std_to_nhwc(t_train.x[0]), t_train.coords3d_pred[0], t_train.coords3d_true[0] ], util3d.plot3d_worker, worker_args=[dataset.joint_info.stick_figure_edges], worker_kwargs=dict(batched=False, interval=100), every_n_secs=FLAGS.hook_seconds, use_threading=False) hooks.append(plot_hook) if 'coords3d_pred2d' in t_train: plot_hook = session_hooks.send_to_worker_hook( [ tfu.std_to_nhwc(t_train.x_2d[0]), t_train.coords3d_pred2d[0], t_train.coords3d_pred2d[0] ], util3d.plot3d_worker, worker_args=[dataset.joint_info.stick_figure_edges], worker_kwargs=dict(batched=False, interval=100, has_ground_truth=False), every_n_secs=FLAGS.hook_seconds, use_threading=False) hooks.append(plot_hook) return hooks