Example #1
0
    def test_capture(self):
        global_step = tf.contrib.framework.get_or_create_global_step()
        # Some test computation
        some_weights = tf.get_variable("weigths", [2, 128])
        computation = tf.nn.softmax(some_weights)

        hook = hooks.MetadataCaptureHook(params={"step": 5},
                                         model_dir=self.model_dir)
        hook.begin()

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            #pylint: disable=W0212
            mon_sess = monitored_session._HookedSession(sess, [hook])
            # Should not trigger for step 0
            sess.run(tf.assign(global_step, 0))
            mon_sess.run(computation)
            self.assertEqual(gfile.ListDirectory(self.model_dir), [])
            # Should trigger *after* step 5
            sess.run(tf.assign(global_step, 5))
            mon_sess.run(computation)
            self.assertEqual(gfile.ListDirectory(self.model_dir), [])
            mon_sess.run(computation)
            self.assertEqual(set(gfile.ListDirectory(self.model_dir)),
                             set(["run_meta", "tfprof_log", "timeline.json"]))
Example #2
0
def create_default_training_hooks(estimator,
                                  sample_frequency=500,
                                  delimiter=" "):
    """Creates common SessionRunHooks used for training.

  Args:
    estimator: The estimator instance
    sample_frequency: frequency of samples passed to the TrainSampleHook

  Returns:
    An array of `SessionRunHook` items.
  """
    output_dir = estimator.model_dir
    training_hooks = []

    model_analysis_hook = hooks.PrintModelAnalysisHook(
        filename=os.path.join(output_dir, "model_analysis.txt"))
    training_hooks.append(model_analysis_hook)

    train_sample_hook = hooks.TrainSampleHook(every_n_steps=sample_frequency,
                                              sample_dir=os.path.join(
                                                  output_dir, "samples"),
                                              delimiter=delimiter)
    training_hooks.append(train_sample_hook)

    metadata_hook = hooks.MetadataCaptureHook(output_dir=os.path.join(
        output_dir, "metadata"),
                                              step=10)
    training_hooks.append(metadata_hook)

    tokens_per_sec_counter = hooks.TokensPerSecondCounter(
        every_n_steps=100, output_dir=output_dir)
    training_hooks.append(tokens_per_sec_counter)

    return training_hooks
Example #3
0
def create_experiment(output_dir):
    """
  Creates a new Experiment instance.

  Args:
    output_dir: Output directory for model checkpoints and summaries.
  """

    # Load vocabulary info
    source_vocab_info = inputs.get_vocab_info(FLAGS.vocab_source)
    target_vocab_info = inputs.get_vocab_info(FLAGS.vocab_target)

    # Create data providers
    train_data_provider = lambda: inputs.make_data_provider([FLAGS.data_train])
    dev_data_provider = lambda: inputs.make_data_provider([FLAGS.data_dev])

    # Find model class
    model_class = getattr(models, FLAGS.model)

    # Parse parameter and merge with defaults
    hparams = model_class.default_params()
    if FLAGS.hparams is not None:
        hparams = HParamsParser(hparams).parse(FLAGS.hparams)

    # Print hyperparameter values
    tf.logging.info("Model Hyperparameters")
    tf.logging.info("=" * 50)
    for param, value in sorted(hparams.items()):
        tf.logging.info("%s=%s", param, value)
    tf.logging.info("=" * 50)

    # Create model
    model = model_class(source_vocab_info=source_vocab_info,
                        target_vocab_info=target_vocab_info,
                        params=hparams)
    featurizer = model.create_featurizer()

    bucket_boundaries = None
    if FLAGS.buckets:
        bucket_boundaries = list(map(int, FLAGS.buckets.split(",")))

    # Create input functions
    train_input_fn = training_utils.create_input_fn(
        train_data_provider,
        featurizer,
        FLAGS.batch_size,
        bucket_boundaries=bucket_boundaries)
    eval_input_fn = training_utils.create_input_fn(dev_data_provider,
                                                   featurizer,
                                                   FLAGS.batch_size)

    def model_fn(features, labels, params, mode):
        """Builds the model graph"""
        return model(features, labels, params, mode)

    estimator = tf.contrib.learn.estimator.Estimator(model_fn=model_fn,
                                                     model_dir=output_dir)

    # Create training Hooks
    model_analysis_hook = hooks.PrintModelAnalysisHook(
        filename=os.path.join(estimator.model_dir, "model_analysis.txt"))
    train_sample_hook = hooks.TrainSampleHook(
        every_n_steps=FLAGS.sample_every_n_steps)
    metadata_hook = hooks.MetadataCaptureHook(output_dir=os.path.join(
        estimator.model_dir, "metadata"),
                                              step=10)
    train_monitors = [model_analysis_hook, train_sample_hook, metadata_hook]

    experiment = tf.contrib.learn.experiment.Experiment(
        estimator=estimator,
        train_input_fn=train_input_fn,
        eval_input_fn=eval_input_fn,
        min_eval_frequency=FLAGS.eval_every_n_steps,
        train_steps=FLAGS.train_steps,
        eval_steps=FLAGS.eval_steps,
        train_monitors=train_monitors)

    return experiment