def experiment_fn(run_config, hparams): estimator = tf.estimator.Estimator(model_fn=optimizer.make_model_fn( MODELS[FLAGS.model].model, FLAGS.num_gpus), config=run_config, params=hparams) train_hooks = [ hooks.ExamplesPerSecondHook(batch_size=hparams.batch_size, every_n_iter=FLAGS.save_summary_steps), hooks.LoggingTensorHook(collection="batch_logging", every_n_iter=FLAGS.save_summary_steps, batch=True), hooks.LoggingTensorHook(collection="logging", every_n_iter=FLAGS.save_summary_steps, batch=False) ] eval_hooks = [ hooks.SummarySaverHook(every_n_iter=FLAGS.save_summary_steps, output_dir=os.path.join(run_config.model_dir, "eval")) ] experiment = tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=common_io.make_input_fn( DATASETS[FLAGS.dataset], tf.estimator.ModeKeys.TRAIN, hparams, num_epochs=FLAGS.num_epochs, shuffle_batches=FLAGS.shuffle_batches, num_threads=FLAGS.num_reader_threads), eval_input_fn=common_io.make_input_fn( DATASETS[FLAGS.dataset], tf.estimator.ModeKeys.EVAL, hparams, num_epochs=FLAGS.num_epochs, shuffle_batches=FLAGS.shuffle_batches, num_threads=FLAGS.num_reader_threads), eval_steps=None, min_eval_frequency=FLAGS.eval_frequency, eval_hooks=eval_hooks) experiment.extend_train_hooks(train_hooks) return experiment
def main(unused_argv): if FLAGS.output_dir: model_dir = FLAGS.output_dir else: raise NotImplementedError DATASETS[FLAGS.dataset].prepare() session_config = tf.ConfigProto() session_config.allow_soft_placement = True session_config.gpu_options.allow_growth = True run_config = tf.contrib.learn.RunConfig( model_dir=model_dir, save_summary_steps=FLAGS.save_summary_steps, save_checkpoints_steps=FLAGS.save_checkpoints_steps, save_checkpoints_secs=None, session_config=session_config) estimator = tf.estimator.Estimator( model_fn=optimizer.make_model_fn(MODELS[FLAGS.model].model, FLAGS.num_gpus), config=run_config, params=hparams.get_params(MODELS[FLAGS.model], DATASETS[FLAGS.dataset], FLAGS.hparams)) y = estimator.predict(input_fn=common_io.make_input_fn( DATASETS[FLAGS.dataset], tf.estimator.ModeKeys.PREDICT, hparams.get_params(MODELS[FLAGS.model], DATASETS[FLAGS.dataset], FLAGS.hparams), num_epochs=1, shuffle_batches=False, num_threads=FLAGS.num_reader_threads), ) print("fname,label") words = DATASETS[FLAGS.dataset].WORDS for file, p in zip(DATASETS[FLAGS.dataset].TEST_LIST, y): print(file, words[p["predictions"]], sep=',')