def test_capture(self): global_step = tf.contrib.framework.get_or_create_global_step() # Some test computation some_weights = tf.get_variable("weigths", [2, 128]) computation = tf.nn.softmax(some_weights) hook = hooks.MetadataCaptureHook(params={"step": 5}, model_dir=self.model_dir) hook.begin() with self.test_session() as sess: sess.run(tf.global_variables_initializer()) #pylint: disable=W0212 mon_sess = monitored_session._HookedSession(sess, [hook]) # Should not trigger for step 0 sess.run(tf.assign(global_step, 0)) mon_sess.run(computation) self.assertEqual(gfile.ListDirectory(self.model_dir), []) # Should trigger *after* step 5 sess.run(tf.assign(global_step, 5)) mon_sess.run(computation) self.assertEqual(gfile.ListDirectory(self.model_dir), []) mon_sess.run(computation) self.assertEqual(set(gfile.ListDirectory(self.model_dir)), set(["run_meta", "tfprof_log", "timeline.json"]))
def create_default_training_hooks(estimator, sample_frequency=500, delimiter=" "): """Creates common SessionRunHooks used for training. Args: estimator: The estimator instance sample_frequency: frequency of samples passed to the TrainSampleHook Returns: An array of `SessionRunHook` items. """ output_dir = estimator.model_dir training_hooks = [] model_analysis_hook = hooks.PrintModelAnalysisHook( filename=os.path.join(output_dir, "model_analysis.txt")) training_hooks.append(model_analysis_hook) train_sample_hook = hooks.TrainSampleHook(every_n_steps=sample_frequency, sample_dir=os.path.join( output_dir, "samples"), delimiter=delimiter) training_hooks.append(train_sample_hook) metadata_hook = hooks.MetadataCaptureHook(output_dir=os.path.join( output_dir, "metadata"), step=10) training_hooks.append(metadata_hook) tokens_per_sec_counter = hooks.TokensPerSecondCounter( every_n_steps=100, output_dir=output_dir) training_hooks.append(tokens_per_sec_counter) return training_hooks
def create_experiment(output_dir): """ Creates a new Experiment instance. Args: output_dir: Output directory for model checkpoints and summaries. """ # Load vocabulary info source_vocab_info = inputs.get_vocab_info(FLAGS.vocab_source) target_vocab_info = inputs.get_vocab_info(FLAGS.vocab_target) # Create data providers train_data_provider = lambda: inputs.make_data_provider([FLAGS.data_train]) dev_data_provider = lambda: inputs.make_data_provider([FLAGS.data_dev]) # Find model class model_class = getattr(models, FLAGS.model) # Parse parameter and merge with defaults hparams = model_class.default_params() if FLAGS.hparams is not None: hparams = HParamsParser(hparams).parse(FLAGS.hparams) # Print hyperparameter values tf.logging.info("Model Hyperparameters") tf.logging.info("=" * 50) for param, value in sorted(hparams.items()): tf.logging.info("%s=%s", param, value) tf.logging.info("=" * 50) # Create model model = model_class(source_vocab_info=source_vocab_info, target_vocab_info=target_vocab_info, params=hparams) featurizer = model.create_featurizer() bucket_boundaries = None if FLAGS.buckets: bucket_boundaries = list(map(int, FLAGS.buckets.split(","))) # Create input functions train_input_fn = training_utils.create_input_fn( train_data_provider, featurizer, FLAGS.batch_size, bucket_boundaries=bucket_boundaries) eval_input_fn = training_utils.create_input_fn(dev_data_provider, featurizer, FLAGS.batch_size) def model_fn(features, labels, params, mode): """Builds the model graph""" return model(features, labels, params, mode) estimator = tf.contrib.learn.estimator.Estimator(model_fn=model_fn, model_dir=output_dir) # Create training Hooks model_analysis_hook = hooks.PrintModelAnalysisHook( filename=os.path.join(estimator.model_dir, "model_analysis.txt")) train_sample_hook = hooks.TrainSampleHook( every_n_steps=FLAGS.sample_every_n_steps) metadata_hook = hooks.MetadataCaptureHook(output_dir=os.path.join( estimator.model_dir, "metadata"), step=10) train_monitors = [model_analysis_hook, train_sample_hook, metadata_hook] experiment = tf.contrib.learn.experiment.Experiment( estimator=estimator, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, min_eval_frequency=FLAGS.eval_every_n_steps, train_steps=FLAGS.train_steps, eval_steps=FLAGS.eval_steps, train_monitors=train_monitors) return experiment