Пример #1
0
def create_default_training_hooks(estimator,
                                  sample_frequency=500,
                                  delimiter=" "):
    """Creates common SessionRunHooks used for training.

  Args:
    estimator: The estimator instance
    sample_frequency: frequency of samples passed to the TrainSampleHook

  Returns:
    An array of `SessionRunHook` items.
  """
    output_dir = estimator.model_dir
    training_hooks = []

    model_analysis_hook = hooks.PrintModelAnalysisHook(
        filename=os.path.join(output_dir, "model_analysis.txt"))
    training_hooks.append(model_analysis_hook)

    train_sample_hook = hooks.TrainSampleHook(every_n_steps=sample_frequency,
                                              sample_dir=os.path.join(
                                                  output_dir, "samples"),
                                              delimiter=delimiter)
    training_hooks.append(train_sample_hook)

    metadata_hook = hooks.MetadataCaptureHook(output_dir=os.path.join(
        output_dir, "metadata"),
                                              step=10)
    training_hooks.append(metadata_hook)

    tokens_per_sec_counter = hooks.TokensPerSecondCounter(
        every_n_steps=100, output_dir=output_dir)
    training_hooks.append(tokens_per_sec_counter)

    return training_hooks
Пример #2
0
 def test_begin(self):
   outfile = tempfile.NamedTemporaryFile()
   tf.get_variable("weigths", [128, 128])
   hook = hooks.PrintModelAnalysisHook(filename=outfile.name)
   hook.begin()
   file_contents = outfile.read().strip()
   self.assertEqual(file_contents.decode(), "_TFProfRoot (--/16.38k params)\n"
                    "  weigths (128x128, 16.38k/16.38k params)")
   outfile.close()
Пример #3
0
    def test_begin(self):
        model_dir = tempfile.mkdtemp()
        outfile = tempfile.NamedTemporaryFile()
        tf.get_variable("weigths", [128, 128])
        hook = hooks.PrintModelAnalysisHook(params={}, model_dir=model_dir)
        hook.begin()

        with gfile.GFile(os.path.join(model_dir,
                                      "model_analysis.txt")) as file:
            file_contents = file.read().strip()

        self.assertEqual(
            file_contents.decode(), "_TFProfRoot (--/16.38k params)\n"
            "  weigths (128x128, 16.38k/16.38k params)")
        outfile.close()
Пример #4
0
  def test_begin(self):
    model_dir = tempfile.mkdtemp()
    outfile = tempfile.NamedTemporaryFile()
    tf.get_variable("weights", [128, 128])
    hook = hooks.PrintModelAnalysisHook(
        params={}, model_dir=model_dir, run_config=tf.contrib.learn.RunConfig())
    hook.begin()

    with gfile.GFile(os.path.join(model_dir, "model_analysis.txt")) as file:
      file_contents = file.read().strip()

    lines = tf.compat.as_text(file_contents).split("\n")
    if len(lines) == 3:
      # TensorFlow v1.2 includes an extra header line
      self.assertEqual(lines[0], "node name | # parameters")

    self.assertEqual(lines[-2], "_TFProfRoot (--/16.38k params)")
    self.assertEqual(lines[-1], "  weights (128x128, 16.38k/16.38k params)")

    outfile.close()
Пример #5
0
def create_experiment(output_dir):
    """
  Creates a new Experiment instance.

  Args:
    output_dir: Output directory for model checkpoints and summaries.
  """

    # Load vocabulary info
    source_vocab_info = inputs.get_vocab_info(FLAGS.vocab_source)
    target_vocab_info = inputs.get_vocab_info(FLAGS.vocab_target)

    # Create data providers
    train_data_provider = lambda: inputs.make_data_provider([FLAGS.data_train])
    dev_data_provider = lambda: inputs.make_data_provider([FLAGS.data_dev])

    # Find model class
    model_class = getattr(models, FLAGS.model)

    # Parse parameter and merge with defaults
    hparams = model_class.default_params()
    if FLAGS.hparams is not None:
        hparams = HParamsParser(hparams).parse(FLAGS.hparams)

    # Print hyperparameter values
    tf.logging.info("Model Hyperparameters")
    tf.logging.info("=" * 50)
    for param, value in sorted(hparams.items()):
        tf.logging.info("%s=%s", param, value)
    tf.logging.info("=" * 50)

    # Create model
    model = model_class(source_vocab_info=source_vocab_info,
                        target_vocab_info=target_vocab_info,
                        params=hparams)
    featurizer = model.create_featurizer()

    bucket_boundaries = None
    if FLAGS.buckets:
        bucket_boundaries = list(map(int, FLAGS.buckets.split(",")))

    # Create input functions
    train_input_fn = training_utils.create_input_fn(
        train_data_provider,
        featurizer,
        FLAGS.batch_size,
        bucket_boundaries=bucket_boundaries)
    eval_input_fn = training_utils.create_input_fn(dev_data_provider,
                                                   featurizer,
                                                   FLAGS.batch_size)

    def model_fn(features, labels, params, mode):
        """Builds the model graph"""
        return model(features, labels, params, mode)

    estimator = tf.contrib.learn.estimator.Estimator(model_fn=model_fn,
                                                     model_dir=output_dir)

    # Create training Hooks
    model_analysis_hook = hooks.PrintModelAnalysisHook(
        filename=os.path.join(estimator.model_dir, "model_analysis.txt"))
    train_sample_hook = hooks.TrainSampleHook(
        every_n_steps=FLAGS.sample_every_n_steps)
    metadata_hook = hooks.MetadataCaptureHook(output_dir=os.path.join(
        estimator.model_dir, "metadata"),
                                              step=10)
    train_monitors = [model_analysis_hook, train_sample_hook, metadata_hook]

    experiment = tf.contrib.learn.experiment.Experiment(
        estimator=estimator,
        train_input_fn=train_input_fn,
        eval_input_fn=eval_input_fn,
        min_eval_frequency=FLAGS.eval_every_n_steps,
        train_steps=FLAGS.train_steps,
        eval_steps=FLAGS.eval_steps,
        train_monitors=train_monitors)

    return experiment