def experiment_fn(run_config, hparams): estimator = tf.estimator.Estimator(model_fn=make_model_fn(), config=run_config, params=hparams) train_hooks = [ hooks.ExamplesPerSecondHook(batch_size=hparams.batch_size, every_n_iter=FLAGS.save_summary_steps), hooks.LoggingTensorHook(collection="batch_logging", every_n_iter=FLAGS.save_summary_steps, batch=True), hooks.LoggingTensorHook(collection="logging", every_n_iter=FLAGS.save_summary_steps, batch=False) ] eval_hooks = [ hooks.SummarySaverHook(every_n_iter=FLAGS.save_summary_steps, output_dir=os.path.join(run_config.model_dir, "eval")) ] experiment = tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=make_input_fn(tf.estimator.ModeKeys.TRAIN, hparams), eval_input_fn=make_input_fn(tf.estimator.ModeKeys.EVAL, hparams), eval_steps=None, min_eval_frequency=FLAGS.eval_frequency, eval_hooks=eval_hooks) experiment.extend_train_hooks(train_hooks) return experiment
def _train(): params_path = os.path.join(model_dir, "params.json") if os.path.exists(params_path) and not FLAGS.overwrite_params: with open(params_path, "r") as fp: if not fp.read() == str(hp): raise RuntimeError("Mismatching parameters found.") else: with open(params_path, "w") as fp: fp.write(str(hp)) train_sets = ( cfg.train_sets.to_dict() if isinstance(cfg.train_sets, misc_utils.Tuple) else cfg.train_sets) estimator.train( input_fn=io_utils.make_input_fn( d, train_sets, tf.estimator.ModeKeys.TRAIN, hp, num_epochs=cfg.num_epochs, shuffle_batches=cfg.shuffle_batches, num_threads=cfg.num_reader_threads, prefetch_buffer_size=cfg.prefetch_buffer_size), hooks=[ hooks.ExamplesPerSecondHook( batch_size=hp.batch_size, every_n_iter=cfg.save_summary_steps), hooks.LoggingTensorHook( collection="batch_logging", every_n_iter=cfg.save_summary_steps, batch=True), hooks.LoggingTensorHook( collection="logging", every_n_iter=cfg.save_summary_steps, batch=False), tf.train.CheckpointSaverHook( model_dir, save_steps=cfg.save_checkpoints_steps, listeners=[ hooks.BestCheckpointKeeper( model_dir, eval_fn=_eval, eval_set=cfg.checkpoint_selector.eval_set, eval_metric=cfg.checkpoint_selector.eval_metric, compare_fn=cfg.checkpoint_selector.compare_fn)])])
def experiment_fn(run_config, hparams): estimator = tf.estimator.Estimator(model_fn=optimizer.make_model_fn( MODELS[FLAGS.model].model, FLAGS.num_gpus), config=run_config, params=hparams) train_hooks = [ hooks.ExamplesPerSecondHook(batch_size=hparams.batch_size, every_n_iter=FLAGS.save_summary_steps), hooks.LoggingTensorHook(collection="batch_logging", every_n_iter=FLAGS.save_summary_steps, batch=True), hooks.LoggingTensorHook(collection="logging", every_n_iter=FLAGS.save_summary_steps, batch=False) ] eval_hooks = [ hooks.SummarySaverHook(every_n_iter=FLAGS.save_summary_steps, output_dir=os.path.join(run_config.model_dir, "eval")) ] experiment = tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=common_io.make_input_fn( DATASETS[FLAGS.dataset], tf.estimator.ModeKeys.TRAIN, hparams, num_epochs=FLAGS.num_epochs, shuffle_batches=FLAGS.shuffle_batches, num_threads=FLAGS.num_reader_threads), eval_input_fn=common_io.make_input_fn( DATASETS[FLAGS.dataset], tf.estimator.ModeKeys.EVAL, hparams, num_epochs=FLAGS.num_epochs, shuffle_batches=FLAGS.shuffle_batches, num_threads=FLAGS.num_reader_threads), eval_steps=None, min_eval_frequency=FLAGS.eval_frequency, eval_hooks=eval_hooks) experiment.extend_train_hooks(train_hooks) return experiment