def run_eval_loop(sess=None, fetches_to_collect=None, other_ops=(), hooks=(), checkpoint_dir=None, load_path=None, max_steps=None, max_seconds=None, init_fn=None): if isinstance(fetches_to_collect, dict): keys, values = zip(*fetches_to_collect.items()) results = run_eval_loop(sess, list(values), other_ops, hooks, checkpoint_dir, load_path, max_steps, max_seconds) return attrdict.AttrDict(dict(zip(keys, results))) sess_creator = None if sess else make_session_creator( checkpoint_dir, load_path, init_fn) collect_hook = session_hooks.collect_hook(fetches_to_collect) hooks = [collect_hook, *hooks] if max_seconds or max_steps: stop_hook = session_hooks.stop_after_steps_or_seconds_hook( max_seconds, max_steps) hooks.append(stop_hook) tfasync.main_loop(sess=sess, sess_creator=sess_creator, ops=other_ops, hooks=hooks) return collect_hook.result
def run_train_loop( train_op, sess=None, checkpoint_dir=None, load_path=None, max_steps=None, hooks=(), init_fn=None): sess_creator = None if sess else make_session_creator(checkpoint_dir, load_path, init_fn) if max_steps: stop_hook = session_hooks.stop_after_steps_or_seconds_hook(steps_limit=max_steps) hooks = [stop_hook, *hooks] hooks = [*hooks, session_hooks.stop_on_signal_hook()] tfasync.main_loop(sess=sess, sess_creator=sess_creator, ops=train_op, hooks=hooks)
async def validation_hook( metrics, checkpoint_path_prefix=None, summary_output_dir=None, max_steps=None, max_seconds=None): """Evaluates the model and logs the resulting average evaluation metrics. Furthermore, if `checkpoint_path_prefix` is given, it also saves a checkpoint whenever there is record-low loss. In this case the loss tensor is assumed to be the first metric. To reiterate, if checkpointing is desired, the first metric must be a loss, for which lower values are better. This hook runs a separate validation loop within the existing session. Args: metric_tensors: The tensors representing the evaluation metrics on the validation set. metric_names: The names of the metrics in the same order. metric_format_specs: The format specifiers for str.format, for pretty printing in the log. checkpoint_path_prefix: Prefix of the path where the best model's checkpoint should be saved. If None, no checkpointing is done. summary_output_dir: TensorBoard summary directory where the metrics should be written. max_steps: The maximum number of steps to run in the validation loop. max_seconds: The maximum time to spend on validation. After this the loop is stopped and the average metrics are calculated based on the steps that were run. Returns: A hook that can be used in a training loop. """ saver = tf.compat.v1.train.Saver(max_to_keep=1, save_relative_paths=True) summary_writer = (tf.compat.v1.summary.FileWriterCache.get(summary_output_dir) if summary_output_dir else None) # Read the best loss from filesystem. The loss value is encoded into the filename of the # checkpoint, such as "something-val-0.003348-158585.index" if checkpoint_path_prefix: paths = glob.glob(checkpoint_path_prefix + '*') filename_matches = [re.match(r'.+?-val-(.+?)-.+?\.index', path) for path in paths] losses = [float(match[1]) for match in filename_matches if match] best_main_metric_value = min(losses) if losses else np.inf else: best_main_metric_value = np.inf # We create some hooks for the internal, nested validation loop. (Hooks within a hook!) # First a counter to count steps. counter = tfu.get_or_create_counter('validation') counter_h = counter_hook(counter) eta_h = eta_hook(max_steps, step_tensor=counter.var, init_phase_seconds=30, every_n_secs=10) collect_h = collect_hook([m.tensor for m in metrics]) stop_h = stop_after_steps_or_seconds_hook(seconds_limit=max_seconds, steps_limit=max_steps) sigint_h = stop_on_signal_hook(sig=signal.SIGINT) sigterm_h = stop_on_signal_hook(sig=signal.SIGTERM, is_additional=True) inner_hooks = [counter_h, eta_h, collect_h, stop_h, sigint_h, sigterm_h] global_step_tensor = tf.compat.v1.train.get_global_step() async for run_context, run_values in tfasync.run_detailed_iter(global_step_tensor): global_step_value = run_values.results # don't validate after the very first step, not much to validate yet if global_step_value == 0: continue logging.info('Running validation') # Run the evaluation loop in the existing session that this validation hook operates in tfasync.main_loop(sess=run_context.session, hooks=inner_hooks) aggregated_values = [ metric.get_aggregated_value(result) for metric, result in zip(metrics, collect_h.result)] # Write to log for metric, aggregated_value in zip(metrics, aggregated_values): logging.info(metric.format(aggregated_value)) # Write summaries for TensorBoard if summary_writer: summary = tfu.scalar_dict_to_summary({ f'validation/{metric.name}': np.nanmean(value) for metric, value in zip(metrics, aggregated_values)}) summary_writer.add_summary(summary, global_step=global_step_value) summary_writer.flush() # Save checkpoint if the loss improved aggregated_main_value = np.nanmean(aggregated_values[0]) if metrics[0].is_first_better(aggregated_main_value, best_main_metric_value): best_main_metric_value = aggregated_main_value logging.info(f'Main metric: {metrics[0].format(aggregated_main_value):} (new record!)') if checkpoint_path_prefix: saver.save( run_context.session, f'{checkpoint_path_prefix}-val-{aggregated_main_value:.6f}', global_step=global_step_value) else: logging.info(f'Main metric: {aggregated_main_value:.6f}')